Unverified Commit 72bf9831 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[Fused] Unsqueeze op (#2916)

parent b2ca3e79
...@@ -312,6 +312,8 @@ set (SRC ...@@ -312,6 +312,8 @@ set (SRC
op/fused/space_to_depth.hpp op/fused/space_to_depth.hpp
op/fused/squeeze.cpp op/fused/squeeze.cpp
op/fused/squeeze.hpp op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/arithmetic_reduction.cpp op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp op/util/binary_elementwise_arithmetic.cpp
......
...@@ -14,16 +14,9 @@ ...@@ -14,16 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <numeric> #include "ngraph/op/fused/unsqueeze.hpp"
#include <set> #include "ngraph/op/constant.hpp"
#include <vector> #include "squeeze.hpp"
#include "exceptions.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/util.hpp"
#include "unsqueeze.hpp"
#include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,31 +28,11 @@ namespace ngraph ...@@ -35,31 +28,11 @@ namespace ngraph
{ {
NodeVector unsqueeze(const Node& node) NodeVector unsqueeze(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; auto data = node.get_ng_inputs().at(0);
auto data = inputs.at(0); auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
auto data_shape = data->get_shape(); auto axes_node = std::make_shared<ngraph::op::Constant>(
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes"); element::i64, Shape{axes.size()}, axes);
return {std::make_shared<ngraph::op::Unsqueeze>(data, axes_node)};
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
ASSERT_VALID_ARGUMENT(
node,
axes.size() ==
std::set<std::int64_t>(std::begin(axes), std::end(axes)).size())
<< "'axes' has a duplicate axis.";
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
for (auto axis : axes)
{
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid.";
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -109,6 +109,7 @@ ...@@ -109,6 +109,7 @@
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <functional>
#include <iterator>
#include <set>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
using namespace ngraph;
op::Unsqueeze::Unsqueeze(const shared_ptr<Node>& data, const shared_ptr<Node>& axes)
: FusedOp("Unsqueeze", {data, axes})
{
constructor_validate_and_infer_types();
}
void op::Unsqueeze::pre_validate_and_infer_types()
{
auto axes_node = get_argument(1);
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
}
NodeVector op::Unsqueeze::decompose_op() const
{
auto data = get_argument(0);
auto axes_node = get_argument(1);
// Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data->get_shape();
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this,
axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis.");
sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
for (auto axis : axes)
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis <= data_shape.size(),
"provided 'axes' value ",
axis,
" is not valid.");
data_shape.insert(next(begin(data_shape), axis), 1);
}
return {make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Unsqueeze>(new_args.at(0), new_args.at(1));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
class Unsqueeze : public ngraph::op::util::FusedOp
{
public:
Unsqueeze(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& axes);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
...@@ -33,3 +33,4 @@ NGRAPH_OP(PRelu, ngraph::op) ...@@ -33,3 +33,4 @@ NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op) NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op) NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op) NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
...@@ -90,6 +90,7 @@ ...@@ -90,6 +90,7 @@
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
...@@ -2079,6 +2080,7 @@ shared_ptr<runtime::Executable> ...@@ -2079,6 +2080,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::StopGradient: case OP_TYPEID::StopGradient:
case OP_TYPEID::Tile: case OP_TYPEID::Tile:
case OP_TYPEID::Transpose: case OP_TYPEID::Transpose:
case OP_TYPEID::Unsqueeze:
default: default:
{ {
throw unsupported_op("Unsupported op '" + op->description() + throw unsupported_op("Unsupported op '" + op->description() +
...@@ -2171,7 +2173,8 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node) ...@@ -2171,7 +2173,8 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
case OP_TYPEID::ScaleShift: case OP_TYPEID::ScaleShift:
case OP_TYPEID::SpaceToDepth: case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::Squeeze: { return false; case OP_TYPEID::Squeeze:
case OP_TYPEID::Unsqueeze: { return false;
} }
default: { return true; default: { return true;
} }
......
...@@ -80,6 +80,7 @@ ...@@ -80,6 +80,7 @@
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp" #include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -1501,6 +1502,11 @@ static shared_ptr<ngraph::Function> ...@@ -1501,6 +1502,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::StopGradient>(args[0]); node = make_shared<op::StopGradient>(args[0]);
break; break;
} }
case OP_TYPEID::Unsqueeze:
{
node = make_shared<op::Unsqueeze>(args[0], args[1]);
break;
}
case OP_TYPEID::UnknownOp: case OP_TYPEID::UnknownOp:
{ {
stringstream ss; stringstream ss;
...@@ -2227,6 +2233,8 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2227,6 +2233,8 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::Transpose: { break; case OP_TYPEID::Transpose: { break;
} }
case OP_TYPEID::Unsqueeze: { break;
}
case OP_TYPEID::UnknownOp: { break; case OP_TYPEID::UnknownOp: { break;
} }
} }
......
...@@ -773,6 +773,21 @@ NGRAPH_TEST(${BACKEND_NAME}, grn_2d_with_bias) ...@@ -773,6 +773,21 @@ NGRAPH_TEST(${BACKEND_NAME}, grn_2d_with_bias)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, unsqueeze)
{
auto data_node = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(data_node, axes_node);
auto function = make_shared<Function>(NodeVector{squeeze}, ParameterVector{data_node});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
auto data = vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
test_case.add_input(data);
test_case.add_expected_output<float>(Shape{4, 1, 1, 2}, data);
}
NGRAPH_TEST(${BACKEND_NAME}, scale_shift_no_broadcast) NGRAPH_TEST(${BACKEND_NAME}, scale_shift_no_broadcast)
{ {
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6}); auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
......
...@@ -14495,6 +14495,17 @@ TEST(type_prop, fused_clamp) ...@@ -14495,6 +14495,17 @@ TEST(type_prop, fused_clamp)
EXPECT_EQ(clamp->get_shape(), (Shape{2, 2})); EXPECT_EQ(clamp->get_shape(), (Shape{2, 2}));
} }
TEST(type_prop, unsqueeze)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
ASSERT_EQ(squeeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
}
TEST(type_prop, scale_shift_no_broadcast) TEST(type_prop, scale_shift_no_broadcast)
{ {
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6}); auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment