Unverified Commit b2da4cee authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by GitHub

POC enabling Resnet50 with dynamic batch dimension (#4298)

* Constify the onnx importer conv

* Extract and fix the groups attribute validation for Conv

* Check if the convolution's data input rank is static

* Validate the groups attribute against channels and filters

* Validate the conv operation in a separate function

* Dynamically broadcast the conv bias if needed

* Import a test model with dynamic batch conv op

* Run a conv test with dynamic batch

* Cleanup of conv bias handling code

* Use a proper Broadcast constructor for bias in onnx conv

* Handle dynamic ReduceMean with statically defined rank

* Use the target shape rank to construct the default output shape for Broadcast

* Handle ONNX Squeeze with dynamic input and static rank

* Handle ONNX Shape with dynamic input and static rank

* Handle the dynamic target shape in ONNX Reshape

* Fix for the ONNX Shape input validation

* Handle ONNX Softmax with dynamic input and static rank

* Fix the failing Broadcast type prop test

* Code formatting

* Dont broadcast bias before adding it to the conv node

* Drop the conv node validation and rely on the core op implementation checks

* Code review feedback

* Revert the Broadcast op changes

* More code review feedback

* Dynamic conv test using ng test case

* Obsolete headers removal

* Code formatting

* Variable names refactor

* Disable model_conv_with_dynamic_batch test on GPU

* Code formatting
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent d3747036
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/convpool.hpp" #include "utils/convpool.hpp"
namespace ngraph namespace ngraph
...@@ -46,7 +45,7 @@ namespace ngraph ...@@ -46,7 +45,7 @@ namespace ngraph
const ngraph::Strides& dilations, const ngraph::Strides& dilations,
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above, const ngraph::CoordinateDiff& padding_above,
int groups, int64_t groups,
const ngraph::op::PadType& auto_pad) const ngraph::op::PadType& auto_pad)
{ {
if (groups > 1) if (groups > 1)
...@@ -55,7 +54,7 @@ namespace ngraph ...@@ -55,7 +54,7 @@ namespace ngraph
filters_shape.at(0) = filters_shape.at(0) / groups; filters_shape.at(0) = filters_shape.at(0) / groups;
filters_shape.insert(filters_shape.begin(), groups); filters_shape.insert(filters_shape.begin(), groups);
auto reshaped_filters = const auto reshaped_filters =
ngraph::builder::opset1::reshape(filters, filters_shape); ngraph::builder::opset1::reshape(filters, filters_shape);
return std::make_shared<default_opset::GroupConvolution>( return std::make_shared<default_opset::GroupConvolution>(
...@@ -79,64 +78,73 @@ namespace ngraph ...@@ -79,64 +78,73 @@ namespace ngraph
} }
} }
std::shared_ptr<ngraph::Node>
add_bias(const std::shared_ptr<ngraph::Node>& ng_conv,
const std::shared_ptr<ngraph::Node>& bias)
{
const auto rank_of_conv =
static_cast<size_t>(ng_conv->get_output_partial_shape(0).rank());
// reshape the bias node {M} to {1, M, 1, 1, ..., 1}
// this is required by the addition operation that needs to be able
// to broadcast the bias to match the shape of the convolution node
std::vector<size_t> reshape_pattern_values(rank_of_conv, 1U);
reshape_pattern_values[1] = bias->get_shape().front();
const auto reshape_pattern =
default_opset::Constant::create(element::u64,
Shape{reshape_pattern_values.size()},
reshape_pattern_values);
std::shared_ptr<ngraph::Node> reshaped_bias =
std::make_shared<default_opset::Reshape>(bias, reshape_pattern, false);
return {std::make_shared<default_opset::Add>(ng_conv, reshaped_bias)};
}
} // namespace } // namespace
NodeVector conv(const Node& node) NodeVector conv(const Node& node)
{ {
// in the current implementation we assume that the data input rank is static
// and only the 'batch' dimension can be dynamic
const NodeVector& inputs = node.get_ng_inputs(); const NodeVector& inputs = node.get_ng_inputs();
auto data = inputs.at(0); const auto data = inputs.at(0);
auto filters = inputs.at(1); const auto filters = inputs.at(1);
const auto groups = node.get_attribute_value<int64_t>("group", 1);
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
NGRAPH_CHECK(data->get_output_partial_shape(0).rank().is_static(),
ASSERT_VALID_ARGUMENT( "The input data tensor's rank has to be known (static)");
node,
((groups >= 0) && const auto strides = convpool::get_strides(node);
(groups <= static_cast<int64_t>(data->get_shape().at(1))) && const auto dilations = convpool::get_dilations(node);
(groups <= static_cast<int64_t>(filters->get_shape().at(0))))) const auto paddings = convpool::get_pads(node);
<< "incorrect value of 'group' attribute: " << groups; const ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node);
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
<< "provided group attribute value must be a multiple of data channels "
"count.";
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
<< "provided group attribute value must be a multiple of filter channels "
"count.";
auto strides = convpool::get_strides(node);
auto dilations = convpool::get_dilations(node);
auto paddings = convpool::get_pads(node);
ngraph::op::PadType auto_pad_type = convpool::get_auto_pad(node);
const auto& padding_below = paddings.first; const auto& padding_below = paddings.first;
const auto& padding_above = paddings.second; const auto& padding_above = paddings.second;
auto conv_node = make_ng_convolution(data, const auto conv_node = make_ng_convolution(data,
filters, filters,
strides, strides,
dilations, dilations,
padding_below, padding_below,
padding_above, padding_above,
groups, groups,
auto_pad_type); auto_pad_type);
// no bias param // no bias param
if (inputs.size() < 3) if (inputs.size() < 3)
{ {
return {conv_node}; return {conv_node};
} }
else
{
const auto bias = inputs.at(2);
const auto bias_ps = bias->get_output_partial_shape(0);
auto bias = inputs.at(2); NGRAPH_CHECK(bias_ps.is_static() && is_vector(bias_ps.to_shape()),
const Shape& new_shape = conv_node->get_shape(); "The bias input needs to be a static 1D vector");
auto broadcasted_bias = std::make_shared<default_opset::Broadcast>( return {add_bias(conv_node, bias)};
bias, }
default_opset::Constant::create(
element::i64, Shape{new_shape.size()}, new_shape),
default_opset::Constant::create(element::i64, Shape{1}, {1}));
return {std::make_shared<default_opset::Add>(conv_node, broadcasted_bias)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -32,30 +32,52 @@ namespace ngraph ...@@ -32,30 +32,52 @@ namespace ngraph
{ {
NodeVector reduce_mean(const Node& node) NodeVector reduce_mean(const Node& node)
{ {
auto input_shape = node.get_ng_inputs().at(0)->get_shape(); const auto data = node.get_ng_inputs().at(0);
auto reduction_axes = reduction::detail::get_reduction_axes(node); const auto& data_shape = data->get_output_partial_shape(0);
std::size_t elem_count_product =
std::accumulate(std::begin(reduction_axes),
std::end(reduction_axes),
1UL,
[&input_shape](const std::size_t& a, const std::size_t& b) {
return a * input_shape.at(b);
});
auto sum_node = std::shared_ptr<ngraph::Node>{reduction::make_ng_reduction_op( // sum up the input data along the reduction axes
const auto sum_node = reduction::make_ng_reduction_op(
node, node,
node.get_ng_inputs().at(0), data,
std::make_shared<default_opset::ReduceSum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
bool>)}; bool>);
auto const_node = default_opset::Constant::create( // calculate the product of dimensions pointed to by reduction axes
sum_node->get_element_type(), size_t reduced_elems_count = 1U;
sum_node->get_shape(),
std::vector<std::size_t>(shape_size(sum_node->get_shape()),
elem_count_product));
if (data_shape.is_static())
{
const auto input_shape = data_shape.to_shape();
// calculate the product of dimensions pointed to by reduction axes
// this value represents the number of input tensor values that were reduced
for (const auto axis : reduction::detail::get_reduction_axes(node))
{
reduced_elems_count *= input_shape.at(axis);
}
}
else
{
for (const auto axis : reduction::detail::get_reduction_axes(node))
{
const auto dim_to_reduce = data_shape[axis];
NGRAPH_CHECK(dim_to_reduce.is_static(),
"Axis ",
axis,
" in the input data tensor needs to be statically "
"specified to create a ReduceMean operation");
reduced_elems_count *= static_cast<size_t>(dim_to_reduce);
}
}
const auto const_node = default_opset::Constant::create(
sum_node->get_element_type(), {}, {reduced_elems_count});
// divide the sum node containing reduced values by the number
// of those values to obtain the mean
return {std::make_shared<default_opset::Divide>(sum_node, const_node)}; return {std::make_shared<default_opset::Divide>(sum_node, const_node)};
} }
......
...@@ -43,9 +43,6 @@ namespace ngraph ...@@ -43,9 +43,6 @@ namespace ngraph
// Since opset 5 the target shape is provided as input // Since opset 5 the target shape is provided as input
if (ng_inputs.size() == 2) if (ng_inputs.size() == 2)
{ {
NGRAPH_CHECK(ng_inputs.at(1)->is_constant(),
"The target shape input has to be a Constant.");
pattern = ng_inputs.at(1); pattern = ng_inputs.at(1);
} }
else else
......
...@@ -32,11 +32,21 @@ namespace ngraph ...@@ -32,11 +32,21 @@ namespace ngraph
{ {
NodeVector shape(const Node& node) NodeVector shape(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); const auto data = node.get_ng_inputs().at(0);
auto data_shape = data->get_shape(); const auto data_shape = data->get_output_partial_shape(0);
return {std::make_shared<default_opset::Constant>( if (data_shape.is_static())
ngraph::element::i64, Shape{data_shape.size()}, data_shape)}; {
const auto static_data_shape = data_shape.to_shape();
return {default_opset::Constant::create(ngraph::element::i64,
Shape{static_data_shape.size()},
static_data_shape)};
}
else
{
return {std::make_shared<default_opset::ShapeOf>(data)};
}
} }
} // namespace set_1 } // namespace set_1
......
...@@ -39,10 +39,7 @@ namespace ngraph ...@@ -39,10 +39,7 @@ namespace ngraph
return {std::make_shared<default_opset::Softmax>(data, normalized_axis)}; return {std::make_shared<default_opset::Softmax>(data, normalized_axis)};
} }
} // namespace set_1 }
}
} // namespace op }
}
} // namespace onnx_import
} // namespace ngraph
...@@ -14,12 +14,10 @@ ...@@ -14,12 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <vector> #include "ngraph/op/fused/squeeze.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
#include "squeeze.hpp" #include "squeeze.hpp"
...@@ -42,6 +40,7 @@ namespace ngraph ...@@ -42,6 +40,7 @@ namespace ngraph
ngraph::normalize_axes(node.get_description(), axes, data_rank); ngraph::normalize_axes(node.get_description(), axes, data_rank);
auto axes_node = std::make_shared<default_opset::Constant>( auto axes_node = std::make_shared<default_opset::Constant>(
element::u64, Shape{normalized_axes.size()}, normalized_axes); element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)}; return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
} }
......
...@@ -30,9 +30,10 @@ namespace ngraph ...@@ -30,9 +30,10 @@ namespace ngraph
{ {
Shape get_kernel_shape(const Node& node) Shape get_kernel_shape(const Node& node)
{ {
std::size_t input_spatial_dims = node.get_ng_inputs().at(0)->get_shape().size() - 2; const auto& data_shape = node.get_ng_inputs().at(0)->get_output_partial_shape(0);
return node.get_attribute_value<std::vector<std::size_t>>( const size_t input_spatial_dims = static_cast<size_t>(data_shape.rank()) - 2;
"kernel_shape", std::vector<std::size_t>(input_spatial_dims, 1UL)); return node.get_attribute_value<std::vector<size_t>>(
"kernel_shape", std::vector<size_t>(input_spatial_dims, 1UL));
} }
namespace detail namespace detail
......
...@@ -34,15 +34,22 @@ namespace ngraph ...@@ -34,15 +34,22 @@ namespace ngraph
{ {
auto reduction_axes = auto reduction_axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(
node.get_description(), const auto input_rank =
reduction_axes, node.get_ng_inputs().at(0)->get_output_partial_shape(0).rank();
node.get_ng_inputs().at(0)->get_output_partial_shape(0).rank());
std::vector<std::size_t> normalized_axes =
ngraph::normalize_axes(node.get_description(), reduction_axes, input_rank);
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
normalized_axes = onnx_import::common::get_monotonic_range<std::size_t>( NGRAPH_CHECK(input_rank.is_static(),
node.get_ng_inputs().at(0)->get_shape().size()); "The input tensor's rank needs to be known(static) when the "
"'axes' attribute is not specified. Node: ",
node.get_description());
normalized_axes = onnx_import::common::get_monotonic_range<size_t>(
static_cast<size_t>(input_rank));
} }
return AxisSet{normalized_axes}; return AxisSet{normalized_axes};
} }
...@@ -84,17 +91,21 @@ namespace ngraph ...@@ -84,17 +91,21 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& ng_input, const std::shared_ptr<ngraph::Node>& ng_input,
RuntimeReductionFunction reduction_function) RuntimeReductionFunction reduction_function)
{ {
auto data_shape = ng_input->get_shape(); const auto data_ps = node.get_ng_inputs().at(0)->get_output_partial_shape(0);
NGRAPH_CHECK(data_ps.rank().is_static(),
"Reduction operations input rank is required to be static");
auto reduction_axes = detail::get_reduction_axes(node); const auto data_rank = static_cast<size_t>(data_ps.rank());
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size()) const auto reduction_axes = detail::get_reduction_axes(node);
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_rank)
<< "provided reduction axes count (" << reduction_axes.size() << "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_shape.size() << ")"; << ") is larger than input tensor rank (" << data_rank << ")";
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1); std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
std::shared_ptr<ngraph::Node> op_node = reduction_function( const auto op_node = reduction_function(
ng_input, ng_input,
std::make_shared<ngraph::op::Constant>(element::i64, std::make_shared<ngraph::op::Constant>(element::i64,
ngraph::Shape{reduction_axes.size()}, ngraph::Shape{reduction_axes.size()},
......
...@@ -42,9 +42,11 @@ namespace ngraph ...@@ -42,9 +42,11 @@ namespace ngraph
} // namespace detail } // namespace detail
// An overload for reduction operators that take reduction axes as input
using RuntimeReductionFunction = std::function<std::shared_ptr<ngraph::Node>( using RuntimeReductionFunction = std::function<std::shared_ptr<ngraph::Node>(
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&, bool)>; const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&, bool)>;
// An overload for reduction operators that take reduction axes as an attribute
using ReductionFunction = std::function<std::shared_ptr<ngraph::Node>( using ReductionFunction = std::function<std::shared_ptr<ngraph::Node>(
const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&)>; const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&)>;
......
...@@ -502,3 +502,4 @@ model_lstm_fwd_hardsigmoid_activation ...@@ -502,3 +502,4 @@ model_lstm_fwd_hardsigmoid_activation
model_acosh model_acosh
model_asinh model_asinh
model_atanh model_atanh
model_conv_with_dynamic_batch
ir_version: 3
producer_name: "nGraph ONNX Importer"
model_version: 1
graph {
node {
name: "dyn_conv"
input: "data"
input: "filters"
input: "bias"
output: "dyn_conv_out"
op_type: "Conv"
}
input {
name: "data"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 3
}
dim {
dim_value: 7
}
dim {
dim_value: 7
}
}
}
}
}
input {
name: "filters"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 10
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 10
}
}
}
}
}
output {
name: "dyn_conv_out"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch"
}
dim {
dim_value: 10
}
dim {
dim_value: 6
}
dim {
dim_value: 6
}
}
}
}
}
name: "simple_dyn_shapes_graph"
}
opset_import {
domain: ""
version: 7
}
...@@ -259,3 +259,26 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_atanh_3_2) ...@@ -259,3 +259,26 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_atanh_3_2)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_conv_with_dynamic_batch)
{
const auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/conv_with_dynamic_batch.prototxt"));
auto test_case = NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
const auto data_shape = Shape{1, 3, 7, 7};
const auto filters_shape = Shape{10, 3, 2, 2};
const auto data_elems = shape_size(data_shape);
const auto filters_elems = shape_size(filters_shape);
test_case.add_input<int64_t>(data_shape, std::vector<int64_t>(data_elems, 1));
test_case.add_input<int64_t>(filters_shape, std::vector<int64_t>(filters_elems, 1));
test_case.add_input<int64_t>(Shape{10}, std::vector<int64_t>(10, 1));
const auto expected_out_shape = Shape{1, 10, 6, 6};
const std::vector<int64_t> expected_values(shape_size(expected_out_shape), 13);
test_case.add_expected_output<int64_t>(expected_out_shape, expected_values);
test_case.run();
}
...@@ -361,14 +361,18 @@ TEST(type_prop, broadcast_v1_axes_wrong_rank) ...@@ -361,14 +361,18 @@ TEST(type_prop, broadcast_v1_axes_wrong_rank)
} }
} }
TEST(type_prop, broadcast_v1_output_partial_shape_dynamic) TEST(type_prop, broadcast_v1_fully_dynamic_target_shape)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4}); auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1}); auto bc_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2}); auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes); auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic()); ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
} }
TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong) TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment