Commit 0f2734dc authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Sang Ik Lee

Use v1::Reshape in ONNX GlobalLpPool (#4090)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 7e319f95
...@@ -24,9 +24,6 @@ ...@@ -24,9 +24,6 @@
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -40,10 +37,10 @@ namespace ngraph ...@@ -40,10 +37,10 @@ namespace ngraph
{ {
NodeVector global_lp_pool(const Node& node) NodeVector global_lp_pool(const Node& node)
{ {
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)}; const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::size_t channel_axis{1}; const std::size_t channel_axis{1};
std::size_t channels_count = data->get_shape().at(channel_axis); const std::size_t channels_count = data->get_shape().at(channel_axis);
std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)}; const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
ASSERT_VALID_ARGUMENT(node, p_norm >= 0) ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute."; << "Only positive (including zero) values are supported for 'p' attribute.";
...@@ -63,10 +60,12 @@ namespace ngraph ...@@ -63,10 +60,12 @@ namespace ngraph
// output shape is all ones except N channel // output shape is all ones except N channel
Shape output_shape(orig_shape.size(), 1); Shape output_shape(orig_shape.size(), 1);
output_shape.at(0) = orig_shape.at(0); output_shape.at(0) = orig_shape.at(0);
slice = std::make_shared<ngraph::opset0::Reshape>(
slice, const auto reshape_pattern = default_opset::Constant::create(
ngraph::get_default_order(slice->get_shape().size()), element::i64, Shape{output_shape.size()}, output_shape);
output_shape);
slice =
std::make_shared<default_opset::Reshape>(slice, reshape_pattern, false);
} }
return {std::make_shared<default_opset::Concat>(slices, channel_axis)}; return {std::make_shared<default_opset::Concat>(slices, channel_axis)};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment