Commit 81818b79 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for REFLECT mode Pad. (#3736)

parent f1f85448
...@@ -50,7 +50,8 @@ namespace ngraph ...@@ -50,7 +50,8 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT && if ((pad_mode == ngraph::op::PadMode::CONSTANT ||
pad_mode == ngraph::op::PadMode::REFLECT) &&
is_optimized_et(args[0].get_element_type())) is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
...@@ -66,6 +67,7 @@ namespace ngraph ...@@ -66,6 +67,7 @@ namespace ngraph
out_shape, out_shape,
padding_below, padding_below,
padding_above, padding_above,
pad_mode,
arg_buffer_index, arg_buffer_index,
padding_value_index, padding_value_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -77,6 +79,7 @@ namespace ngraph ...@@ -77,6 +79,7 @@ namespace ngraph
out_shape, out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()), CoordinateDiff(padding_below.begin(), padding_below.end()),
CoordinateDiff(padding_above.begin(), padding_above.end()), CoordinateDiff(padding_above.begin(), padding_above.end()),
pad_mode,
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -123,7 +126,8 @@ namespace ngraph ...@@ -123,7 +126,8 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT && if ((pad_mode == ngraph::op::PadMode::CONSTANT ||
pad_mode == ngraph::op::PadMode::REFLECT) &&
is_optimized_et(pad->get_input_element_type(0))) is_optimized_et(pad->get_input_element_type(0)))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
...@@ -133,7 +137,8 @@ namespace ngraph ...@@ -133,7 +137,8 @@ namespace ngraph
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::pad_and_slice); runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above]( auto functor =
[kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode](
const std::vector<void*>& inputs, std::vector<void*>& outputs) { const std::vector<void*>& inputs, std::vector<void*>& outputs) {
kernel(inputs[0], kernel(inputs[0],
outputs[0], outputs[0],
...@@ -142,6 +147,7 @@ namespace ngraph ...@@ -142,6 +147,7 @@ namespace ngraph
out_shape, out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()), CoordinateDiff(padding_below.begin(), padding_below.end()),
CoordinateDiff(padding_above.begin(), padding_above.end()), CoordinateDiff(padding_above.begin(), padding_above.end()),
pad_mode,
0); 0);
}; };
return functor; return functor;
......
...@@ -3153,21 +3153,6 @@ namespace ngraph ...@@ -3153,21 +3153,6 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
if (arg0_shape.size() == 4 && args[0].get_element_type() == element::f32 &&
pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT)
{
writer << "cpu::kernel::pad_4d_float32(" << args[0].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " {" << join(arg0_shape) << "},\n"
<< " {" << join(result_shape) << "},\n"
<< " {" << join(pad->get_padding_below())
<< "},\n"
<< " {" << join(pad->get_padding_above())
<< "}, 0);\n";
}
else
{
std::string pad_mode_string; std::string pad_mode_string;
switch (pad->get_pad_mode()) switch (pad->get_pad_mode())
{ {
...@@ -3184,6 +3169,25 @@ namespace ngraph ...@@ -3184,6 +3169,25 @@ namespace ngraph
pad_mode_string = "ngraph::op::PadMode::SYMMETRIC"; pad_mode_string = "ngraph::op::PadMode::SYMMETRIC";
break; break;
} }
if (arg0_shape.size() == 4 && args[0].get_element_type() == element::f32 &&
(pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT ||
pad->get_pad_mode() == ngraph::op::PadMode::REFLECT))
{
writer << "cpu::kernel::pad_4d_float32(" << args[0].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " {" << join(arg0_shape) << "},\n"
<< " {" << join(result_shape) << "},\n"
<< " {" << join(pad->get_padding_below())
<< "},\n"
<< " {" << join(pad->get_padding_above())
<< "}, \n"
<< " " << pad_mode_string << ",\n"
<< " 0);\n";
}
else
{
writer << "reference::pad<" << out[0].get_type() << ">(" << args[0].get_name() writer << "reference::pad<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n"; << ",\n";
writer << " " << args[1].get_name() << ",\n"; writer << " " << args[1].get_name() << ",\n";
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "ngraph/op/pad.hpp"
// CBLAS types and wrappers // CBLAS types and wrappers
namespace cblas namespace cblas
...@@ -146,6 +148,7 @@ namespace ngraph ...@@ -146,6 +148,7 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena); int arena);
void reduce_sum_all_1d_float32(float* input, void reduce_sum_all_1d_float32(float* input,
......
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena) int arena)
{ {
pad_and_slice<float, 4>(input, pad_and_slice<float, 4>(input,
...@@ -40,6 +41,7 @@ namespace ngraph ...@@ -40,6 +41,7 @@ namespace ngraph
output_shape, output_shape,
padding_below, padding_below,
padding_above, padding_above,
pad_mode,
arena); arena);
} }
} }
......
...@@ -67,15 +67,19 @@ namespace ngraph ...@@ -67,15 +67,19 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena) int arena)
{ {
Eigen::array<Eigen::Index, Rank> out_dims, in_dims; Eigen::array<Eigen::Index, Rank> out_dims, in_dims, temp_dims;
Eigen::array<Eigen::IndexPair<size_t>, Rank> padding; Eigen::array<Eigen::IndexPair<size_t>, Rank> padding;
Eigen::array<Eigen::Index, Rank> indices; Eigen::array<Eigen::Index, Rank> indices;
bool has_negative_below_padding = false;
for (size_t i = 0; i < Rank; i++) for (size_t i = 0; i < Rank; i++)
{ {
out_dims[i] = output_shape[i]; out_dims[i] = output_shape[i];
temp_dims[i] = output_shape[i];
in_dims[i] = input_shape[i]; in_dims[i] = input_shape[i];
padding[i] = { padding[i] = {
...@@ -88,6 +92,8 @@ namespace ngraph ...@@ -88,6 +92,8 @@ namespace ngraph
{ {
NGRAPH_CHECK(padding_below[i] > INT_MIN); NGRAPH_CHECK(padding_below[i] > INT_MIN);
indices[i] = -padding_below[i]; indices[i] = -padding_below[i];
temp_dims[i] -= padding_below[i];
has_negative_below_padding = true;
} }
else else
{ {
...@@ -97,13 +103,94 @@ namespace ngraph ...@@ -97,13 +103,94 @@ namespace ngraph
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out( Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims); static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> temp(
static_cast<ElementType*>(output), temp_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in( Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(
static_cast<ElementType*>(input), in_dims); static_cast<ElementType*>(input), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = if (pad_mode == ngraph::op::PadMode::CONSTANT)
in.pad(padding, *static_cast<ElementType*>(pad_value)) {
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in.pad(padding, *static_cast<ElementType*>(pad_value))
.slice(indices, out_dims); .slice(indices, out_dims);
} }
else
{
// clang-format off
// PadMode::REFLECT
// We should have dim >= 2 for each dim.
// Example:
//
// Input shape: [4]
// Padding: 6 below, 13 above
// Output shape: [23]
//
// Input: 1 2 3 4
// Expected output: 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3
// Pattern: ... | original n elements | middle (n - 2) elements of original n in reverse order |
// original n elements | middle (n - 2) elements of original n in reverse order | ...
// | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3
// clang-format on
auto generator =
[&](const Eigen::array<Eigen::DenseIndex, Rank>& out_index) {
Eigen::array<Eigen::DenseIndex, Rank> in_index;
for (size_t i = 0; i < Rank; i++)
{
auto origin_length = in_dims[i];
auto p_below = padding_below[i] >= 0 ? padding_below[i] : 0;
if (out_index[i] < p_below)
{
// padding below
auto reverse = p_below - out_index[i];
auto res = reverse % (origin_length * 2 - 2);
if (res <= origin_length - 2)
{
// copy one of the middle n-2 items
in_index[i] = res;
}
else
{
// copy one of the n items
in_index[i] = origin_length * 2 - 2 - res;
}
}
else if (out_index[i] < in_dims[i] + p_below)
{
// original
in_index[i] = out_index[i] - p_below;
}
else
{
// padding above
auto pos = out_index[i] - in_dims[i] - p_below;
auto res = pos % (origin_length * 2 - 2);
if (res < origin_length - 2)
{
// copy one of the middle n-2 items
in_index[i] = origin_length - 2 - res;
}
else
{
// copy one of the n items
in_index[i] = res - (origin_length - 2);
}
}
}
return in(in_index);
};
if (has_negative_below_padding)
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = temp.generate(generator).slice(indices, out_dims);
}
else
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = out.generate(generator);
}
}
}
template <typename ElementType> template <typename ElementType>
void pad_ref(const void* arg0, void pad_ref(const void* arg0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment