Commit b3602cf6 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

Handle negative padding (#3195)

parent 00535470
...@@ -212,10 +212,6 @@ void ngraph::runtime::plaidml::ImplPad::Apply() ...@@ -212,10 +212,6 @@ void ngraph::runtime::plaidml::ImplPad::Apply()
NGRAPH_DEBUG << "Pad input dims: " << op().get_input_shape(0); NGRAPH_DEBUG << "Pad input dims: " << op().get_input_shape(0);
NGRAPH_DEBUG << "Pad output dims: " << op().get_shape(); NGRAPH_DEBUG << "Pad output dims: " << op().get_shape();
// FIXME: Compatibility hack inserted by amprocte, now that nGraph's Pad op no longer supports
// interior padding.
Shape padding_interior(op().get_padding_below().size(), 0);
auto dim_limit = op().get_shape().size(); auto dim_limit = op().get_shape().size();
bool any_zero_dims = false; bool any_zero_dims = false;
...@@ -230,16 +226,17 @@ void ngraph::runtime::plaidml::ImplPad::Apply() ...@@ -230,16 +226,17 @@ void ngraph::runtime::plaidml::ImplPad::Apply()
auto out_dsize = [&](std::size_t idx) { auto out_dsize = [&](std::size_t idx) {
std::ostringstream s; std::ostringstream s;
std::size_t total_pad = op().get_padding_below().at(idx) + op().get_padding_above().at(idx); std::ptrdiff_t total_pad =
std::size_t in_dsize = op().get_input_shape(0).at(idx); op().get_padding_below().at(idx) + op().get_padding_above().at(idx);
if (in_dsize) std::ptrdiff_t in_dsize = op().get_input_shape(0).at(idx);
{
total_pad += padding_interior.at(idx) * (in_dsize - 1);
}
if (!any_zero_dims) if (!any_zero_dims)
{ {
s << "DI" << idx + 1; s << "DI" << idx + 1;
if (total_pad) if (total_pad < 0)
{
s << " - " << (0 - total_pad);
}
else if (0 < total_pad)
{ {
s << " + " << total_pad; s << " + " << total_pad;
} }
...@@ -258,15 +255,7 @@ void ngraph::runtime::plaidml::ImplPad::Apply() ...@@ -258,15 +255,7 @@ void ngraph::runtime::plaidml::ImplPad::Apply()
{ {
s << below << " + "; s << below << " + ";
} }
auto interior = padding_interior.at(idx) + 1;
if (interior != 1)
{
s << "(d" << idx + 1 << " * " << interior << ")";
}
else
{
s << "d" << idx + 1; s << "d" << idx + 1;
}
return s.str(); return s.str();
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment