Commit 46ff13c7 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

Add ScatterND FusedOp (#4018)

* Added scatterND op to ONNX importer.

* Added ScatterND FusedOp.

* Removed unnecessary files.

* Added op to config files.

* Changed input order.

* Fixed validation checking.

* Added suport for int64 in ScatterNDAdd op.

* Changed test.

* Disabled test for plaidML.

* Code refactoring

* Added tests.
parent 6e38579b
......@@ -407,6 +407,8 @@ set (SRC
op/fused/rnn_cell.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/scatter_nd.cpp
op/fused/scatter_nd.hpp
op/fused/selu.cpp
op/fused/selu.hpp
op/fused/shuffle_channels.cpp
......
......@@ -169,6 +169,8 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/scatter_nd.cpp
op/scatter_nd.hpp
op/selu.cpp
op/selu.hpp
op/shape.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/op/fused/scatter_nd.hpp"
#include "scatter_nd.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector scatter_nd(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
auto updates = ng_inputs.at(2);
return {std::make_shared<ngraph::op::ScatterND>(data, indices, updates)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector scatter_nd(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -101,6 +101,7 @@
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/reverse_sequence.hpp"
#include "op/scatter_nd.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
#include "op/shrink.hpp"
......@@ -333,6 +334,7 @@ namespace ngraph
REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence);
REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Shrink", 1, shrink);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/scatter_nd.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v0::ScatterND::type_info;
op::v0::ScatterND::ScatterND(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates)
: op::util::FusedOp({data, indices, updates})
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::ScatterND::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ScatterND>(new_args.at(0), new_args.at(1), new_args.at(2));
}
void op::v0::ScatterND::pre_validate_and_infer_types()
{
const static int DATA = 0;
const static int INDICES = 1;
const static int UPDATES = 2;
element::Type data_et = input_value(DATA).get_element_type();
element::Type indices_et = input_value(INDICES).get_element_type();
element::Type updates_et = input_value(UPDATES).get_element_type();
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32.");
NODE_VALIDATION_CHECK(this,
data_et == updates_et,
"Updates element type must be the same as element type of data.");
const PartialShape& data_ps = get_input_partial_shape(DATA);
const PartialShape& indices_ps = get_input_partial_shape(INDICES);
const PartialShape& updates_ps = get_input_partial_shape(UPDATES);
if (data_ps.rank().is_static())
{
const size_t data_rank = static_cast<size_t>(data_ps.rank());
NODE_VALIDATION_CHECK(this, data_rank >= 1, "Data rank is expected to be at least 1.");
}
if (indices_ps.rank().is_static())
{
const size_t indices_rank = static_cast<size_t>(indices_ps.rank());
NODE_VALIDATION_CHECK(
this, indices_rank >= 1, "Indices rank is expected to be at least 1.");
}
if (indices_ps.rank().is_static() && data_ps.rank().is_static())
{
const size_t indices_rank = static_cast<size_t>(indices_ps.rank());
const size_t last_dim_pos = indices_rank - 1;
const Dimension indices_last_dim = indices_ps[last_dim_pos];
if (indices_last_dim.is_static())
{
const size_t indices_last_dim_value = static_cast<size_t>(indices_last_dim);
const size_t data_rank = static_cast<size_t>(data_ps.rank());
NODE_VALIDATION_CHECK(this,
indices_last_dim_value <= data_rank,
"Last dimension of indices can be at most the rank of data.");
if (updates_ps.rank().is_static())
{
const size_t expected_updates_rank =
data_rank + indices_rank - indices_last_dim_value - 1;
NODE_VALIDATION_CHECK(
this,
static_cast<size_t>(updates_ps.rank()) == expected_updates_rank,
"Updates rank is expected to be equal data_rank + indices_rank - "
"indices_shape[-1] - 1.");
}
}
}
set_output_type(0, data_et, data_ps);
}
NodeVector op::ScatterND::decompose_op() const
{
const auto data = input_value(0);
const auto indices = input_value(1);
const auto updates = input_value(2);
const Shape& data_shape = data.get_shape();
const Shape& updates_shape = updates.get_shape();
element::Type data_et = data.get_element_type();
// Create a boolean mask that matches the data tensor shape and
// contains 'true' values in the positions indicated by 'indices'
// and 'false' values everywhere else.
const auto true_values = op::Constant::create(element::i64, updates_shape, {1});
const auto false_values = op::Constant::create(element::i64, data_shape, {0});
const auto mask = std::make_shared<op::v0::ScatterNDAdd>(false_values, indices, true_values);
const auto mask_bool = std::make_shared<op::v0::Convert>(mask, element::boolean);
const auto zeros = op::Constant::create(data_et, data_shape, {0});
// Create an intermediate node that will contain the original data and
// zeros in the positions indicated by indices.
const auto intermediate = std::make_shared<op::v0::Select>(mask_bool, zeros, data);
return {std::make_shared<op::v0::ScatterNDAdd>(intermediate, indices, updates)};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Replace values within provided tensor by `updates` according to `indices`.
class NGRAPH_API ScatterND : public op::util::FusedOp
{
public:
static constexpr NodeTypeInfo type_info{"ScatterND", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
ScatterND() = default;
/// \param data The tensor whithn slice-values will be updated
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates The tensor of replacement-slice-values
ScatterND(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& updates);
void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
using v0::ScatterND;
}
}
......@@ -209,6 +209,7 @@ NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(ScalarConstantLike, ngraph::op, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
NGRAPH_OP(ScatterND, ngraph::op::v0, 0)
NGRAPH_OP(ScatterNDAdd, ngraph::op::v0, 0)
NGRAPH_OP(Select, ngraph::op::v0, 0)
NGRAPH_OP(Select, ngraph::op::v1, 1)
......
......@@ -110,6 +110,7 @@
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/scatter_nd.hpp"
#include "ngraph/op/fused/selu.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/op/fused/softmax_crossentropy.hpp"
......
......@@ -177,6 +177,7 @@ NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterND, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)
......
......@@ -149,6 +149,55 @@ namespace ngraph
};
}
}
else if (element_type == element::i64)
{
if (is_int64)
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
ngraph::runtime::reference::scatter_nd_add<int64_t, int64_t>(
static_cast<int64_t*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[updates_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
else
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
ngraph::runtime::reference::scatter_nd_add<int64_t, int32_t>(
static_cast<int64_t*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[updates_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ScatterNDAdd");
......
......@@ -1637,6 +1637,7 @@ private:
case OP_TYPEID::Reciprocal:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ScatterND:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SoftmaxCrossEntropy:
......
......@@ -1880,6 +1880,7 @@ private:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ScatterND:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SoftmaxCrossEntropy:
......
......@@ -269,6 +269,7 @@ model_global_lp_pool_p3
model_argmin_no_keepdims
model_reduce_log_sum_exp
model_elu
model_scatterND
model_selu
model_sigmoid
model_softplus
......
......@@ -2624,6 +2624,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::ScatterAdd>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterND:
{
node = make_shared<op::ScatterND>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterNDAdd:
{
node = make_shared<op::ScatterNDAdd>(args[0], args[1], args[2]);
......@@ -4386,6 +4391,8 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::ScatterAdd: { break;
}
case OP_TYPEID::ScatterND: { break;
}
case OP_TYPEID::ScatterNDAdd: { break;
}
case OP_TYPEID::Select: { break;
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "i"
input: "u"
output: "y"
op_type: "ScatterND"
}
name: "test_scatterND"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
input {
name: "i"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "u"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
version: 7
}
......@@ -1777,6 +1777,19 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_mod)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_scatterND)
{
const auto scatterND_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/scatter_nd.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(scatterND_fn, "${BACKEND_NAME}");
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
test_case.add_input<int64_t>({4, 3, 1, 7});
test_case.add_input<float>({9.f, 10.f, 11.f, 12.f});
test_case.add_expected_output<float>(Shape{8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_gatherND_int32)
{
const auto gatherND_fn = onnx_import::import_onnx_model(
......
......@@ -179,3 +179,85 @@ TEST(type_prop, scatter_nd_add_fail_updates_shape)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_fail_updates_element_type)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
try
{
auto G = make_shared<op::ScatterND>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Created ScatterND op with incorrect updates element type.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Updates element type must be the same as element type of data."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_fail_updates_rank)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterND>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Created ScatterND op with incorrect updates rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Updates rank is expected to be equal data_rank + indices_rank - "
"indices_shape[-1] - 1."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_fail_updates_shape)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{4};
Shape updates_shape{2};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterND>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Created ScatterND op with incorrect indices shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Last dimension of indices can be at most the rank of data."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment