Commit b26a53e2 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Add Expand op to ONNX importer. (#3692)

* [ONNX] Added Expand op to ONNX importer.

* Added support only for static broadcating.

* Changed version of set from 8 to 1.

* Added test for expand op.
parent 8d1e2196
......@@ -86,6 +86,8 @@ add_library(onnx_import STATIC
op/equal.hpp
op/erf.hpp
op/exp.hpp
op/expand.hpp
op/expand.cpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include "expand.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector expand(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
const std::shared_ptr<ngraph::Node> shape{node.get_ng_inputs().at(1)};
NGRAPH_CHECK(shape->is_constant(),
"Ngraph does not support dynamic braodcasting for Expand op.");
std::vector<std::size_t> shape_vector =
ngraph::as_type_ptr<ngraph::op::Constant>(shape)->get_vector<std::size_t>();
const ngraph::Shape shape_shape{shape_vector};
return {ngraph::op::numpy_style_broadcast(data, shape_shape)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
// Expand operator has been available since version 8 of the default ONNX operator set.
// Currently, Expand is assigned to version 1 due to temporary reason.
{
NodeVector expand(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -53,6 +53,7 @@
#include "op/equal.hpp"
#include "op/erf.hpp"
#include "op/exp.hpp"
#include "op/expand.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp"
#include "op/floor.hpp"
......@@ -264,6 +265,7 @@ namespace ngraph
REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Erf", 1, erf);
REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("Expand", 1, expand);
REGISTER_OPERATOR("EyeLike", 1, eye_like);
REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor);
......
......@@ -21,10 +21,3 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
# RandomUniform not supported in CPU backend
random_uniform_all_static_seed_unused
random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
input: "const_shape"
output: "expanded"
name: "expand_1"
op_type: "Expand"
}
name: "expand test"
initializer {
dims: 3
data_type: 7
int64_data: 2
int64_data: 1
int64_data: 6
name: "const_shape"
}
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "expanded"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
dim {
dim_value: 6
}
}
}
}
}
}
opset_import {
version: 1
}
\ No newline at end of file
......@@ -30,6 +30,7 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
......@@ -419,3 +420,19 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_split_variable_parts_2d)
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i]));
}
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_expand_static_shape)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/expand_static_shape.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// input data shape (3,1)
test_case.add_input(std::vector<float>{1, 2, 3});
test_case.add_expected_output<float>(Shape{2, 3, 6},
{1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment