Commit 346607fe authored by Christian Convey's avatar Christian Convey Committed by Scott Cyphers

Fixes NGMX-339: Adds option to AvgPoolBackprop padding. (#535)

parent 3e5aa370
...@@ -259,13 +259,15 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -259,13 +259,15 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
bool include_padding_in_avg_computation)
: RequiresTensorViewArgs("AvgPoolBackprop", {delta}) : RequiresTensorViewArgs("AvgPoolBackprop", {delta})
, m_forward_arg_shape(forward_arg_shape) , m_forward_arg_shape(forward_arg_shape)
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_include_padding_in_avg_computation(include_padding_in_avg_computation)
{ {
// -- // --
// TODO: de-duplicate this code from AvgPool::AvgPool. // TODO: de-duplicate this code from AvgPool::AvgPool.
...@@ -386,6 +388,47 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -386,6 +388,47 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
window_movement_strides[i])); window_movement_strides[i]));
} }
//
// Make sure we're not going to have to compute average over an empty set of tensor elements.
// That will happen if the sliding window ever resides entirely over the padding area AND
// we're planning to disregard padding when computing the window's average.
//
if (!include_padding_in_avg_computation)
{
for (size_t i = 0; i < spatial_dimension_count; i++)
{
const size_t dim_virtual_size = input_item_virtual_shape[i];
const size_t dim_window_size = window_shape[i];
const size_t dim_stride = window_movement_strides[i];
const size_t dim_padding_below = padding_below[i];
const size_t dim_padding_above = padding_above[i];
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
if ((dim_padding_below > 0) && (dim_window_size <= dim_padding_below))
{
throw ngraph_error(
"AvgPoolBackprop window will sometimes reside entirely within the "
"padding-below region, but the op disregards padding elements.");
}
// Now check the upper-bound...
{
const size_t dim_num_strides = (dim_virtual_size - dim_window_size) / dim_stride;
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
if ((dim_padding_above > 0) &&
(dim_window_max_lower_offset >= dim_padding_above_start_offset))
{
throw ngraph_error(
"AvgPoolBackprop window will sometimes reside entirely within the "
"padding-above region, but the op disregards padding elements.");
}
}
}
}
// //
// Construct result shape: NCDo. // Construct result shape: NCDo.
// //
...@@ -413,6 +456,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -413,6 +456,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above,
m_include_padding_in_avg_computation);
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
...@@ -117,7 +117,8 @@ namespace ngraph ...@@ -117,7 +117,8 @@ namespace ngraph
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above,
bool include_padding_in_avg_computation);
virtual std::shared_ptr<Node> copy_with_new_args( virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
...@@ -132,7 +133,8 @@ namespace ngraph ...@@ -132,7 +133,8 @@ namespace ngraph
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above,
m_include_padding_in_avg_computation);
return std::shared_ptr<op::AvgPoolBackprop>(avpn); return std::shared_ptr<op::AvgPoolBackprop>(avpn);
} }
...@@ -141,12 +143,18 @@ namespace ngraph ...@@ -141,12 +143,18 @@ namespace ngraph
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
bool get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
protected: protected:
Shape m_forward_arg_shape; Shape m_forward_arg_shape;
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
bool m_include_padding_in_avg_computation;
}; };
} }
} }
...@@ -2618,9 +2618,14 @@ namespace ngraph ...@@ -2618,9 +2618,14 @@ namespace ngraph
writer << "memory result = memory({result_desc, cpu_engine}, " writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n"; << out[0].get_name() << ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy // Dummy forward primitive descriptor to keep MKLDNN happy
const char* algorithm_enumerator =
apb->get_include_padding_in_avg_computation()
? "algorithm::pooling_avg_include_padding"
: "algorithm::pooling_avg_exclude_padding";
writer << "pooling_forward::primitive_desc fwd_pd = " writer << "pooling_forward::primitive_desc fwd_pd = "
"pooling_forward::primitive_desc(" "pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_avg_exclude_padding, " << "{prop_kind::forward, " << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {" << "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {" << join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, " << join(apb->get_window_shape()) << "}, "
...@@ -2629,7 +2634,7 @@ namespace ngraph ...@@ -2629,7 +2634,7 @@ namespace ngraph
<< "padding_kind::zero}, cpu_engine);\n"; << "padding_kind::zero}, cpu_engine);\n";
writer writer
<< "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc(" << "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(algorithm::pooling_avg_exclude_padding, " << "pooling_backward::desc(" << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {" << "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {" << join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, " << join(apb->get_window_shape()) << "}, "
...@@ -2653,7 +2658,11 @@ namespace ngraph ...@@ -2653,7 +2658,11 @@ namespace ngraph
writer << " {" << join(apb->get_window_movement_strides()) writer << " {" << join(apb->get_window_movement_strides())
<< "},\n"; << "},\n";
writer << " {" << join(apb->get_padding_below()) << "},\n"; writer << " {" << join(apb->get_padding_below()) << "},\n";
writer << " {" << join(apb->get_padding_above()) << "}\n"; writer << " {" << join(apb->get_padding_above()) << "},\n";
writer << " "
<< ngraph::to_cplusplus_sourcecode_literal(
apb->get_include_padding_in_avg_computation())
<< "\n";
writer << " );\n"; writer << " );\n";
} }
} }
......
...@@ -291,7 +291,8 @@ private: ...@@ -291,7 +291,8 @@ private:
apb->get_window_shape(), apb->get_window_shape(),
apb->get_window_movement_strides(), apb->get_window_movement_strides(),
apb->get_padding_below(), apb->get_padding_below(),
apb->get_padding_above()); apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
} }
else if (node_op == "Broadcast") else if (node_op == "Broadcast")
{ {
......
...@@ -39,7 +39,8 @@ namespace ngraph ...@@ -39,7 +39,8 @@ namespace ngraph
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
bool include_padding_in_avg_computation)
{ {
CoordinateTransform out_transform(out_shape); CoordinateTransform out_transform(out_shape);
...@@ -100,7 +101,8 @@ namespace ngraph ...@@ -100,7 +101,8 @@ namespace ngraph
for (const Coordinate& source_window_coord : source_window_transform) for (const Coordinate& source_window_coord : source_window_transform)
{ {
if (source_window_transform.has_source_coordinate(source_window_coord)) if (source_window_transform.has_source_coordinate(source_window_coord) ||
include_padding_in_avg_computation)
{ {
num_elements_in_window++; num_elements_in_window++;
} }
......
...@@ -371,12 +371,15 @@ static shared_ptr<ngraph::Function> ...@@ -371,12 +371,15 @@ static shared_ptr<ngraph::Function>
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>(); auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
node = make_shared<op::AvgPoolBackprop>(forward_arg_shape, node = make_shared<op::AvgPoolBackprop>(forward_arg_shape,
args[0], args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above); padding_above,
include_padding_in_avg_computation);
} }
else if (node_op == "BatchNorm") else if (node_op == "BatchNorm")
{ {
...@@ -879,6 +882,7 @@ static json write(const Node& n) ...@@ -879,6 +882,7 @@ static json write(const Node& n)
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
} }
else if (node_op == "BatchNorm") else if (node_op == "BatchNorm")
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment