Commit aa3692d2 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Sang Ik Lee

Adding auto pad to convolution and pooling (#2743)

* Adding auto pad to convolution

* Added auto pad to pooling ops and moved auto pad computation to utility method

* Added serializer support for autopadding. workaround for clang macro warning

* Style fix

* Addressed PR feedback

* Fix docstrings for same_upper and same_lower
parent 2257d9cf
...@@ -141,6 +141,7 @@ ...@@ -141,6 +141,7 @@
#include "ngraph/op/tan.hpp" #include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/partial_shape.hpp" #include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp" #include "ngraph/runtime/tensor.hpp"
......
...@@ -26,13 +26,15 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -26,13 +26,15 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation) bool include_padding_in_avg_computation,
const PadType& pad_type)
: Op("AvgPool", check_single_output_args({arg})) : Op("AvgPool", check_single_output_args({arg}))
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_include_padding_in_avg_computation(include_padding_in_avg_computation) , m_include_padding_in_avg_computation(include_padding_in_avg_computation)
, m_pad_type(pad_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -56,6 +58,23 @@ void op::AvgPool::validate_and_infer_types() ...@@ -56,6 +58,23 @@ void op::AvgPool::validate_and_infer_types()
const PartialShape& arg_shape = get_input_partial_shape(0); const PartialShape& arg_shape = get_input_partial_shape(0);
if (m_pad_type == PadType::SAME_UPPER || m_pad_type == PadType::SAME_LOWER)
{
if (arg_shape.is_static())
{
CoordinateDiff padding_above, padding_below;
infer_auto_padding(arg_shape.to_shape(),
m_window_shape,
m_window_movement_strides,
Strides(m_window_shape.size(), 1), // No dilation
m_pad_type,
padding_above,
padding_below);
m_padding_above = Shape(padding_above.begin(), padding_above.end());
m_padding_below = Shape(padding_below.begin(), padding_below.end());
}
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end()); CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,7 +49,8 @@ namespace ngraph ...@@ -48,7 +49,8 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation = false); bool include_padding_in_avg_computation = false,
const PadType& pad_type = PadType::EXPLICIT);
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0). /// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
/// ///
...@@ -90,6 +92,8 @@ namespace ngraph ...@@ -90,6 +92,8 @@ namespace ngraph
{ {
return m_include_padding_in_avg_computation; return m_include_padding_in_avg_computation;
} }
/// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; }
/// \return The default value for AvgPool. /// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -102,6 +106,7 @@ namespace ngraph ...@@ -102,6 +106,7 @@ namespace ngraph
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
bool m_include_padding_in_avg_computation; bool m_include_padding_in_avg_computation;
PadType m_pad_type;
}; };
class AvgPoolBackprop : public Op class AvgPoolBackprop : public Op
......
...@@ -33,13 +33,15 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -33,13 +33,15 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides) const Strides& data_dilation_strides,
const PadType& pad_type)
: Op("Convolution", check_single_output_args({data_batch, filters})) : Op("Convolution", check_single_output_args({data_batch, filters}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides) , m_data_dilation_strides(data_dilation_strides)
, m_pad_type(pad_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -76,6 +78,25 @@ void op::Convolution::validate_and_infer_types() ...@@ -76,6 +78,25 @@ void op::Convolution::validate_and_infer_types()
m_padding_above = conv_default_padding(this, data_batch_shape, filters_shape); m_padding_above = conv_default_padding(this, data_batch_shape, filters_shape);
} }
if (m_pad_type == PadType::SAME_UPPER || m_pad_type == PadType::SAME_LOWER)
{
if (data_batch_shape.is_static() && filters_shape.is_static())
{
// TODO: data dilation
m_padding_below.clear();
m_padding_above.clear();
auto filter_shape = filters_shape.to_shape();
filter_shape.erase(filter_shape.begin(), filter_shape.begin() + 2); // Remove {O,I}
infer_auto_padding(data_batch_shape.to_shape(),
filter_shape,
m_window_movement_strides,
m_window_dilation_strides,
m_pad_type,
m_padding_above,
m_padding_below);
}
}
element::Type result_et; element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -45,6 +46,8 @@ namespace ngraph ...@@ -45,6 +46,8 @@ namespace ngraph
/// `[f]` /// `[f]`
/// \param data_dilation_strides The data dilation strides.<br> /// \param data_dilation_strides The data dilation strides.<br>
/// `[f]` /// `[f]`
/// \param pad_type The pad type for automatically computing padding sizes.<br>
/// `[f]`
/// ///
/// Output `[N, C_OUT, R1, ... Rf]` /// Output `[N, C_OUT, R1, ... Rf]`
/// ///
...@@ -54,7 +57,8 @@ namespace ngraph ...@@ -54,7 +57,8 @@ namespace ngraph
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides); const Strides& data_dilation_strides,
const PadType& pad_type = PadType::EXPLICIT);
/// \brief Constructs a batched convolution operation with no data dilation (i.e., all data dilation strides are 1). /// \brief Constructs a batched convolution operation with no data dilation (i.e., all data dilation strides are 1).
/// ///
...@@ -141,6 +145,8 @@ namespace ngraph ...@@ -141,6 +145,8 @@ namespace ngraph
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
/// \return The input data dilation strides. /// \return The input data dilation strides.
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
/// \return The pad type for convolution.
const PadType& get_pad_type() const { return m_pad_type; }
/// \return The default value for Convolution. /// \return The default value for Convolution.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -153,6 +159,7 @@ namespace ngraph ...@@ -153,6 +159,7 @@ namespace ngraph
CoordinateDiff m_padding_below; CoordinateDiff m_padding_below;
CoordinateDiff m_padding_above; CoordinateDiff m_padding_above;
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
PadType m_pad_type;
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
......
...@@ -29,12 +29,14 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -29,12 +29,14 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
const PadType& pad_type)
: Op("MaxPool", check_single_output_args({arg})) : Op("MaxPool", check_single_output_args({arg}))
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_pad_type(pad_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -58,6 +60,23 @@ void op::MaxPool::validate_and_infer_types() ...@@ -58,6 +60,23 @@ void op::MaxPool::validate_and_infer_types()
const PartialShape& arg_shape = get_input_partial_shape(0); const PartialShape& arg_shape = get_input_partial_shape(0);
if (m_pad_type == PadType::SAME_UPPER || m_pad_type == PadType::SAME_LOWER)
{
if (arg_shape.is_static())
{
CoordinateDiff padding_above, padding_below;
infer_auto_padding(arg_shape.to_shape(),
m_window_shape,
m_window_movement_strides,
Strides(m_window_shape.size(), 1), // No dilation
m_pad_type,
padding_above,
padding_below);
m_padding_above = Shape(padding_above.begin(), padding_above.end());
m_padding_below = Shape(padding_below.begin(), padding_below.end());
}
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end()); CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,11 +35,13 @@ namespace ngraph ...@@ -34,11 +35,13 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const std::shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above,
const PadType& pad_type = PadType::EXPLICIT);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -68,6 +71,8 @@ namespace ngraph ...@@ -68,6 +71,8 @@ namespace ngraph
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
/// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -82,6 +87,7 @@ namespace ngraph ...@@ -82,6 +87,7 @@ namespace ngraph
Strides m_window_movement_strides; Strides m_window_movement_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
PadType m_pad_type;
}; };
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
......
...@@ -18,19 +18,12 @@ ...@@ -18,19 +18,12 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Modes for the `Pad` operator.
enum class PadMode
{
CONSTANT = 0,
EDGE,
REFLECT
};
/// \brief Generic padding operation. /// \brief Generic padding operation.
class Pad : public Op class Pad : public Op
{ {
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
namespace ngraph
{
namespace op
{
/// \brief Modes for the `Pad` operator.
enum class PadMode
{
CONSTANT = 0,
EDGE,
REFLECT
};
/// \brief Padding Type used for `Convolution` and `Pooling`
///
/// Follows ONNX padding type definitions
/// EXPLICIT - Pad dimensions are explicity specified
/// SAME_LOWER - Pad dimensions computed to match input shape
/// Ceil(num_dims/2) at the beginning and
/// Floor(num_dims/2) at the end
/// SAME_UPPER - Pad dimensions computed to match input shape
/// Floor(num_dims/2) at the beginning and
/// Ceil(num_dims/2) at the end
/// VALID - No padding
///
enum class PadType
{
EXPLICIT = 0,
SAME_LOWER,
SAME_UPPER,
VALID,
AUTO = SAME_UPPER,
NOTSET = EXPLICIT
};
}
}
...@@ -540,12 +540,16 @@ static shared_ptr<ngraph::Function> ...@@ -540,12 +540,16 @@ static shared_ptr<ngraph::Function>
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<size_t>>();
auto include_padding_in_avg_computation = auto include_padding_in_avg_computation =
node_js.at("include_padding_in_avg_computation").get<bool>(); node_js.at("include_padding_in_avg_computation").get<bool>();
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
node = make_shared<op::AvgPool>(args[0], node = make_shared<op::AvgPool>(args[0],
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation,
pad_type);
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
...@@ -666,6 +670,10 @@ static shared_ptr<ngraph::Function> ...@@ -666,6 +670,10 @@ static shared_ptr<ngraph::Function>
data_dilation_strides_maybe = node_js["image_dilation_strides"]; data_dilation_strides_maybe = node_js["image_dilation_strides"];
} }
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (data_dilation_strides_maybe.empty()) if (data_dilation_strides_maybe.empty())
{ {
node = make_shared<op::Convolution>(args[0], node = make_shared<op::Convolution>(args[0],
...@@ -684,7 +692,8 @@ static shared_ptr<ngraph::Function> ...@@ -684,7 +692,8 @@ static shared_ptr<ngraph::Function>
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides_maybe.get<std::vector<size_t>>()); data_dilation_strides_maybe.get<std::vector<size_t>>(),
pad_type);
} }
break; break;
} }
...@@ -961,6 +970,9 @@ static shared_ptr<ngraph::Function> ...@@ -961,6 +970,9 @@ static shared_ptr<ngraph::Function>
// omitted. // omitted.
auto padding_below_maybe = node_js["padding_below"]; auto padding_below_maybe = node_js["padding_below"];
auto padding_above_maybe = node_js["padding_above"]; auto padding_above_maybe = node_js["padding_above"];
op::PadType pad_type = node_js["pad_type"].empty()
? op::PadType::EXPLICIT
: static_cast<op::PadType>(node_js.at("pad_type"));
if (padding_below_maybe.empty() && !padding_above_maybe.empty()) if (padding_below_maybe.empty() && !padding_above_maybe.empty())
{ {
throw runtime_error( throw runtime_error(
...@@ -979,7 +991,8 @@ static shared_ptr<ngraph::Function> ...@@ -979,7 +991,8 @@ static shared_ptr<ngraph::Function>
window_shape, window_shape,
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above); padding_above,
pad_type);
} }
else else
{ {
...@@ -1518,6 +1531,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1518,6 +1531,7 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation(); node["include_padding_in_avg_computation"] = tmp->get_include_padding_in_avg_computation();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
...@@ -1599,6 +1613,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1599,6 +1613,7 @@ static json write(const Node& n, bool binary_constant_data)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -1747,6 +1762,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1747,6 +1762,7 @@ static json write(const Node& n, bool binary_constant_data)
node["window_movement_strides"] = tmp->get_window_movement_strides(); node["window_movement_strides"] = tmp->get_window_movement_strides();
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["pad_type"] = tmp->get_pad_type();
break; break;
} }
case OP_TYPEID::MaxPoolBackprop: case OP_TYPEID::MaxPoolBackprop:
......
...@@ -589,3 +589,28 @@ std::tuple<element::Type, PartialShape, PartialShape> ...@@ -589,3 +589,28 @@ std::tuple<element::Type, PartialShape, PartialShape>
input_shape, input_shape,
{{gamma_element_type, gamma_shape, "gamma"}, {beta_element_type, beta_shape, "beta"}}); {{gamma_element_type, gamma_shape, "gamma"}, {beta_element_type, beta_shape, "beta"}});
} }
void ngraph::infer_auto_padding(const Shape& image_shape,
const Shape& filter_shape,
const Strides& filter_strides,
const Strides& filter_dilations,
const op::PadType pad_type,
CoordinateDiff& padding_above,
CoordinateDiff& padding_below)
{
NGRAPH_CHECK(pad_type == op::PadType::SAME_UPPER || pad_type == op::PadType::SAME_LOWER);
for (size_t i = 0; i < static_cast<size_t>(filter_shape.size()); i++)
{
int64_t image_size = static_cast<int64_t>(image_shape[i + 2]);
int64_t filter_size = (static_cast<int64_t>(filter_shape[i]) - 1) * filter_dilations[i] + 1;
int64_t filter_stride = static_cast<int64_t>(filter_strides[i]);
auto output_size = (image_size + filter_stride - 1) / filter_stride;
auto padding_needed =
std::max(int64_t(0), (output_size - 1) * filter_stride + filter_size - image_size);
auto padding_lhs = padding_needed / 2;
auto padding_rhs = padding_needed - padding_lhs;
padding_below.push_back(pad_type == op::PadType::SAME_UPPER ? padding_lhs : padding_rhs);
padding_above.push_back(pad_type == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs);
}
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -82,4 +83,12 @@ namespace ngraph ...@@ -82,4 +83,12 @@ namespace ngraph
const PartialShape& input_shape, const PartialShape& input_shape,
const PartialShape& gamma_shape, const PartialShape& gamma_shape,
const PartialShape& beta_shape); const PartialShape& beta_shape);
void infer_auto_padding(const Shape& image_shape,
const Shape& filter_shape,
const Strides& filter_strides,
const Strides& filter_dilations,
const op::PadType pad_type,
CoordinateDiff& padding_above,
CoordinateDiff& padding_below);
} }
...@@ -5395,6 +5395,132 @@ TEST(type_prop, conv_2d_deduce_padded_neg) ...@@ -5395,6 +5395,132 @@ TEST(type_prop, conv_2d_deduce_padded_neg)
EXPECT_EQ(conv->get_padding_above(), (CoordinateDiff{3, -4})); EXPECT_EQ(conv->get_padding_above(), (CoordinateDiff{3, -4}));
} }
struct DeduceAutoPadTest
: ::testing::TestWithParam<
std::tuple<Shape, Shape, Strides, Strides, CoordinateDiff, CoordinateDiff>>
{
};
TEST_P(DeduceAutoPadTest, same_upper)
{
auto image_shape = std::get<0>(GetParam());
image_shape.insert(image_shape.begin(), {1, 1}); // Add {N, C}
auto filter_shape = std::get<1>(GetParam());
filter_shape.insert(filter_shape.begin(), {1, 1}); // Add {O, I}
auto param0 = make_shared<op::Parameter>(element::f32, image_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filter_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
std::get<2>(GetParam()),
std::get<3>(GetParam()),
CoordinateDiff(),
CoordinateDiff(),
Strides(),
op::PadType::SAME_UPPER);
EXPECT_EQ(conv->get_padding_below(), std::get<4>(GetParam()));
EXPECT_EQ(conv->get_padding_above(), std::get<5>(GetParam()));
auto no_dilation = std::all_of(std::get<3>(GetParam()).begin(),
std::get<3>(GetParam()).end(),
[](size_t i) { return i <= 1; });
if (no_dilation)
{
auto max_pool = make_shared<op::MaxPool>(param0,
std::get<1>(GetParam()),
std::get<2>(GetParam()),
Shape(),
Shape(),
op::PadType::SAME_UPPER);
CoordinateDiff padding_below(max_pool->get_padding_below().begin(),
max_pool->get_padding_below().end());
CoordinateDiff padding_above(max_pool->get_padding_above().begin(),
max_pool->get_padding_above().end());
EXPECT_EQ(padding_below, std::get<4>(GetParam()));
EXPECT_EQ(padding_above, std::get<5>(GetParam()));
auto avg_pool = make_shared<op::AvgPool>(param0,
std::get<1>(GetParam()),
std::get<2>(GetParam()),
Shape(),
Shape(),
false,
op::PadType::SAME_UPPER);
CoordinateDiff pad_below(avg_pool->get_padding_below().begin(),
avg_pool->get_padding_below().end());
CoordinateDiff pad_above(avg_pool->get_padding_above().begin(),
avg_pool->get_padding_above().end());
EXPECT_EQ(pad_below, std::get<4>(GetParam()));
EXPECT_EQ(pad_above, std::get<5>(GetParam()));
}
}
TEST_P(DeduceAutoPadTest, same_lower)
{
auto image_shape = std::get<0>(GetParam());
image_shape.insert(image_shape.begin(), {1, 1}); // Add {N, C}
auto filter_shape = std::get<1>(GetParam());
filter_shape.insert(filter_shape.begin(), {1, 1}); // Add {O, I}
auto param0 = make_shared<op::Parameter>(element::f32, image_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filter_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
std::get<2>(GetParam()),
std::get<3>(GetParam()),
CoordinateDiff(),
CoordinateDiff(),
Strides(),
op::PadType::SAME_LOWER);
EXPECT_EQ(conv->get_padding_above(), std::get<4>(GetParam()));
EXPECT_EQ(conv->get_padding_below(), std::get<5>(GetParam()));
}
INSTANTIATE_TEST_CASE_P(type_prop,
DeduceAutoPadTest,
::testing::Values(std::make_tuple(Shape{5, 6},
Shape{3, 4},
Strides{2, 1},
Strides{1, 1},
CoordinateDiff{1, 1},
CoordinateDiff{1, 2}),
std::make_tuple(Shape{3, 3},
Shape{2, 2},
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{1, 1}),
std::make_tuple(Shape{28, 28},
Shape{3, 3},
Strides{2, 2},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{1, 1}),
std::make_tuple(Shape{100, 150},
Shape{10, 20},
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{4, 9},
CoordinateDiff{5, 10}),
std::make_tuple(Shape{2},
Shape{1},
Strides{3},
Strides{1},
CoordinateDiff{0},
CoordinateDiff{0}),
std::make_tuple(Shape{10, 1},
Shape{4, 1},
Strides{1, 1},
Strides{2, 1},
CoordinateDiff{3, 0},
CoordinateDiff{3, 0}),
std::make_tuple(Shape{10, 5, 6},
Shape{3, 3, 4},
Strides{1, 2, 1},
Strides{2, 1, 1},
CoordinateDiff{2, 1, 1},
CoordinateDiff{2, 1, 2})), );
TEST(type_prop, conv_2d_deduce_strided) TEST(type_prop, conv_2d_deduce_strided)
{ {
// Deduce type // Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment