Commit 346607fe authored by Christian Convey's avatar Christian Convey Committed by Scott Cyphers

Fixes NGMX-339: Adds option to AvgPoolBackprop padding. (#535)

parent 3e5aa370
......@@ -259,13 +259,15 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
const Shape& padding_above,
bool include_padding_in_avg_computation)
: RequiresTensorViewArgs("AvgPoolBackprop", {delta})
, m_forward_arg_shape(forward_arg_shape)
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
, m_padding_above(padding_above)
, m_include_padding_in_avg_computation(include_padding_in_avg_computation)
{
// --
// TODO: de-duplicate this code from AvgPool::AvgPool.
......@@ -386,6 +388,47 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
window_movement_strides[i]));
}
//
// Make sure we're not going to have to compute average over an empty set of tensor elements.
// That will happen if the sliding window ever resides entirely over the padding area AND
// we're planning to disregard padding when computing the window's average.
//
if (!include_padding_in_avg_computation)
{
for (size_t i = 0; i < spatial_dimension_count; i++)
{
const size_t dim_virtual_size = input_item_virtual_shape[i];
const size_t dim_window_size = window_shape[i];
const size_t dim_stride = window_movement_strides[i];
const size_t dim_padding_below = padding_below[i];
const size_t dim_padding_above = padding_above[i];
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
if ((dim_padding_below > 0) && (dim_window_size <= dim_padding_below))
{
throw ngraph_error(
"AvgPoolBackprop window will sometimes reside entirely within the "
"padding-below region, but the op disregards padding elements.");
}
// Now check the upper-bound...
{
const size_t dim_num_strides = (dim_virtual_size - dim_window_size) / dim_stride;
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
if ((dim_padding_above > 0) &&
(dim_window_max_lower_offset >= dim_padding_above_start_offset))
{
throw ngraph_error(
"AvgPoolBackprop window will sometimes reside entirely within the "
"padding-above region, but the op disregards padding elements.");
}
}
}
}
//
// Construct result shape: NCDo.
//
......@@ -413,6 +456,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints,
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
m_padding_above,
m_include_padding_in_avg_computation);
adjoints.add_delta(operand, backprop);
}
......@@ -117,7 +117,8 @@ namespace ngraph
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
const Shape& padding_above,
bool include_padding_in_avg_computation);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -132,7 +133,8 @@ namespace ngraph
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
m_padding_above,
m_include_padding_in_avg_computation);
return std::shared_ptr<op::AvgPoolBackprop>(avpn);
}
......@@ -141,12 +143,18 @@ namespace ngraph
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
bool get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
protected:
Shape m_forward_arg_shape;
Shape m_window_shape;
Strides m_window_movement_strides;
Shape m_padding_below;
Shape m_padding_above;
bool m_include_padding_in_avg_computation;
};
}
}
......@@ -2618,9 +2618,14 @@ namespace ngraph
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
// Dummy forward primitive descriptor to keep MKLDNN happy
const char* algorithm_enumerator =
apb->get_include_padding_in_avg_computation()
? "algorithm::pooling_avg_include_padding"
: "algorithm::pooling_avg_exclude_padding";
writer << "pooling_forward::primitive_desc fwd_pd = "
"pooling_forward::primitive_desc("
<< "{prop_kind::forward, algorithm::pooling_avg_exclude_padding, "
<< "{prop_kind::forward, " << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, "
......@@ -2629,7 +2634,7 @@ namespace ngraph
<< "padding_kind::zero}, cpu_engine);\n";
writer
<< "auto avg_pooling = pooling_backward(pooling_backward::primitive_desc("
<< "pooling_backward::desc(algorithm::pooling_avg_exclude_padding, "
<< "pooling_backward::desc(" << algorithm_enumerator << ", "
<< "result_desc, input_data_desc, {"
<< join(apb->get_window_movement_strides()) << "}, {"
<< join(apb->get_window_shape()) << "}, "
......@@ -2653,7 +2658,11 @@ namespace ngraph
writer << " {" << join(apb->get_window_movement_strides())
<< "},\n";
writer << " {" << join(apb->get_padding_below()) << "},\n";
writer << " {" << join(apb->get_padding_above()) << "}\n";
writer << " {" << join(apb->get_padding_above()) << "},\n";
writer << " "
<< ngraph::to_cplusplus_sourcecode_literal(
apb->get_include_padding_in_avg_computation())
<< "\n";
writer << " );\n";
}
}
......
......@@ -291,7 +291,8 @@ private:
apb->get_window_shape(),
apb->get_window_movement_strides(),
apb->get_padding_below(),
apb->get_padding_above());
apb->get_padding_above(),
apb->get_include_padding_in_avg_computation());
}
else if (node_op == "Broadcast")
{
......
......@@ -39,7 +39,8 @@ namespace ngraph
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
const Shape& padding_above,
bool include_padding_in_avg_computation)
{
CoordinateTransform out_transform(out_shape);
......@@ -100,7 +101,8 @@ namespace ngraph
for (const Coordinate& source_window_coord : source_window_transform)
{
if (source_window_transform.has_source_coordinate(source_window_coord))
if (source_window_transform.has_source_coordinate(source_window_coord) ||
include_padding_in_avg_computation)
{
num_elements_in_window++;
}
......
......@@ -371,12 +371,15 @@ static shared_ptr<ngraph::Function>
node_js.at("window_movement_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>();
node = make_shared<op::AvgPoolBackprop>(forward_arg_shape,
args[0],
window_shape,
window_movement_strides,
padding_below,
padding_above);
padding_above,
include_padding_in_avg_computation);
}
else if (node_op == "BatchNorm")
{
......@@ -879,6 +882,7 @@ static json write(const Node& n)
node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
}
else if (node_op == "BatchNorm")
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment