Commit 04f212b7 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[Spec] Add v1::AvgPool and v1::MaxPool (#3591)

* Add new enum: RoundingType for pooling operations

* Add v1::AvgPool op

* Add v1::MaxPool op

* Fix comments format

* Fix problem with forward declaration

* new UT & fix some bugs
parent 640295cf
...@@ -21,9 +21,10 @@ ...@@ -21,9 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::AvgPool::type_info; // *** AvgPool OP SET 0 ***
constexpr NodeTypeInfo op::v0::AvgPool::type_info;
op::AvgPool::AvgPool(const Output<Node>& arg, op::v0::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -43,7 +44,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -43,7 +44,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::AvgPool::AvgPool(const Output<Node>& arg, op::v0::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -61,7 +62,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -61,7 +62,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
op::AvgPool::AvgPool(const Output<Node>& arg, op::v0::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -77,7 +78,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -77,7 +78,7 @@ op::AvgPool::AvgPool(const Output<Node>& arg,
{ {
} }
void op::AvgPool::validate_and_infer_types() void op::v0::AvgPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
{ {
...@@ -130,92 +131,93 @@ void op::AvgPool::validate_and_infer_types() ...@@ -130,92 +131,93 @@ void op::AvgPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::AvgPool::AvgPool(const Output<Node>& arg, op::v0::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: AvgPool(arg, window_shape, window_movement_strides, Shape(), Shape(), false) : AvgPool(arg, window_shape, window_movement_strides, Shape(), Shape(), false)
{ {
} }
op::AvgPool::AvgPool(const Output<Node>& arg, const Shape& window_shape) op::v0::AvgPool::AvgPool(const Output<Node>& arg, const Shape& window_shape)
: AvgPool(arg, window_shape, Strides(), Shape(), Shape(), false) : AvgPool(arg, window_shape, Strides(), Shape(), Shape(), false)
{ {
} }
const Shape& op::AvgPool::get_window_shape() const const Shape& op::v0::AvgPool::get_window_shape() const
{ {
return m_window_shape; return m_window_shape;
} }
void op::AvgPool::set_window_shape(const Shape& window_shape) void op::v0::AvgPool::set_window_shape(const Shape& window_shape)
{ {
m_window_shape = window_shape; m_window_shape = window_shape;
} }
const Strides& op::AvgPool::get_window_movement_strides() const const Strides& op::v0::AvgPool::get_window_movement_strides() const
{ {
return m_window_movement_strides; return m_window_movement_strides;
} }
void op::AvgPool::set_window_movement_strides(const Strides& window_movement_strides) void op::v0::AvgPool::set_window_movement_strides(const Strides& window_movement_strides)
{ {
m_window_movement_strides = window_movement_strides; m_window_movement_strides = window_movement_strides;
} }
const Shape& op::AvgPool::get_padding_below() const const Shape& op::v0::AvgPool::get_padding_below() const
{ {
return m_padding_below; return m_padding_below;
} }
void op::AvgPool::set_padding_below(const Shape& padding_below) void op::v0::AvgPool::set_padding_below(const Shape& padding_below)
{ {
m_padding_below = padding_below; m_padding_below = padding_below;
} }
const Shape& op::AvgPool::get_padding_above() const const Shape& op::v0::AvgPool::get_padding_above() const
{ {
return m_padding_above; return m_padding_above;
} }
void op::AvgPool::set_padding_above(const Shape& padding_above) void op::v0::AvgPool::set_padding_above(const Shape& padding_above)
{ {
m_padding_above = padding_above; m_padding_above = padding_above;
} }
bool op::AvgPool::get_include_padding_in_avg_computation() const bool op::v0::AvgPool::get_include_padding_in_avg_computation() const
{ {
return m_include_padding_in_avg_computation; return m_include_padding_in_avg_computation;
} }
void op::AvgPool::set_include_padding_in_avg_computation(bool include_padding_in_avg_computation) void op::v0::AvgPool::set_include_padding_in_avg_computation(
bool include_padding_in_avg_computation)
{ {
m_include_padding_in_avg_computation = include_padding_in_avg_computation; m_include_padding_in_avg_computation = include_padding_in_avg_computation;
} }
const op::PadType& op::AvgPool::get_pad_type() const const op::PadType& op::v0::AvgPool::get_pad_type() const
{ {
return m_pad_type; return m_pad_type;
} }
void op::AvgPool::set_pad_type(const op::PadType& pad_type) void op::v0::AvgPool::set_pad_type(const op::PadType& pad_type)
{ {
m_pad_type = pad_type; m_pad_type = pad_type;
} }
bool op::AvgPool::get_ceil_mode() const bool op::v0::AvgPool::get_ceil_mode() const
{ {
return m_ceil_mode; return m_ceil_mode;
} }
void op::AvgPool::set_ceil_mode(bool ceil_mode) void op::v0::AvgPool::set_ceil_mode(bool ceil_mode)
{ {
m_ceil_mode = ceil_mode; m_ceil_mode = ceil_mode;
} }
shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<AvgPool>(new_args.at(0), return make_shared<v0::AvgPool>(new_args.at(0),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
...@@ -225,13 +227,13 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -225,13 +227,13 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
constexpr NodeTypeInfo op::AvgPoolBackprop::type_info; constexpr NodeTypeInfo op::v0::AvgPoolBackprop::type_info;
shared_ptr<Node> op::AvgPool::get_default_value() const shared_ptr<Node> op::v0::AvgPool::get_default_value() const
{ {
return ngraph::make_constant_from_string("0", get_element_type(), get_shape()); return Constant::create(get_element_type(), get_shape(), {0});
} }
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, op::v0::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta, const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -249,7 +251,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -249,7 +251,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::AvgPoolBackprop::validate_and_infer_types() void op::v0::AvgPoolBackprop::validate_and_infer_types()
{ {
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
...@@ -283,81 +285,80 @@ void op::AvgPoolBackprop::validate_and_infer_types() ...@@ -283,81 +285,80 @@ void op::AvgPoolBackprop::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_forward_arg_shape); set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
} }
const Shape& op::AvgPoolBackprop::get_forward_arg_shape() const const Shape& op::v0::AvgPoolBackprop::get_forward_arg_shape() const
{ {
return m_forward_arg_shape; return m_forward_arg_shape;
} }
void op::AvgPoolBackprop::set_forward_arg_shape(const Shape& forward_arg_shape) void op::v0::AvgPoolBackprop::set_forward_arg_shape(const Shape& forward_arg_shape)
{ {
m_forward_arg_shape = forward_arg_shape; m_forward_arg_shape = forward_arg_shape;
} }
const Shape& op::AvgPoolBackprop::get_window_shape() const const Shape& op::v0::AvgPoolBackprop::get_window_shape() const
{ {
return m_window_shape; return m_window_shape;
} }
void op::AvgPoolBackprop::set_window_shape(const Shape& window_shape) void op::v0::AvgPoolBackprop::set_window_shape(const Shape& window_shape)
{ {
m_window_shape = window_shape; m_window_shape = window_shape;
} }
const Strides& op::AvgPoolBackprop::get_window_movement_strides() const const Strides& op::v0::AvgPoolBackprop::get_window_movement_strides() const
{ {
return m_window_movement_strides; return m_window_movement_strides;
} }
void op::AvgPoolBackprop::set_window_movement_strides(const Strides& window_movement_strides) void op::v0::AvgPoolBackprop::set_window_movement_strides(const Strides& window_movement_strides)
{ {
m_window_movement_strides = window_movement_strides; m_window_movement_strides = window_movement_strides;
} }
const Shape& op::AvgPoolBackprop::get_padding_below() const const Shape& op::v0::AvgPoolBackprop::get_padding_below() const
{ {
return m_padding_below; return m_padding_below;
} }
void op::AvgPoolBackprop::set_padding_below(const Shape& padding_below) void op::v0::AvgPoolBackprop::set_padding_below(const Shape& padding_below)
{ {
m_padding_below = padding_below; m_padding_below = padding_below;
} }
const Shape& op::AvgPoolBackprop::get_padding_above() const const Shape& op::v0::AvgPoolBackprop::get_padding_above() const
{ {
return m_padding_above; return m_padding_above;
} }
void op::AvgPoolBackprop::set_padding_above(const Shape& padding_above) void op::v0::AvgPoolBackprop::set_padding_above(const Shape& padding_above)
{ {
m_padding_above = padding_above; m_padding_above = padding_above;
} }
bool op::AvgPoolBackprop::get_include_padding_in_avg_computation() const bool op::v0::AvgPoolBackprop::get_include_padding_in_avg_computation() const
{ {
return m_include_padding_in_avg_computation; return m_include_padding_in_avg_computation;
} }
void op::AvgPoolBackprop::set_include_padding_in_avg_computation( void op::v0::AvgPoolBackprop::set_include_padding_in_avg_computation(
bool include_padding_in_avg_computation) bool include_padding_in_avg_computation)
{ {
m_include_padding_in_avg_computation = include_padding_in_avg_computation; m_include_padding_in_avg_computation = include_padding_in_avg_computation;
} }
shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
AvgPoolBackprop* avpn = new AvgPoolBackprop(m_forward_arg_shape, return make_shared<v0::AvgPoolBackprop>(m_forward_arg_shape,
new_args.at(0), new_args.at(0),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
return shared_ptr<op::AvgPoolBackprop>(avpn);
} }
void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (m_ceil_mode) if (m_ceil_mode)
{ {
...@@ -368,7 +369,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -368,7 +369,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto operand = input_value(0); auto operand = input_value(0);
auto& operand_shape = get_input_shape(0); auto& operand_shape = get_input_shape(0);
auto backprop = make_shared<op::AvgPoolBackprop>(operand_shape, auto backprop = make_shared<op::v0::AvgPoolBackprop>(operand_shape,
delta, delta,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
...@@ -377,3 +378,318 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -377,3 +378,318 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
// *** AvgPool OP SET 1 ***
constexpr NodeTypeInfo op::v1::AvgPool::type_info;
op::v1::AvgPool::AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type,
const PadType& auto_pad)
: Op({arg})
, m_kernel(kernel)
, m_strides(strides)
, m_pads_begin(pads_begin)
, m_pads_end(pads_end)
, m_exclude_pad(exclude_pad)
, m_auto_pad(auto_pad)
, m_rounding_type(rounding_type)
{
constructor_validate_and_infer_types();
}
op::v1::AvgPool::AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type)
: AvgPool(arg,
kernel,
strides,
pads_begin,
pads_end,
exclude_pad,
rounding_type,
op::PadType::EXPLICIT)
{
}
void op::v1::AvgPool::validate_and_infer_types()
{
if (0 == m_strides.size())
{
m_strides = Strides(m_kernel.size(), 1);
}
if (0 == m_pads_begin.size())
{
m_pads_begin = Shape(m_kernel.size(), 0);
}
if (0 == m_pads_end.size())
{
m_pads_end = Shape(m_kernel.size(), 0);
}
const PartialShape& arg_shape = get_input_partial_shape(0);
if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER)
{
if (arg_shape.is_static())
{
CoordinateDiff pads_end, pads_begin;
infer_auto_padding(arg_shape.to_shape(),
m_kernel,
m_strides,
Strides(m_kernel.size(), 1), // No dilation
m_auto_pad,
pads_end,
pads_begin);
m_pads_end = Shape(pads_end.begin(), pads_end.end());
m_pads_begin = Shape(pads_begin.begin(), pads_begin.end());
}
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
CoordinateDiff pads_begin(m_pads_begin.begin(), m_pads_begin.end());
CoordinateDiff pads_end(m_pads_end.begin(), m_pads_end.end());
set_output_type(0,
get_input_element_type(0),
infer_batched_pooling_forward(this,
arg_shape,
pads_begin,
pads_end,
m_kernel,
m_strides,
!m_exclude_pad,
m_rounding_type == op::RoundingType::CEIL));
}
const Shape& op::v1::AvgPool::get_kernel() const
{
return m_kernel;
}
void op::v1::AvgPool::set_kernel(const Shape& kernel)
{
m_kernel = kernel;
}
const Strides& op::v1::AvgPool::get_strides() const
{
return m_strides;
}
void op::v1::AvgPool::set_strides(const Strides& strides)
{
m_strides = strides;
}
const Shape& op::v1::AvgPool::get_pads_begin() const
{
return m_pads_begin;
}
void op::v1::AvgPool::set_pads_begin(const Shape& pads_begin)
{
m_pads_begin = pads_begin;
}
const Shape& op::v1::AvgPool::get_pads_end() const
{
return m_pads_end;
}
void op::v1::AvgPool::set_pads_end(const Shape& pads_end)
{
m_pads_end = pads_end;
}
bool op::v1::AvgPool::get_exclude_pad() const
{
return m_exclude_pad;
}
void op::v1::AvgPool::set_exclude_pad(bool exclude_pad)
{
m_exclude_pad = exclude_pad;
}
const op::PadType& op::v1::AvgPool::get_auto_pad() const
{
return m_auto_pad;
}
void op::v1::AvgPool::set_auto_pad(const op::PadType& auto_pad)
{
m_auto_pad = auto_pad;
}
op::RoundingType op::v1::AvgPool::get_rounding_type() const
{
return m_rounding_type;
}
void op::v1::AvgPool::set_rounding_type(op::RoundingType rounding_type)
{
m_rounding_type = rounding_type;
}
shared_ptr<Node> op::v1::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::AvgPool>(new_args.at(0),
m_strides,
m_pads_begin,
m_pads_end,
m_kernel,
m_exclude_pad,
m_rounding_type,
m_auto_pad);
}
constexpr NodeTypeInfo op::v1::AvgPoolBackprop::type_info;
op::v1::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const Output<Node>& delta,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad)
: Op(check_single_output_args({delta.get_node_shared_ptr()}))
, m_forward_arg_shape(forward_arg_shape)
, m_kernel(kernel)
, m_strides(strides)
, m_pads_begin(pads_begin)
, m_pads_end(pads_end)
, m_exclude_pad(exclude_pad)
{
constructor_validate_and_infer_types();
}
void op::v1::AvgPoolBackprop::validate_and_infer_types()
{
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
CoordinateDiff pads_begin(m_pads_begin.begin(), m_pads_begin.end());
CoordinateDiff pads_end(m_pads_end.begin(), m_pads_end.end());
PartialShape forward_result_shape = infer_batched_pooling_forward(
this, m_forward_arg_shape, pads_begin, pads_end, m_kernel, m_strides, m_exclude_pad);
const PartialShape& delta_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(
this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
}
const Shape& op::v1::AvgPoolBackprop::get_forward_arg_shape() const
{
return m_forward_arg_shape;
}
void op::v1::AvgPoolBackprop::set_forward_arg_shape(const Shape& forward_arg_shape)
{
m_forward_arg_shape = forward_arg_shape;
}
const Shape& op::v1::AvgPoolBackprop::get_kernel() const
{
return m_kernel;
}
void op::v1::AvgPoolBackprop::set_kernel(const Shape& kernel)
{
m_kernel = kernel;
}
const Strides& op::v1::AvgPoolBackprop::get_strides() const
{
return m_strides;
}
void op::v1::AvgPoolBackprop::set_strides(const Strides& strides)
{
m_strides = strides;
}
const Shape& op::v1::AvgPoolBackprop::get_pads_begin() const
{
return m_pads_begin;
}
void op::v1::AvgPoolBackprop::set_pads_begin(const Shape& pads_begin)
{
m_pads_begin = pads_begin;
}
const Shape& op::v1::AvgPoolBackprop::get_pads_end() const
{
return m_pads_end;
}
void op::v1::AvgPoolBackprop::set_pads_end(const Shape& pads_end)
{
m_pads_end = pads_end;
}
bool op::v1::AvgPoolBackprop::get_exclude_pad() const
{
return m_exclude_pad;
}
void op::v1::AvgPoolBackprop::set_exclude_pad(bool exclude_pad)
{
m_exclude_pad = exclude_pad;
}
shared_ptr<Node> op::v1::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::AvgPoolBackprop>(m_forward_arg_shape,
new_args.at(0),
m_strides,
m_pads_begin,
m_pads_end,
m_kernel,
m_exclude_pad);
}
void op::v1::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (m_rounding_type == op::RoundingType::CEIL)
{
throw ngraph_error("Autodiff not supported on AvgPool with ceil_mode set");
}
auto delta = deltas.at(0);
auto operand = input_value(0);
auto& operand_shape = get_input_shape(0);
auto backprop = make_shared<op::v1::AvgPoolBackprop>(
operand_shape, delta, m_strides, m_pads_begin, m_pads_end, m_kernel, m_exclude_pad);
adjoints.add_delta(operand, backprop);
}
shared_ptr<Node> op::v1::AvgPool::get_default_value() const
{
return op::Constant::create(get_element_type(), get_shape(), {0});
}
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{
namespace v0
{ {
/// \brief Batched average pooling operation, with optional padding and window stride. /// \brief Batched average pooling operation, with optional padding and window stride.
/// ///
...@@ -47,10 +49,10 @@ namespace ngraph ...@@ -47,10 +49,10 @@ namespace ngraph
/// \param padding_above The above-padding shape.<br> /// \param padding_above The above-padding shape.<br>
/// `[n]` /// `[n]`
/// \param include_padding_in_avg_computation If true then averages include padding /// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely /// elements, each treated as the number zero. If false, padding elements are
/// ignored when computing averages. /// entirely ignored when computing averages. \param pad_type Padding type to use
/// \param pad_type Padding type to use for additional padded dimensions /// for additional padded dimensions \param ceil_mode Whether to use ceiling while
/// \param ceil_mode Whether to use ceiling while computing output shape. /// computing output shape.
AvgPool(const Output<Node>& arg, AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -73,9 +75,9 @@ namespace ngraph ...@@ -73,9 +75,9 @@ namespace ngraph
/// \param padding_above The above-padding shape.<br> /// \param padding_above The above-padding shape.<br>
/// `[n]` /// `[n]`
/// \param include_padding_in_avg_computation If true then averages include padding /// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely /// elements, each treated as the number zero. If false, padding elements are
/// ignored when computing averages. /// entirely ignored when computing averages. \param pad_type Padding type to use
/// \param pad_type Padding type to use for additional padded dimensions /// for additional padded dimensions
AvgPool(const Output<Node>& arg, AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -97,8 +99,8 @@ namespace ngraph ...@@ -97,8 +99,8 @@ namespace ngraph
/// \param padding_above The above-padding shape.<br> /// \param padding_above The above-padding shape.<br>
/// `[n]` /// `[n]`
/// \param include_padding_in_avg_computation If true then averages include padding /// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely /// elements, each treated as the number zero. If false, padding elements are
/// ignored when computing averages. /// entirely ignored when computing averages.
AvgPool(const Output<Node>& arg, AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -106,8 +108,8 @@ namespace ngraph ...@@ -106,8 +108,8 @@ namespace ngraph
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation = false); bool include_padding_in_avg_computation = false);
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding /// \brief Constructs a batched, unpadded average pooling operation (i.e., all
/// shapes are set to 0). /// padding shapes are set to 0).
/// ///
/// \param arg The output producing the input data batch tensor.<br> /// \param arg The output producing the input data batch tensor.<br>
/// `[d1, ..., dn]` /// `[d1, ..., dn]`
...@@ -149,7 +151,8 @@ namespace ngraph ...@@ -149,7 +151,8 @@ namespace ngraph
const Shape& get_padding_above() const; const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_above); void set_padding_above(const Shape& padding_above);
bool get_include_padding_in_avg_computation() const; bool get_include_padding_in_avg_computation() const;
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation); void
set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const; const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type); void set_pad_type(const PadType& pad_type);
...@@ -199,7 +202,8 @@ namespace ngraph ...@@ -199,7 +202,8 @@ namespace ngraph
const Shape& get_padding_above() const; const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_abve); void set_padding_above(const Shape& padding_abve);
bool get_include_padding_in_avg_computation() const; bool get_include_padding_in_avg_computation() const;
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation); void
set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
protected: protected:
Shape m_forward_arg_shape; Shape m_forward_arg_shape;
...@@ -209,5 +213,158 @@ namespace ngraph ...@@ -209,5 +213,158 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
bool m_include_padding_in_avg_computation{false}; bool m_include_padding_in_avg_computation{false};
}; };
} } // namespace v0
}
namespace v1
{
/// \brief Batched average pooling operation.
///
class AvgPool : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"AvgPool", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a batched average pooling operation.
AvgPool() = default;
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param strides The strides.<br>
/// `[n]`
/// \param pads_begin The beginning of padding shape.<br>
/// `[n]`
/// \param pads_end The end of padding shape.<br>
/// `[n]`
/// \param kernel The kernel shape.<br>
/// `[n]`
/// \param exclude_pad If false then averages include padding elements, each treated
/// as the number zero. If true, padding elements are entirely ignored when
/// computing averages. \param rounding_type Whether to use ceiling or floor
/// rounding type while computing output shape. \param auto_pad Padding type to use
/// for additional padded dimensions
AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type,
const PadType& auto_pad);
//// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param strides The strides.<br>
/// `[n]`
/// \param pads_begin The beginning of padding shape.<br>
/// `[n]`
/// \param pads_end The end of padding shape.<br>
/// `[n]`
/// \param kernel The kernel shape.<br>
/// `[n]`
/// \param exclude_pad If false then averages include padding elements, each treated
/// as the number zero. If true, padding elements are entirely ignored when
/// computing averages.
/// \param rounding_type Whether to use ceiling or floor rounding type while
/// computing output shape.
AvgPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad,
op::RoundingType rounding_type);
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
/// \return The kernel shape.
const Shape& get_kernel() const;
void set_kernel(const Shape& kernel);
/// \return The strides.
const Strides& get_strides() const;
void set_strides(const Strides& strides);
/// \return The beginning of padding shape.
const Shape& get_pads_begin() const;
void set_pads_begin(const Shape& pads_begin);
/// \return The end of padding shape.
const Shape& get_pads_end() const;
void set_pads_end(const Shape& pads_end);
bool get_exclude_pad() const;
void set_exclude_pad(bool exclude_pad);
/// \return The pad type for pooling.
const PadType& get_auto_pad() const;
void set_auto_pad(const PadType& auto_pad);
op::RoundingType get_rounding_type() const;
void set_rounding_type(op::RoundingType rounding_type);
/// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override;
protected:
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
bool m_exclude_pad{true};
PadType m_auto_pad{PadType::EXPLICIT};
op::RoundingType m_rounding_type{op::RoundingType::FLOOR};
};
class AvgPoolBackprop : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"AvgPoolBackprop", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
AvgPoolBackprop() = default;
AvgPoolBackprop(const Shape& forward_arg_shape,
const Output<Node>& delta,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
bool exclude_pad);
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_forward_arg_shape() const;
void set_forward_arg_shape(const Shape& forward_arg_shape);
const Shape& get_kernel() const;
void set_kernel(const Shape& kernel);
const Strides& get_strides() const;
void set_strides(const Strides& strides);
const Shape& get_pads_begin() const;
void set_pads_begin(const Shape& pads_begin);
const Shape& get_pads_end() const;
void set_pads_end(const Shape& padding_abve);
bool get_exclude_pad() const;
void set_exclude_pad(bool exclude_pad);
protected:
Shape m_forward_arg_shape;
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
bool m_exclude_pad{false};
};
} // namespace v1
using v0::AvgPool;
using v0::AvgPoolBackprop;
} // namespace op
} // namespace ngraph
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::MaxPool::type_info; constexpr NodeTypeInfo op::v0::MaxPool::type_info;
op::MaxPool::MaxPool(const Output<Node>& arg, op::v0::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -42,23 +42,23 @@ op::MaxPool::MaxPool(const Output<Node>& arg, ...@@ -42,23 +42,23 @@ op::MaxPool::MaxPool(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const Output<Node>& arg, op::v0::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type) const PadType& pad_type)
: MaxPool( : v0::MaxPool(
arg, window_shape, window_movement_strides, padding_below, padding_above, pad_type, false) arg, window_shape, window_movement_strides, padding_below, padding_above, pad_type, false)
{ {
} }
op::MaxPool::MaxPool(const Output<Node>& arg, op::v0::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: MaxPool(arg, : v0::MaxPool(arg,
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
...@@ -67,7 +67,7 @@ op::MaxPool::MaxPool(const Output<Node>& arg, ...@@ -67,7 +67,7 @@ op::MaxPool::MaxPool(const Output<Node>& arg,
{ {
} }
void op::MaxPool::validate_and_infer_types() void op::v0::MaxPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
{ {
...@@ -120,22 +120,22 @@ void op::MaxPool::validate_and_infer_types() ...@@ -120,22 +120,22 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::MaxPool::MaxPool(const Output<Node>& arg, op::v0::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape()) : v0::MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{ {
} }
op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape) op::v0::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape()) : v0::MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{ {
} }
shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::MaxPool::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<MaxPool>(new_args.at(0), return make_shared<v0::MaxPool>(new_args.at(0),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
...@@ -144,13 +144,13 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -144,13 +144,13 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
constexpr NodeTypeInfo op::MaxPoolBackprop::type_info; constexpr NodeTypeInfo op::v0::MaxPoolBackprop::type_info;
shared_ptr<Node> op::MaxPool::get_default_value() const shared_ptr<Node> op::v0::MaxPool::get_default_value() const
{ {
return ngraph::make_constant_from_string("0", get_element_type(), get_shape()); return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
} }
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward, op::v0::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta, const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -165,7 +165,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward, ...@@ -165,7 +165,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward, op::v0::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta, const Output<Node>& delta,
const Output<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
...@@ -181,7 +181,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward, ...@@ -181,7 +181,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::MaxPoolBackprop::validate_and_infer_types() void op::v0::MaxPoolBackprop::validate_and_infer_types()
{ {
element::Type forward_arg_et = get_input_element_type(0); element::Type forward_arg_et = get_input_element_type(0);
element::Type delta_et = get_input_element_type(1); element::Type delta_et = get_input_element_type(1);
...@@ -229,12 +229,12 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -229,12 +229,12 @@ void op::MaxPoolBackprop::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), forward_arg_shape); set_output_type(0, get_input_element_type(0), forward_arg_shape);
} }
shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
if (this->get_input_size() == 3) if (this->get_input_size() == 3)
{ {
return make_shared<op::MaxPoolBackprop>(new_args.at(0), return make_shared<op::v0::MaxPoolBackprop>(new_args.at(0),
new_args.at(1), new_args.at(1),
new_args.at(2), new_args.at(2),
m_window_shape, m_window_shape,
...@@ -243,7 +243,7 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a ...@@ -243,7 +243,7 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a
m_padding_above); m_padding_above);
} }
return make_shared<op::MaxPoolBackprop>(new_args.at(0), return make_shared<op::v0::MaxPoolBackprop>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
...@@ -251,7 +251,7 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a ...@@ -251,7 +251,7 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a
m_padding_above); m_padding_above);
} }
void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (m_ceil_mode) if (m_ceil_mode)
{ {
...@@ -262,7 +262,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -262,7 +262,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto operand = input_value(0); auto operand = input_value(0);
auto backprop = auto backprop =
make_shared<op::MaxPoolBackprop>(operand, make_shared<op::v0::MaxPoolBackprop>(operand,
delta, delta,
static_pointer_cast<op::MaxPool>(shared_from_this()), static_pointer_cast<op::MaxPool>(shared_from_this()),
m_window_shape, m_window_shape,
...@@ -272,3 +272,211 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -272,3 +272,211 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
constexpr NodeTypeInfo op::v1::MaxPool::type_info;
op::v1::MaxPool::MaxPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
op::RoundingType rounding_type,
const PadType& auto_pad)
: Op({arg})
, m_kernel(kernel)
, m_strides(strides)
, m_pads_begin(pads_begin)
, m_pads_end(pads_end)
, m_auto_pad(auto_pad)
, m_rounding_type(rounding_type)
{
constructor_validate_and_infer_types();
}
op::v1::MaxPool::MaxPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
op::RoundingType rounding_type)
: v1::MaxPool(arg, strides, pads_begin, pads_end, kernel, rounding_type, op::PadType::EXPLICIT)
{
}
void op::v1::MaxPool::validate_and_infer_types()
{
if (0 == m_strides.size())
{
m_strides = Strides(m_kernel.size(), 1);
}
if (0 == m_pads_begin.size())
{
m_pads_begin = Shape(m_kernel.size(), 0);
}
if (0 == m_pads_end.size())
{
m_pads_end = Shape(m_kernel.size(), 0);
}
const PartialShape& arg_shape = get_input_partial_shape(0);
if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER)
{
if (arg_shape.is_static())
{
CoordinateDiff pads_end, pads_begin;
infer_auto_padding(arg_shape.to_shape(),
m_kernel,
m_strides,
Strides(m_kernel.size(), 1), // No dilation
m_auto_pad,
pads_end,
pads_begin);
m_pads_end = Shape(pads_end.begin(), pads_end.end());
m_pads_begin = Shape(pads_begin.begin(), pads_begin.end());
}
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
CoordinateDiff pads_begin(m_pads_begin.begin(), m_pads_begin.end());
CoordinateDiff pads_end(m_pads_end.begin(), m_pads_end.end());
set_output_type(0,
get_input_element_type(0),
infer_batched_pooling_forward(this,
arg_shape,
pads_begin,
pads_end,
m_kernel,
m_strides,
true,
m_rounding_type == op::RoundingType::CEIL));
}
shared_ptr<Node> op::v1::MaxPool::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::MaxPool>(
new_args.at(0), m_strides, m_pads_begin, m_pads_end, m_kernel, m_rounding_type, m_auto_pad);
}
shared_ptr<Node> op::v1::MaxPool::get_default_value() const
{
return op::Constant::create(get_element_type(), get_shape(), {0});
}
constexpr NodeTypeInfo op::v1::MaxPoolBackprop::type_info;
op::v1::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel)
: Op({arg_forward, delta})
, m_kernel(kernel)
, m_strides(strides)
, m_pads_begin(pads_begin)
, m_pads_end(pads_end)
{
constructor_validate_and_infer_types();
}
op::v1::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Output<Node>& result_forward,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel)
: Op({arg_forward, delta, result_forward})
, m_kernel(kernel)
, m_strides(strides)
, m_pads_begin(pads_begin)
, m_pads_end(pads_end)
{
constructor_validate_and_infer_types();
}
void op::v1::MaxPoolBackprop::validate_and_infer_types()
{
element::Type forward_arg_et = get_input_element_type(0);
element::Type delta_et = get_input_element_type(1);
element::Type result_et;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, forward_arg_et, delta_et),
"Element types for forward argument (",
forward_arg_et,
") and delta (",
delta_et,
") do not match.");
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
CoordinateDiff pads_begin(m_pads_begin.begin(), m_pads_begin.end());
CoordinateDiff pads_end(m_pads_end.begin(), m_pads_end.end());
const PartialShape& forward_arg_shape = get_input_partial_shape(0);
PartialShape forward_result_shape = infer_batched_pooling_forward(
this, forward_arg_shape, pads_begin, pads_end, m_kernel, m_strides, true);
const PartialShape& delta_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(
this,
forward_result_shape.compatible(delta_shape),
"Inferred forward output shape does not match delta shape (inferred forward output ",
"shape: ",
forward_result_shape,
", delta shape: ",
delta_shape,
").");
set_output_type(0, get_input_element_type(0), forward_arg_shape);
}
shared_ptr<Node> op::v1::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (this->get_input_size() == 3)
{
return make_shared<op::v1::MaxPoolBackprop>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_strides,
m_pads_begin,
m_pads_end,
m_kernel);
}
return make_shared<op::v1::MaxPoolBackprop>(
new_args.at(0), new_args.at(1), m_strides, m_pads_begin, m_pads_end, m_kernel);
}
void op::v1::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (m_rounding_type == op::RoundingType::CEIL)
{
throw ngraph_error("Autodiff not supported on MaxPool with rounding_type set");
}
auto delta = deltas.at(0);
auto operand = input_value(0);
auto backprop =
make_shared<op::v1::MaxPoolBackprop>(operand,
delta,
static_pointer_cast<op::MaxPool>(shared_from_this()),
m_strides,
m_pads_begin,
m_pads_end,
m_kernel);
adjoints.add_delta(operand, backprop);
}
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{
namespace v0
{ {
/// \brief Batched max pooling operation, with optional padding and window stride. /// \brief Batched max pooling operation, with optional padding and window stride.
class MaxPool : public Op class MaxPool : public Op
...@@ -104,17 +106,26 @@ namespace ngraph ...@@ -104,17 +106,26 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; } void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const
{
return m_window_movement_strides;
}
void set_window_movement_strides(const Strides& window_movement_strides) void set_window_movement_strides(const Strides& window_movement_strides)
{ {
m_window_movement_strides = window_movement_strides; m_window_movement_strides = window_movement_strides;
} }
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; } void set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; } void set_adding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; } void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
...@@ -166,20 +177,167 @@ namespace ngraph ...@@ -166,20 +177,167 @@ namespace ngraph
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; } void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const
{
return m_window_movement_strides;
}
void set_window_movement_strides(const Strides& window_movement_strides) void set_window_movement_strides(const Strides& window_movement_strides)
{ {
m_window_movement_strides = window_movement_strides; m_window_movement_strides = window_movement_strides;
} }
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; } void set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; } void set_padding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
}; };
} // namespace v0
namespace v1
{
/// \brief Batched max pooling operation.
class MaxPool : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"MaxPool", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param kernel The kernel shape.
/// \param rounding_mode Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad The pad type for automatically computing padding sizes.
MaxPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
op::RoundingType rounding_mode,
const PadType& auto_pad);
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param kernel The kernel shape.
/// \param rounding_mode Whether to use ceiling or floor rounding type while
/// computing output shape.
MaxPool(const Output<Node>& arg,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
op::RoundingType rounding_mode);
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The kernel shape.
const Shape& get_kernel() const { return m_kernel; }
void set_kernel(const Shape& kernel) { m_kernel = kernel; }
/// \return The strides.
const Strides& get_strides() const { return m_strides; }
void set_strides(const Strides& strides) { m_strides = strides; }
/// \return The beginning of padding shape.
const Shape& get_pads_begin() const { return m_pads_begin; }
void set_pads_begin(const Shape& pads_begin) { m_pads_begin = pads_begin; }
/// \return The end of padding shape.
const Shape& get_pads_end() const { return m_pads_end; }
void set_adding_above(const Shape& pads_end) { m_pads_end = pads_end; }
/// \return The pad type for pooling.
const PadType& get_auto_pad() const { return m_auto_pad; }
void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
/// \return The ceiling mode being used for output shape computations
op::RoundingType get_rounding_type() const { return m_rounding_type; }
void set_rounding_type(op::RoundingType rounding_mode)
{
m_rounding_type = rounding_mode;
} }
} /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
PadType m_auto_pad;
op::RoundingType m_rounding_type{op::RoundingType::FLOOR};
};
class MaxPoolBackprop : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"MaxPoolBackprop", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel);
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Output<Node>& result_forward,
const Strides& strides,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
const Shape& get_kernel() const { return m_kernel; }
void set_kernel(const Shape& kernel) { m_kernel = kernel; }
const Strides& get_strides() const { return m_strides; }
void set_strides(const Strides& strides) { m_strides = strides; }
const Shape& get_pads_begin() const { return m_pads_begin; }
void set_pads_begin(const Shape& pads_begin) { m_pads_begin = pads_begin; }
const Shape& get_pads_end() const { return m_pads_end; }
void set_pads_end(const Shape& pads_end) { m_pads_end = pads_end; }
protected:
Shape m_kernel;
Strides m_strides;
Shape m_pads_begin;
Shape m_pads_end;
};
} // namespace v1
using v0::MaxPool;
using v0::MaxPoolBackprop;
} // namespace op
} // namespace ngraph
...@@ -54,6 +54,13 @@ namespace ngraph ...@@ -54,6 +54,13 @@ namespace ngraph
NOTSET = EXPLICIT, NOTSET = EXPLICIT,
}; };
/// \brief Rounding Type used for `Pooling` operators.
enum class RoundingType
{
FLOOR = 0,
CEIL = 1,
};
/// \brief Specifies the algorithm to use for implicit broadcasting of a tensor /// \brief Specifies the algorithm to use for implicit broadcasting of a tensor
/// to align with another tensor /// to align with another tensor
/// ///
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp" #include "ngraph/op/reduce_prod.hpp"
...@@ -82,18 +84,117 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -82,18 +84,117 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
#endif #endif
switch (get_typeid(node)) switch (get_typeid(node))
{ {
case OP_TYPEID::Softmax: case OP_TYPEID::AvgPool:
{ {
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get()); auto tmp = dynamic_cast<const op::v0::AvgPool*>(node.get());
AxisSet axes = tmp->get_axes();
NGRAPH_CHECK( auto rounding_type = static_cast<op::RoundingType>(tmp->get_ceil_mode());
axes.size() == 1, auto exclude_pad = !tmp->get_include_padding_in_avg_computation();
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ", auto auto_pad = tmp->get_pad_type();
*node); auto pads_begin = tmp->get_padding_below();
auto pads_end = tmp->get_padding_above();
auto strides = tmp->get_window_movement_strides();
auto kernel = tmp->get_window_shape();
auto replacement_node = make_shared<op::v1::AvgPool>(node->input(0).get_source_output(),
strides,
pads_begin,
pads_end,
kernel,
exclude_pad,
rounding_type,
auto_pad);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::AvgPoolBackprop:
{
auto tmp = dynamic_cast<const op::v0::AvgPoolBackprop*>(node.get());
auto exclude_pad = !tmp->get_include_padding_in_avg_computation();
auto pads_begin = tmp->get_padding_below();
auto pads_end = tmp->get_padding_above();
auto strides = tmp->get_window_movement_strides();
auto kernel = tmp->get_window_shape();
auto replacement_node = auto replacement_node =
make_shared<op::v1::Softmax>(node->input(0).get_source_output(), axes.to_vector()[0]); make_shared<op::v1::AvgPoolBackprop>(tmp->get_forward_arg_shape(),
node->input(0).get_source_output(),
strides,
pads_begin,
pads_end,
kernel,
exclude_pad);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get());
int64_t axis = tmp->get_axis();
auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
auto replacement_node = make_shared<op::v1::Gather>(
node->input(0).get_source_output(), node->input(1).get_source_output(), axis_node);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::MaxPool:
{
auto tmp = dynamic_cast<const op::v0::MaxPool*>(node.get());
auto rounding_type = static_cast<op::RoundingType>(tmp->get_ceil_mode());
auto auto_pad = tmp->get_pad_type();
auto pads_begin = tmp->get_padding_below();
auto pads_end = tmp->get_padding_above();
auto strides = tmp->get_window_movement_strides();
auto kernel = tmp->get_window_shape();
auto replacement_node = make_shared<op::v1::MaxPool>(node->input(0).get_source_output(),
strides,
pads_begin,
pads_end,
kernel,
rounding_type,
auto_pad);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::MaxPoolBackprop:
{
auto tmp = dynamic_cast<const op::v0::MaxPoolBackprop*>(node.get());
auto pads_begin = tmp->get_padding_below();
auto pads_end = tmp->get_padding_above();
auto strides = tmp->get_window_movement_strides();
auto kernel = tmp->get_window_shape();
shared_ptr<Node> replacement_node;
if (node->get_inputs().size() == 3)
{
replacement_node =
make_shared<op::v1::MaxPoolBackprop>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
node->input(2).get_source_output(),
strides,
pads_begin,
pads_end,
kernel);
}
else
{
replacement_node =
make_shared<op::v1::MaxPoolBackprop>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
strides,
pads_begin,
pads_end,
kernel);
}
replace_node(node, replacement_node); replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
...@@ -136,14 +237,18 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -136,14 +237,18 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Gather: case OP_TYPEID::Softmax:
{ {
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get()); auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
int64_t axis = tmp->get_axis(); AxisSet axes = tmp->get_axes();
auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis}); NGRAPH_CHECK(
auto replacement_node = make_shared<op::v1::Gather>( axes.size() == 1,
node->input(0).get_source_output(), node->input(1).get_source_output(), axis_node); "Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
*node);
auto replacement_node =
make_shared<op::v1::Softmax>(node->input(0).get_source_output(), axes.to_vector()[0]);
replace_node(node, replacement_node); replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
......
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "ngraph/code_writer.hpp" #include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -127,9 +129,6 @@ namespace ngraph ...@@ -127,9 +129,6 @@ namespace ngraph
class MaxPoolWithIndices; class MaxPoolWithIndices;
class Reverse; class Reverse;
class ReverseSequence; class ReverseSequence;
class AvgPool;
class AvgPoolBackprop;
class MaxPoolBackprop;
class MaxPoolWithIndicesBackprop; class MaxPoolWithIndicesBackprop;
class Max; class Max;
class Erf; class Erf;
......
...@@ -360,6 +360,13 @@ static op::PadMode read_pad_mode(json node_js) ...@@ -360,6 +360,13 @@ static op::PadMode read_pad_mode(json node_js)
: op::PadMode::CONSTANT; : op::PadMode::CONSTANT;
} }
static op::RoundingType read_rounding_type(json node_js)
{
return has_key(node_js, "rounding_type")
? static_cast<op::RoundingType>(node_js.at("rounding_type"))
: op::RoundingType::FLOOR;
}
static json write_element_type(const ngraph::element::Type& n) static json write_element_type(const ngraph::element::Type& n)
{ {
json j; json j;
...@@ -800,6 +807,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -800,6 +807,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break; break;
} }
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{
if (op_version == 0)
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -810,7 +819,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -810,7 +819,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node_js.at("include_padding_in_avg_computation").get<bool>(); node_js.at("include_padding_in_avg_computation").get<bool>();
op::PadType pad_type = read_pad_type(node_js); op::PadType pad_type = read_pad_type(node_js);
bool ceil_mode = get_or_default<bool>(node_js, "ceil_mode", false); bool ceil_mode = get_or_default<bool>(node_js, "ceil_mode", false);
node = make_shared<op::AvgPool>(args[0], node = make_shared<op::v0::AvgPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
...@@ -818,9 +827,30 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -818,9 +827,30 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
include_padding_in_avg_computation, include_padding_in_avg_computation,
pad_type, pad_type,
ceil_mode); ceil_mode);
}
if (op_version == 1)
{
auto kernel = node_js.at("kernel").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
auto pads_begin = node_js.at("pads_begin").get<vector<size_t>>();
auto pads_end = node_js.at("pads_end").get<vector<size_t>>();
auto exclude_pad = node_js.at("exclude_pad").get<bool>();
op::PadType pad_type = read_pad_type(node_js);
op::RoundingType rounding_type = read_rounding_type(node_js);
node = make_shared<op::v1::AvgPool>(args[0],
strides,
pads_begin,
pads_end,
kernel,
exclude_pad,
rounding_type,
pad_type);
}
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
{
if (op_version == 0)
{ {
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>(); auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
...@@ -830,13 +860,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -830,13 +860,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation = auto include_padding_in_avg_computation =
get_or_default<bool>(node_js, "include_padding_in_avg_computation", false); get_or_default<bool>(node_js, "include_padding_in_avg_computation", false);
node = make_shared<op::AvgPoolBackprop>(forward_arg_shape, node = make_shared<op::v0::AvgPoolBackprop>(forward_arg_shape,
args[0], args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation);
}
if (op_version == 1)
{
auto forward_arg_shape = node_js.at("forward_arg_shape").get<vector<size_t>>();
auto kernel = node_js.at("kernel").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
auto pads_begin = node_js.at("pads_begin").get<vector<size_t>>();
auto pads_end = node_js.at("pads_end").get<vector<size_t>>();
auto exclude_pad = get_or_default<bool>(node_js, "exclude_pad", true);
node = make_shared<op::v1::AvgPoolBackprop>(
forward_arg_shape, args[0], strides, pads_begin, pads_end, kernel, exclude_pad);
}
break; break;
} }
case OP_TYPEID::BatchMatMul: case OP_TYPEID::BatchMatMul:
...@@ -1426,12 +1468,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1426,12 +1468,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
break; break;
} }
case OP_TYPEID::MaxPool: case OP_TYPEID::MaxPool:
{
if (op_version == 0)
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
// For backwards compatibility, both (but not just one) of the padding_ fields may be // For backwards compatibility, both (but not just one) of the padding_ fields may
// omitted. // be omitted.
auto padding_below_maybe = get_or_default(node_js, "padding_below", json{}); auto padding_below_maybe = get_or_default(node_js, "padding_below", json{});
auto padding_above_maybe = get_or_default(node_js, "padding_above", json{}); auto padding_above_maybe = get_or_default(node_js, "padding_above", json{});
op::PadType pad_type = read_pad_type(node_js); op::PadType pad_type = read_pad_type(node_js);
...@@ -1449,7 +1493,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1449,7 +1493,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{ {
auto padding_below = padding_below_maybe.get<vector<size_t>>(); auto padding_below = padding_below_maybe.get<vector<size_t>>();
auto padding_above = padding_above_maybe.get<vector<size_t>>(); auto padding_above = padding_above_maybe.get<vector<size_t>>();
node = make_shared<op::MaxPool>(args[0], node = make_shared<op::v0::MaxPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
...@@ -1458,11 +1502,26 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1458,11 +1502,26 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
else else
{ {
node = make_shared<op::MaxPool>(args[0], window_shape, window_movement_strides); node = make_shared<op::v0::MaxPool>(
args[0], window_shape, window_movement_strides);
}
}
if (op_version == 1)
{
auto kernel = node_js.at("kernel").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
auto pads_begin = node_js.at("pads_begin").get<vector<size_t>>();
auto pads_end = node_js.at("pads_end").get<vector<size_t>>();
auto rounding_type = read_rounding_type(node_js);
op::PadType pad_type = read_pad_type(node_js);
node = make_shared<op::v1::MaxPool>(
args[0], strides, pads_begin, pads_end, kernel, rounding_type, pad_type);
} }
break; break;
} }
case OP_TYPEID::MaxPoolBackprop: case OP_TYPEID::MaxPoolBackprop:
{
if (op_version == 0)
{ {
auto window_shape = node_js.at("window_shape").get<vector<size_t>>(); auto window_shape = node_js.at("window_shape").get<vector<size_t>>();
auto window_movement_strides = auto window_movement_strides =
...@@ -1471,7 +1530,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1471,7 +1530,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
if (args.size() == 3) if (args.size() == 3)
{ {
node = make_shared<op::MaxPoolBackprop>(args[0], node = make_shared<op::v0::MaxPoolBackprop>(args[0],
args[1], args[1],
args[2], args[2],
window_shape, window_shape,
...@@ -1481,13 +1540,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1481,13 +1540,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
else else
{ {
node = make_shared<op::MaxPoolBackprop>(args[0], node = make_shared<op::v0::MaxPoolBackprop>(args[0],
args[1], args[1],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above); padding_above);
} }
}
if (op_version == 1)
{
auto kernel = node_js.at("kernel").get<vector<size_t>>();
auto strides = node_js.at("strides").get<vector<size_t>>();
auto pads_begin = node_js.at("pads_begin").get<vector<size_t>>();
auto pads_end = node_js.at("pads_end").get<vector<size_t>>();
if (args.size() == 3)
{
node = make_shared<op::v1::MaxPoolBackprop>(
args[0], args[1], args[2], kernel, strides, pads_begin, pads_end);
}
else
{
node = make_shared<op::v1::MaxPoolBackprop>(
args[0], args[1], kernel, strides, pads_begin, pads_end);
}
}
break; break;
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
...@@ -2223,28 +2300,57 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2223,28 +2300,57 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{ {
auto tmp = dynamic_cast<const op::AvgPool*>(&n); if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::AvgPool*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] =
tmp->get_include_padding_in_avg_computation();
node["pad_type"] = tmp->get_pad_type(); node["pad_type"] = tmp->get_pad_type();
if (tmp->get_ceil_mode()) if (tmp->get_ceil_mode())
{ {
node["ceil_mode"] = tmp->get_ceil_mode(); node["ceil_mode"] = tmp->get_ceil_mode();
} }
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::AvgPool*>(&n);
node["kernel"] = tmp->get_kernel();
node["strides"] = tmp->get_strides();
node["pads_begin"] = tmp->get_pads_begin();
node["pads_end"] = tmp->get_pads_end();
node["exclude_pad"] = tmp->get_exclude_pad();
node["auto_pad"] = tmp->get_auto_pad();
node["rounding_type"] = tmp->get_rounding_type();
}
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
{ {
auto tmp = dynamic_cast<const op::AvgPoolBackprop*>(&n); if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::AvgPoolBackprop*>(&n);
node["forward_arg_shape"] = tmp->get_forward_arg_shape(); node["forward_arg_shape"] = tmp->get_forward_arg_shape();
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] =
tmp->get_include_padding_in_avg_computation();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::AvgPoolBackprop*>(&n);
node["forward_arg_shape"] = tmp->get_forward_arg_shape();
node["kernel"] = tmp->get_kernel();
node["strides"] = tmp->get_strides();
node["pads_begin"] = tmp->get_pads_begin();
node["pads_end"] = tmp->get_pads_end();
node["exclude_pad"] = tmp->get_exclude_pad();
}
break; break;
} }
case OP_TYPEID::BatchMatMul: { break; case OP_TYPEID::BatchMatMul: { break;
...@@ -2636,21 +2742,45 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2636,21 +2742,45 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::MaxPool: case OP_TYPEID::MaxPool:
{ {
auto tmp = dynamic_cast<const op::MaxPool*>(&n); if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::MaxPool*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["pad_type"] = tmp->get_pad_type(); node["pad_type"] = tmp->get_pad_type();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::MaxPool*>(&n);
node["kernel"] = tmp->get_kernel();
node["strides"] = tmp->get_strides();
node["pads_begin"] = tmp->get_pads_begin();
node["pads_end"] = tmp->get_pads_end();
node["auto_pad"] = tmp->get_auto_pad();
node["rounding_type"] = tmp->get_rounding_type();
}
break; break;
} }
case OP_TYPEID::MaxPoolBackprop: case OP_TYPEID::MaxPoolBackprop:
{ {
auto tmp = dynamic_cast<const op::MaxPoolBackprop*>(&n); if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::MaxPoolBackprop*>(&n);
node["window_shape"] = tmp->get_window_shape(); node["window_shape"] = tmp->get_window_shape();
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::MaxPoolBackprop*>(&n);
node["kernel"] = tmp->get_kernel();
node["strides"] = tmp->get_strides();
node["pads_begin"] = tmp->get_pads_begin();
node["pads_end"] = tmp->get_pads_end();
}
break; break;
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
......
...@@ -75,6 +75,7 @@ set(SRC ...@@ -75,6 +75,7 @@ set(SRC
opset_pass/softmax_opset_pass.cpp opset_pass/softmax_opset_pass.cpp
opset_pass/gather_opset_pass.cpp opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp opset_pass/pad_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp
partial_shape.cpp partial_shape.cpp
pass.cpp pass.cpp
pass_liveness.cpp pass_liveness.cpp
......
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(upgrade_pass, opset1_avgpool_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6, 9});
Shape pads_begin{0, 0};
Shape pads_end{0, 0};
Strides strides{1, 1};
Shape kernel_shape{3, 3};
bool include_pad = true;
bool ceil_mode = false;
op::PadType pad_mode = op::PadType::EXPLICIT;
auto avgpool_v0 = make_shared<op::v0::AvgPool>(
arg, kernel_shape, strides, pads_begin, pads_end, include_pad, pad_mode, ceil_mode);
auto result = make_shared<op::Result>(avgpool_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto avgpool_s1_result = f->get_results().at(0);
auto node = avgpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto avg_pool_v1_node = static_pointer_cast<op::v1::AvgPool>(node);
EXPECT_EQ(avg_pool_v1_node->description(), "AvgPool");
EXPECT_EQ(avg_pool_v1_node->get_version(), 1);
EXPECT_EQ(avg_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(avg_pool_v1_node->get_pads_end(), pads_end);
EXPECT_EQ(avg_pool_v1_node->get_strides(), strides);
EXPECT_EQ(avg_pool_v1_node->get_kernel(), kernel_shape);
EXPECT_EQ(avg_pool_v1_node->get_rounding_type(), op::RoundingType::FLOOR);
EXPECT_EQ(avg_pool_v1_node->get_exclude_pad(), !include_pad);
EXPECT_EQ(avg_pool_v1_node->get_auto_pad(), pad_mode);
}
TEST(upgrade_pass, opset1_maxpool_pass)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6, 9});
Shape pads_begin{0, 0};
Shape pads_end{0, 0};
Strides strides{1, 1};
Shape kernel_shape{3, 3};
bool ceil_mode = false;
op::PadType pad_mode = op::PadType::EXPLICIT;
auto maxpool_v0 = make_shared<op::v0::MaxPool>(
arg, kernel_shape, strides, pads_begin, pads_end, pad_mode, ceil_mode);
auto result = make_shared<op::Result>(maxpool_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto maxpool_s1_result = f->get_results().at(0);
auto node = maxpool_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto max_pool_v1_node = static_pointer_cast<op::v1::MaxPool>(node);
EXPECT_EQ(max_pool_v1_node->description(), "MaxPool");
EXPECT_EQ(max_pool_v1_node->get_version(), 1);
EXPECT_EQ(max_pool_v1_node->get_pads_begin(), pads_begin);
EXPECT_EQ(max_pool_v1_node->get_pads_end(), pads_end);
EXPECT_EQ(max_pool_v1_node->get_strides(), strides);
EXPECT_EQ(max_pool_v1_node->get_kernel(), kernel_shape);
EXPECT_EQ(max_pool_v1_node->get_rounding_type(), op::RoundingType::FLOOR);
EXPECT_EQ(max_pool_v1_node->get_auto_pad(), pad_mode);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment