Commit a809ed7f authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Added support for ceil mode in AvgPool (#3027)

* Added support for ceil mode in AvgPool

* Added ceil mode to MaxPool

* remove extra semicolon

* Add more constructor variants to support pybind which seems to have issues with multiple optional arguments

* More constructor variants for AvgPool

* More constructor variants for MaxPool

* Style fix

* Avoid constructor delegation

* Revert "Avoid constructor delegation"

This reverts commit 8efd59127bc9a16bae93b3c6b67dbcccfa95648f.
parent fa300fae
...@@ -33,7 +33,8 @@ op::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -33,7 +33,8 @@ op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation, bool include_padding_in_avg_computation,
const PadType& pad_type) const PadType& pad_type,
bool ceil_mode)
: Op({arg}) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
...@@ -41,10 +42,45 @@ op::AvgPool::AvgPool(const Output<Node>& arg, ...@@ -41,10 +42,45 @@ op::AvgPool::AvgPool(const Output<Node>& arg,
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_include_padding_in_avg_computation(include_padding_in_avg_computation) , m_include_padding_in_avg_computation(include_padding_in_avg_computation)
, m_pad_type(pad_type) , m_pad_type(pad_type)
, m_ceil_mode(ceil_mode)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation,
const PadType& pad_type)
: AvgPool(arg,
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation,
pad_type,
false)
{
}
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation)
: AvgPool(arg,
window_shape,
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation,
PadType::EXPLICIT)
{
}
void op::AvgPool::validate_and_infer_types() void op::AvgPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
...@@ -94,7 +130,8 @@ void op::AvgPool::validate_and_infer_types() ...@@ -94,7 +130,8 @@ void op::AvgPool::validate_and_infer_types()
padding_above, padding_above,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_include_padding_in_avg_computation)); m_include_padding_in_avg_computation,
m_ceil_mode));
} }
op::AvgPool::AvgPool(const Output<Node>& arg, op::AvgPool::AvgPool(const Output<Node>& arg,
...@@ -169,6 +206,16 @@ void op::AvgPool::set_pad_type(const op::PadType& pad_type) ...@@ -169,6 +206,16 @@ void op::AvgPool::set_pad_type(const op::PadType& pad_type)
m_pad_type = pad_type; m_pad_type = pad_type;
} }
bool op::AvgPool::get_ceil_mode() const
{
return m_ceil_mode;
}
void op::AvgPool::set_ceil_mode(bool ceil_mode)
{
m_ceil_mode = ceil_mode;
}
shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -177,7 +224,9 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -177,7 +224,9 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation,
m_pad_type,
m_ceil_mode);
} }
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop"); const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
...@@ -314,6 +363,11 @@ shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_a ...@@ -314,6 +363,11 @@ shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_a
void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (m_ceil_mode)
{
throw ngraph_error("Autodiff not supported on AvgPool with ceil_mode set");
}
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto operand = get_argument(0); auto operand = get_argument(0);
......
...@@ -34,6 +34,57 @@ namespace ngraph ...@@ -34,6 +34,57 @@ namespace ngraph
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
AvgPool(); AvgPool();
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param window_shape The window shape.<br>
/// `[n]`
/// \param window_movement_strides The window movement strides.<br>
/// `[n]`
/// \param padding_below The below-padding shape.<br>
/// `[n]`
/// \param padding_above The above-padding shape.<br>
/// `[n]`
/// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages.
/// \param pad_type Padding type to use for additional padded dimensions
/// \param ceil_mode Whether to use ceiling while computing output shape.
AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation,
const PadType& pad_type,
bool ceil_mode);
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param window_shape The window shape.<br>
/// `[n]`
/// \param window_movement_strides The window movement strides.<br>
/// `[n]`
/// \param padding_below The below-padding shape.<br>
/// `[n]`
/// \param padding_above The above-padding shape.<br>
/// `[n]`
/// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages.
/// \param pad_type Padding type to use for additional padded dimensions
AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation,
const PadType& pad_type);
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
/// \param arg The output producing the input data batch tensor.<br> /// \param arg The output producing the input data batch tensor.<br>
...@@ -54,8 +105,7 @@ namespace ngraph ...@@ -54,8 +105,7 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation = false, bool include_padding_in_avg_computation = false);
const PadType& pad_type = PadType::EXPLICIT);
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0). /// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
/// ///
...@@ -102,6 +152,8 @@ namespace ngraph ...@@ -102,6 +152,8 @@ namespace ngraph
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const; const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type); void set_pad_type(const PadType& pad_type);
bool get_ceil_mode() const;
void set_ceil_mode(bool ceil_mode);
/// \return The default value for AvgPool. /// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -115,6 +167,7 @@ namespace ngraph ...@@ -115,6 +167,7 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
bool m_include_padding_in_avg_computation{false}; bool m_include_padding_in_avg_computation{false};
PadType m_pad_type{PadType::EXPLICIT}; PadType m_pad_type{PadType::EXPLICIT};
bool m_ceil_mode{false};
}; };
class AvgPoolBackprop : public Op class AvgPoolBackprop : public Op
......
...@@ -169,7 +169,8 @@ shared_ptr<Node> op::Convolution::copy_with_new_args(const NodeVector& new_args) ...@@ -169,7 +169,8 @@ shared_ptr<Node> op::Convolution::copy_with_new_args(const NodeVector& new_args)
m_window_dilation_strides, m_window_dilation_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
m_data_dilation_strides); m_data_dilation_strides,
m_pad_type);
} }
void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -30,17 +30,44 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -30,17 +30,44 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type) const PadType& pad_type,
bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg})) : Op("MaxPool", check_single_output_args({arg}))
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_pad_type(pad_type) , m_pad_type(pad_type)
, m_ceil_mode(ceil_mode)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
const PadType& pad_type)
: MaxPool(
arg, window_shape, window_movement_strides, padding_below, padding_above, pad_type, false)
{
}
op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: MaxPool(arg,
window_shape,
window_movement_strides,
padding_below,
padding_above,
PadType::EXPLICIT)
{
}
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
if (0 == m_window_movement_strides.size()) if (0 == m_window_movement_strides.size())
...@@ -90,7 +117,8 @@ void op::MaxPool::validate_and_infer_types() ...@@ -90,7 +117,8 @@ void op::MaxPool::validate_and_infer_types()
padding_above, padding_above,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
true)); true,
m_ceil_mode));
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
...@@ -112,7 +140,9 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -112,7 +140,9 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above,
m_pad_type,
m_ceil_mode);
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
...@@ -218,6 +248,11 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a ...@@ -218,6 +248,11 @@ shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_a
void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (m_ceil_mode)
{
throw ngraph_error("Autodiff not supported on MaxPool with ceil_mode set");
}
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto operand = get_argument(0); auto operand = get_argument(0);
......
...@@ -36,12 +36,42 @@ namespace ngraph ...@@ -36,12 +36,42 @@ namespace ngraph
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type = PadType::EXPLICIT); const PadType& pad_type,
bool ceil_mode);
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
const PadType& pad_type);
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -73,6 +103,8 @@ namespace ngraph ...@@ -73,6 +103,8 @@ namespace ngraph
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
/// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -88,6 +120,7 @@ namespace ngraph ...@@ -88,6 +120,7 @@ namespace ngraph
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
PadType m_pad_type; PadType m_pad_type;
bool m_ceil_mode{false};
}; };
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
......
...@@ -609,13 +609,17 @@ static shared_ptr<ngraph::Function> ...@@ -609,13 +609,17 @@ static shared_ptr<ngraph::Function>
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT ? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type")); : static_cast<op::PadType>(node_js.at("pad_type"));
bool ceil_mode =
node_js["ceil_mode"].empty() ? false : node_js.at("ceil_mode").get<bool>();
;
node = make_shared<op::AvgPool>(args[0], node = make_shared<op::AvgPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation, include_padding_in_avg_computation,
pad_type); pad_type,
ceil_mode);
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
...@@ -1792,6 +1796,10 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1792,6 +1796,10 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
node["pad_type"] = tmp->get_pad_type(); node["pad_type"] = tmp->get_pad_type();
if (tmp->get_ceil_mode())
{
node["ceil_mode"] = tmp->get_ceil_mode();
}
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
......
...@@ -79,7 +79,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -79,7 +79,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
const PartialShape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed) bool is_window_all_in_padding_allowed,
bool ceil_mode)
{ {
PartialShape data_shape_merged{PartialShape::dynamic()}; PartialShape data_shape_merged{PartialShape::dynamic()};
...@@ -198,9 +199,20 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -198,9 +199,20 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
i, i,
"."); ".");
if (ceil_mode)
{
output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) - output_shape[i] = ceil_div(static_cast<size_t>(data_padded_dilated_dim) -
static_cast<size_t>(window_dilated_dim) + 1, static_cast<size_t>(window_dilated_dim),
window_strides[i]); window_strides[i]) +
1;
}
else
{
output_shape[i] = ((static_cast<size_t>(data_padded_dilated_dim) -
static_cast<size_t>(window_dilated_dim)) /
window_strides[i]) +
1;
}
} }
} }
} }
...@@ -370,7 +382,8 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node, ...@@ -370,7 +382,8 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const PartialShape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
bool is_window_all_in_padding_allowed) bool is_window_all_in_padding_allowed,
bool ceil_mode)
{ {
NODE_VALIDATION_CHECK(node, NODE_VALIDATION_CHECK(node,
data_batch_shape.rank().is_dynamic() || data_batch_shape.rank().is_dynamic() ||
...@@ -438,7 +451,8 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node, ...@@ -438,7 +451,8 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
window_shape, window_shape,
window_strides, window_strides,
window_dilation, window_dilation,
is_window_all_in_padding_allowed); is_window_all_in_padding_allowed,
ceil_mode);
} }
PartialShape data_batch_output_shape{ PartialShape data_batch_output_shape{
......
...@@ -41,7 +41,8 @@ namespace ngraph ...@@ -41,7 +41,8 @@ namespace ngraph
const PartialShape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed,
bool ceil_mode = false);
std::tuple<element::Type, PartialShape> std::tuple<element::Type, PartialShape>
infer_convolution_forward(const Node* node, infer_convolution_forward(const Node* node,
...@@ -61,7 +62,8 @@ namespace ngraph ...@@ -61,7 +62,8 @@ namespace ngraph
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const PartialShape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed,
bool ceil_mode = false);
std::tuple<element::Type, PartialShape, PartialShape> std::tuple<element::Type, PartialShape, PartialShape>
infer_batch_norm_forward(const Node* node, infer_batch_norm_forward(const Node* node,
......
...@@ -7578,6 +7578,26 @@ TEST(type_prop, max_pool_3d_deduce_strided_small) ...@@ -7578,6 +7578,26 @@ TEST(type_prop, max_pool_3d_deduce_strided_small)
EXPECT_EQ(max_pool->get_window_shape(), (Shape{2, 3, 2})); EXPECT_EQ(max_pool->get_window_shape(), (Shape{2, 3, 2}));
} }
TEST(type_prop, max_pool_ceil_mode)
{
// Deduce type
auto param = make_shared<op::Parameter>(element::f32, Shape{64, 3, 10});
Shape window_shape{2};
auto move_strides = Strides{4};
Shape padding_below{4};
Shape padding_above{5};
auto max_pool = make_shared<op::MaxPool>(param,
window_shape,
move_strides,
padding_below,
padding_above,
op::PadType::EXPLICIT,
true);
// ceil((10 + 9 - 2)/4) + 1
EXPECT_EQ(max_pool->get_shape(), (Shape{64, 3, 6}));
}
TEST(type_prop, max_pool_invalid_0d_input) TEST(type_prop, max_pool_invalid_0d_input)
{ {
// Deduce type // Deduce type
...@@ -8647,6 +8667,27 @@ TEST(type_prop, avg_pool_3d_deduce_strided_padded_small) ...@@ -8647,6 +8667,27 @@ TEST(type_prop, avg_pool_3d_deduce_strided_padded_small)
EXPECT_EQ(avg_pool->get_padding_above(), (Shape{6, 4, 5})); EXPECT_EQ(avg_pool->get_padding_above(), (Shape{6, 4, 5}));
} }
TEST(type_prop, avg_pool_ceil_mode)
{
// Deduce type
auto param = make_shared<op::Parameter>(element::f32, Shape{64, 3, 10});
Shape window_shape{2};
auto move_strides = Strides{4};
Shape padding_below{4};
Shape padding_above{5};
auto avg_pool = make_shared<op::AvgPool>(param,
window_shape,
move_strides,
padding_below,
padding_above,
true,
op::PadType::EXPLICIT,
true);
// ceil((10 + 9 - 2)/4) + 1
EXPECT_EQ(avg_pool->get_shape(), (Shape{64, 3, 6}));
}
TEST(type_prop, avg_pool_invalid_0d_input) TEST(type_prop, avg_pool_invalid_0d_input)
{ {
// Deduce type // Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment