Commit 0f2734dc authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Sang Ik Lee

Use v1::Reshape in ONNX GlobalLpPool (#4090)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 7e319f95
......@@ -24,9 +24,6 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/util.hpp"
#include "utils/common.hpp"
......@@ -40,10 +37,10 @@ namespace ngraph
{
NodeVector global_lp_pool(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::size_t channel_axis{1};
std::size_t channels_count = data->get_shape().at(channel_axis);
std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
const std::size_t channel_axis{1};
const std::size_t channels_count = data->get_shape().at(channel_axis);
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute.";
......@@ -63,10 +60,12 @@ namespace ngraph
// output shape is all ones except N channel
Shape output_shape(orig_shape.size(), 1);
output_shape.at(0) = orig_shape.at(0);
slice = std::make_shared<ngraph::opset0::Reshape>(
slice,
ngraph::get_default_order(slice->get_shape().size()),
output_shape);
const auto reshape_pattern = default_opset::Constant::create(
element::i64, Shape{output_shape.size()}, output_shape);
slice =
std::make_shared<default_opset::Reshape>(slice, reshape_pattern, false);
}
return {std::make_shared<default_opset::Concat>(slices, channel_axis)};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment