Unverified Commit 9ae0a564 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

[ONNX] Add new ConstantOfShape operator (#4245)

* [ONNX] Add new ConstatntOfShape operator

* Fix a bug in op implementation

* Modify downgrade pass to support broadcast scalars

* Style-fix

* Use at instead of []

* Use onnx helper instead of ngraph builder

* Add some UT

* Style fix

* Limit range of DynBroadcast in downgrade pass

* Move tests to better location

* Rewrite tests to test_case

* Add check if arg_pshape is static

* Style

* Trigger CI
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 87eae9b0
......@@ -72,6 +72,8 @@ add_library(onnx_import STATIC
op/concat.hpp
op/constant.cpp
op/constant.hpp
op/constant_of_shape.cpp
op/constant_of_shape.hpp
op/conv.cpp
op/conv.hpp
op/conv_integer.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant.hpp"
#include "core/tensor.hpp"
#include "default_opset.hpp"
#include "ngraph/op/constant.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector constant_of_shape(const onnx_import::Node& node)
{
std::shared_ptr<ngraph::Node> constant_value;
if (node.has_attribute("value"))
{
auto value_tensor = node.get_attribute_value<Tensor>("value");
constant_value = value_tensor.get_ng_constant();
constant_value = reshape::interpret_as_scalar(constant_value);
}
else
{
constant_value = default_opset::Constant::create(element::f32, {}, {0});
}
return {std::make_shared<default_opset::Broadcast>(constant_value,
node.get_ng_inputs().at(0))};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector constant_of_shape(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -40,6 +40,7 @@
#include "op/clip.hpp"
#include "op/concat.hpp"
#include "op/constant.hpp"
#include "op/constant_of_shape.hpp"
#include "op/conv.hpp"
#include "op/conv_integer.hpp"
#include "op/conv_transpose.hpp"
......@@ -258,6 +259,7 @@ namespace ngraph
REGISTER_OPERATOR("Clip", 11, clip);
REGISTER_OPERATOR("Concat", 1, concat);
REGISTER_OPERATOR("Constant", 1, constant);
REGISTER_OPERATOR("ConstantOfShape", 1, constant_of_shape);
REGISTER_OPERATOR("Conv", 1, conv);
REGISTER_OPERATOR("ConvInteger", 1, conv_integer);
REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose);
......
......@@ -148,49 +148,71 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{
auto arg = node->input_value(0);
const auto& arg_shape = arg.get_shape();
auto arg_pshape = arg.get_partial_shape();
auto arg_rank = arg_pshape.rank();
auto target_shape_input = node->input_value(1);
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape = node->output(0).get_shape();
NGRAPH_CHECK(node->get_broadcast_axes().first);
// (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{
ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
// the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
std::vector<size_t> empty_axes;
for (size_t a{0}; a < axes_mapping.size(); ++a)
shared_ptr<Node> replacement_node;
if (arg_rank.is_static() && static_cast<size_t>(arg_rank) == 0 &&
!target_shape_input.get_node_shared_ptr()->is_constant())
{
if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
{
empty_axes.push_back(a);
}
replacement_node = make_shared<op::DynBroadcast>(
arg,
target_shape_input,
make_shared<op::Range>(make_zero(element::i64, {}),
make_shared<op::ShapeOf>(target_shape_input),
make_constant_from_string("1", element::i64, {})));
}
// Check if arg_shape contains some more empty dimensions marked to broadcast.
// If axes_mapping size is less than arg_shape size, then some of arg dimensions may
// be equal to one and marked to broadcast.
if (axes_mapping.size() < arg_shape.size())
else
{
for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
NGRAPH_CHECK(arg_pshape.is_static(),
"Unable to convert Broadcast:v1 to Broadcast:v0 "
"if argument shape is not static. Node: ",
*node);
const auto& arg_shape = arg_pshape.to_shape();
NGRAPH_CHECK(target_shape_input.get_node_shared_ptr()->is_constant());
auto target_shape = node->output(0).get_shape();
NGRAPH_CHECK(node->get_broadcast_axes().first);
// (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{
ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
// the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
std::vector<size_t> empty_axes;
for (size_t a{0}; a < axes_mapping.size(); ++a)
{
if (arg_shape.at(a) == 1)
if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
{
empty_axes.push_back(a);
}
}
}
if (!empty_axes.empty())
{
squeezed_arg = builder::squeeze(arg, empty_axes);
}
auto replacement_node =
make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
// Check if arg_shape contains some more empty dimensions marked to broadcast.
// If axes_mapping size is less than arg_shape size, then some of arg dimensions may
// be equal to one and marked to broadcast.
if (axes_mapping.size() < arg_shape.size())
{
for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
{
if (arg_shape.at(a) == 1)
{
empty_axes.push_back(a);
}
}
}
if (!empty_axes.empty())
{
squeezed_arg = builder::squeeze(arg, empty_axes);
}
replacement_node =
make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
}
replace_node(node, replacement_node);
return replacement_node;
}
......
ir_version: 4
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 1
float_data: 0
name: "value"
}
type: TENSOR
}
}
name: "test_constantofshape_float_ones"
input {
name: "x"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
ir_version: 4
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 6
int32_data: 1
name: "value"
}
type: TENSOR
}
}
name: "test_constantofshape_int_zeros"
input {
name: "x"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 9
}
......@@ -362,3 +362,33 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, global_max_pool_dyn_shape)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_constant_of_shape_float_zeros)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/constant_of_shape_float_zeros.prototxt"));
std::vector<float> expected_values(24, 0);
auto test_case = NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
test_case.add_input<int64_t>(Shape{3}, std::vector<int64_t>{2, 3, 4});
test_case.add_expected_output<float>(Shape{2, 3, 4}, expected_values);
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_constant_of_shape_int_ones)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/constant_of_shape_int_ones.prototxt"));
std::vector<int32_t> expected_values(6, 1);
auto test_case = NgraphTestCase(function, "${BACKEND_NAME}", BackendMode::DYNAMIC);
test_case.add_input<int64_t>(Shape{2}, std::vector<int64_t>{2, 3});
test_case.add_expected_output<int32_t>(Shape{2, 3}, expected_values);
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment