Commit 81818b79 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for REFLECT mode Pad. (#3736)

parent f1f85448
...@@ -50,7 +50,8 @@ namespace ngraph ...@@ -50,7 +50,8 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT && if ((pad_mode == ngraph::op::PadMode::CONSTANT ||
pad_mode == ngraph::op::PadMode::REFLECT) &&
is_optimized_et(args[0].get_element_type())) is_optimized_et(args[0].get_element_type()))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
...@@ -66,6 +67,7 @@ namespace ngraph ...@@ -66,6 +67,7 @@ namespace ngraph
out_shape, out_shape,
padding_below, padding_below,
padding_above, padding_above,
pad_mode,
arg_buffer_index, arg_buffer_index,
padding_value_index, padding_value_index,
out_buffer_index](CPURuntimeContext* ctx, out_buffer_index](CPURuntimeContext* ctx,
...@@ -77,6 +79,7 @@ namespace ngraph ...@@ -77,6 +79,7 @@ namespace ngraph
out_shape, out_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()), CoordinateDiff(padding_below.begin(), padding_below.end()),
CoordinateDiff(padding_above.begin(), padding_above.end()), CoordinateDiff(padding_above.begin(), padding_above.end()),
pad_mode,
ectx->arena); ectx->arena);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
...@@ -123,7 +126,8 @@ namespace ngraph ...@@ -123,7 +126,8 @@ namespace ngraph
auto padding_above = pad->get_padding_above(); auto padding_above = pad->get_padding_above();
auto pad_mode = pad->get_pad_mode(); auto pad_mode = pad->get_pad_mode();
if (pad_mode == ngraph::op::PadMode::CONSTANT && if ((pad_mode == ngraph::op::PadMode::CONSTANT ||
pad_mode == ngraph::op::PadMode::REFLECT) &&
is_optimized_et(pad->get_input_element_type(0))) is_optimized_et(pad->get_input_element_type(0)))
{ {
std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel; std::function<decltype(runtime::cpu::kernel::pad_and_slice<float, 1>)> kernel;
...@@ -133,17 +137,19 @@ namespace ngraph ...@@ -133,17 +137,19 @@ namespace ngraph
arg_shape.size(), arg_shape.size(),
runtime::cpu::kernel::pad_and_slice); runtime::cpu::kernel::pad_and_slice);
auto functor = [kernel, arg_shape, out_shape, padding_below, padding_above]( auto functor =
const std::vector<void*>& inputs, std::vector<void*>& outputs) { [kernel, arg_shape, out_shape, padding_below, padding_above, pad_mode](
kernel(inputs[0], const std::vector<void*>& inputs, std::vector<void*>& outputs) {
outputs[0], kernel(inputs[0],
inputs[1], outputs[0],
arg_shape, inputs[1],
out_shape, arg_shape,
CoordinateDiff(padding_below.begin(), padding_below.end()), out_shape,
CoordinateDiff(padding_above.begin(), padding_above.end()), CoordinateDiff(padding_below.begin(), padding_below.end()),
0); CoordinateDiff(padding_above.begin(), padding_above.end()),
}; pad_mode,
0);
};
return functor; return functor;
} }
else else
......
...@@ -3153,8 +3153,26 @@ namespace ngraph ...@@ -3153,8 +3153,26 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
std::string pad_mode_string;
switch (pad->get_pad_mode())
{
case ngraph::op::PadMode::CONSTANT:
pad_mode_string = "ngraph::op::PadMode::CONSTANT";
break;
case ngraph::op::PadMode::EDGE:
pad_mode_string = "ngraph::op::PadMode::EDGE";
break;
case ngraph::op::PadMode::REFLECT:
pad_mode_string = "ngraph::op::PadMode::REFLECT";
break;
case ngraph::op::PadMode::SYMMETRIC:
pad_mode_string = "ngraph::op::PadMode::SYMMETRIC";
break;
}
if (arg0_shape.size() == 4 && args[0].get_element_type() == element::f32 && if (arg0_shape.size() == 4 && args[0].get_element_type() == element::f32 &&
pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT) (pad->get_pad_mode() == ngraph::op::PadMode::CONSTANT ||
pad->get_pad_mode() == ngraph::op::PadMode::REFLECT))
{ {
writer << "cpu::kernel::pad_4d_float32(" << args[0].get_name() << ",\n" writer << "cpu::kernel::pad_4d_float32(" << args[0].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n" << " " << out[0].get_name() << ",\n"
...@@ -3164,26 +3182,12 @@ namespace ngraph ...@@ -3164,26 +3182,12 @@ namespace ngraph
<< " {" << join(pad->get_padding_below()) << " {" << join(pad->get_padding_below())
<< "},\n" << "},\n"
<< " {" << join(pad->get_padding_above()) << " {" << join(pad->get_padding_above())
<< "}, 0);\n"; << "}, \n"
<< " " << pad_mode_string << ",\n"
<< " 0);\n";
} }
else else
{ {
std::string pad_mode_string;
switch (pad->get_pad_mode())
{
case ngraph::op::PadMode::CONSTANT:
pad_mode_string = "ngraph::op::PadMode::CONSTANT";
break;
case ngraph::op::PadMode::EDGE:
pad_mode_string = "ngraph::op::PadMode::EDGE";
break;
case ngraph::op::PadMode::REFLECT:
pad_mode_string = "ngraph::op::PadMode::REFLECT";
break;
case ngraph::op::PadMode::SYMMETRIC:
pad_mode_string = "ngraph::op::PadMode::SYMMETRIC";
break;
}
writer << "reference::pad<" << out[0].get_type() << ">(" << args[0].get_name() writer << "reference::pad<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n"; << ",\n";
writer << " " << args[1].get_name() << ",\n"; writer << " " << args[1].get_name() << ",\n";
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "ngraph/op/pad.hpp"
// CBLAS types and wrappers // CBLAS types and wrappers
namespace cblas namespace cblas
...@@ -146,6 +148,7 @@ namespace ngraph ...@@ -146,6 +148,7 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena); int arena);
void reduce_sum_all_1d_float32(float* input, void reduce_sum_all_1d_float32(float* input,
......
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena) int arena)
{ {
pad_and_slice<float, 4>(input, pad_and_slice<float, 4>(input,
...@@ -40,6 +41,7 @@ namespace ngraph ...@@ -40,6 +41,7 @@ namespace ngraph
output_shape, output_shape,
padding_below, padding_below,
padding_above, padding_above,
pad_mode,
arena); arena);
} }
} }
......
...@@ -67,15 +67,19 @@ namespace ngraph ...@@ -67,15 +67,19 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const ngraph::op::PadMode pad_mode,
int arena) int arena)
{ {
Eigen::array<Eigen::Index, Rank> out_dims, in_dims; Eigen::array<Eigen::Index, Rank> out_dims, in_dims, temp_dims;
Eigen::array<Eigen::IndexPair<size_t>, Rank> padding; Eigen::array<Eigen::IndexPair<size_t>, Rank> padding;
Eigen::array<Eigen::Index, Rank> indices; Eigen::array<Eigen::Index, Rank> indices;
bool has_negative_below_padding = false;
for (size_t i = 0; i < Rank; i++) for (size_t i = 0; i < Rank; i++)
{ {
out_dims[i] = output_shape[i]; out_dims[i] = output_shape[i];
temp_dims[i] = output_shape[i];
in_dims[i] = input_shape[i]; in_dims[i] = input_shape[i];
padding[i] = { padding[i] = {
...@@ -88,6 +92,8 @@ namespace ngraph ...@@ -88,6 +92,8 @@ namespace ngraph
{ {
NGRAPH_CHECK(padding_below[i] > INT_MIN); NGRAPH_CHECK(padding_below[i] > INT_MIN);
indices[i] = -padding_below[i]; indices[i] = -padding_below[i];
temp_dims[i] -= padding_below[i];
has_negative_below_padding = true;
} }
else else
{ {
...@@ -97,12 +103,93 @@ namespace ngraph ...@@ -97,12 +103,93 @@ namespace ngraph
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out( Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims); static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> temp(
static_cast<ElementType*>(output), temp_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in( Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(
static_cast<ElementType*>(input), in_dims); static_cast<ElementType*>(input), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) = if (pad_mode == ngraph::op::PadMode::CONSTANT)
in.pad(padding, *static_cast<ElementType*>(pad_value)) {
.slice(indices, out_dims); out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in.pad(padding, *static_cast<ElementType*>(pad_value))
.slice(indices, out_dims);
}
else
{
// clang-format off
// PadMode::REFLECT
// We should have dim >= 2 for each dim.
// Example:
//
// Input shape: [4]
// Padding: 6 below, 13 above
// Output shape: [23]
//
// Input: 1 2 3 4
// Expected output: 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3 2 1 2 3 4 3
// Pattern: ... | original n elements | middle (n - 2) elements of original n in reverse order |
// original n elements | middle (n - 2) elements of original n in reverse order | ...
// | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3 2 | 1 2 3 4 | 3
// clang-format on
auto generator =
[&](const Eigen::array<Eigen::DenseIndex, Rank>& out_index) {
Eigen::array<Eigen::DenseIndex, Rank> in_index;
for (size_t i = 0; i < Rank; i++)
{
auto origin_length = in_dims[i];
auto p_below = padding_below[i] >= 0 ? padding_below[i] : 0;
if (out_index[i] < p_below)
{
// padding below
auto reverse = p_below - out_index[i];
auto res = reverse % (origin_length * 2 - 2);
if (res <= origin_length - 2)
{
// copy one of the middle n-2 items
in_index[i] = res;
}
else
{
// copy one of the n items
in_index[i] = origin_length * 2 - 2 - res;
}
}
else if (out_index[i] < in_dims[i] + p_below)
{
// original
in_index[i] = out_index[i] - p_below;
}
else
{
// padding above
auto pos = out_index[i] - in_dims[i] - p_below;
auto res = pos % (origin_length * 2 - 2);
if (res < origin_length - 2)
{
// copy one of the middle n-2 items
in_index[i] = origin_length - 2 - res;
}
else
{
// copy one of the n items
in_index[i] = res - (origin_length - 2);
}
}
}
return in(in_index);
};
if (has_negative_below_padding)
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = temp.generate(generator).slice(indices, out_dims);
}
else
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = out.generate(generator);
}
}
} }
template <typename ElementType> template <typename ElementType>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment