Commit 7e6c34cf authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

scatter_add and scatter_nd_add (#2874)

* Temp save.

* Temp save.

* Temp save.

* Temp save.

* Temp save.

* Temp save.

* Temp save.

* Fix compile errors.

* Fix incorrect index.

* Fix UT typo.

* Interpreter passes UT.

* Fix more bugs.

* Apply style.

* Add shape check for updates tensor.

* Merge typo
parent b94a042d
......@@ -252,6 +252,10 @@ set (SRC
op/reverse.hpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/scatter_add.cpp
op/scatter_add.hpp
op/scatter_nd_add.cpp
op/scatter_nd_add.hpp
op/select.cpp
op/select.hpp
op/sigmoid.cpp
......
......@@ -138,6 +138,8 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
......
......@@ -135,6 +135,8 @@ NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(ScalarConstantLike, ngraph::op)
NGRAPH_OP(ScatterAdd, ngraph::op)
NGRAPH_OP(ScatterNDAdd, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(ShapeOf, ngraph::op)
NGRAPH_OP(Sigmoid, ngraph::op)
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
static int INPUTS = 0;
static int INDICES = 1;
static int UPDATES = 2;
shared_ptr<Node> op::ScatterAdd::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ScatterAdd>(new_args.at(INPUTS), new_args.at(INDICES), new_args.at(UPDATES));
}
void op::ScatterAdd::validate_and_infer_types()
{
element::Type inputs_et = get_input_element_type(INPUTS);
element::Type indices_et = get_input_element_type(INDICES);
element::Type updates_et = get_input_element_type(UPDATES);
const PartialShape& inputs_shape = get_input_partial_shape(INPUTS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
const PartialShape& updates_shape = get_input_partial_shape(UPDATES);
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32");
NODE_VALIDATION_CHECK(
this, updates_et == inputs_et, "Updates element type must be the same as Inputs");
// updates rank must be at indices rank + inputs rank - 1
NODE_VALIDATION_CHECK(this,
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
updates_shape.rank().is_dynamic() ||
static_cast<size_t>(updates_shape.rank()) ==
static_cast<size_t>(indices_shape.rank()) +
static_cast<size_t>(inputs_shape.rank()) - 1,
"Updates rank is expected to be indices rank + inputs rank - 1");
bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
{
for (size_t i = 0; i < static_cast<size_t>(indices_shape.rank()); i++)
{
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
}
for (size_t i = 1; i < static_cast<size_t>(inputs_shape.rank()); i++)
{
compatible =
compatible &&
updates_shape[static_cast<size_t>(indices_shape.rank()) + i - 1].same_scheme(
inputs_shape[i]);
}
}
NODE_VALIDATION_CHECK(
this, compatible, "Updates shape must be indices_shape + inputs_shape[1:]");
set_output_type(0, inputs_et, inputs_shape);
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Add updates to slices from inputs addressed by indices
class ScatterAdd : public Op
{
public:
// \param inputs Tensor
// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
// \param update Tensor: Must have same type as inputs
ScatterAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates)
: Op("ScatterAdd", check_single_output_args({inputs, indices, updates}))
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
{
throw ngraph_error("Not yet implemented");
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
static int INPUTS = 0;
static int INDICES = 1;
static int UPDATES = 2;
shared_ptr<Node> op::ScatterNDAdd::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ScatterNDAdd>(
new_args.at(INPUTS), new_args.at(INDICES), new_args.at(UPDATES));
}
void op::ScatterNDAdd::validate_and_infer_types()
{
element::Type inputs_et = get_input_element_type(INPUTS);
element::Type indices_et = get_input_element_type(INDICES);
element::Type updates_et = get_input_element_type(UPDATES);
const PartialShape& inputs_shape = get_input_partial_shape(INPUTS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
const PartialShape& updates_shape = get_input_partial_shape(UPDATES);
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32");
NODE_VALIDATION_CHECK(
this, updates_et == inputs_et, "Updates element type must be the same as inputs");
NODE_VALIDATION_CHECK(this,
indices_shape.rank().is_dynamic() ||
static_cast<size_t>(indices_shape.rank()) >= 1,
"Indices rank is expected to be at least 1");
NODE_VALIDATION_CHECK(
this,
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]) <=
static_cast<size_t>(inputs_shape.rank()),
"Last dimension of indices can be at most the rank of inputs");
NODE_VALIDATION_CHECK(
this,
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
updates_shape.rank().is_dynamic() ||
static_cast<size_t>(updates_shape.rank()) ==
static_cast<size_t>(indices_shape.rank()) +
static_cast<size_t>(inputs_shape.rank()) -
static_cast<size_t>(
indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]) -
1,
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices - 1");
bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
{
for (size_t i = 0; i < static_cast<size_t>(indices_shape.rank()) - 1; i++)
{
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
}
size_t j =
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]);
for (size_t i = j; i < static_cast<size_t>(inputs_shape.rank()); i++)
{
compatible =
compatible &&
updates_shape[static_cast<size_t>(indices_shape.rank()) + i - 2].same_scheme(
inputs_shape[i]);
}
}
NODE_VALIDATION_CHECK(
this,
compatible,
"Updates shape must be indices_shape[:-1] + inputs_shape[indices.shape[-1]:]");
set_output_type(0, inputs_et, inputs_shape);
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Add updates to slices from inputs addressed by indices
class ScatterNDAdd : public Op
{
public:
// \param inputs Tensor
// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
// \param update Tensor: Must have same type as inputs
ScatterNDAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates)
: Op("ScatterNDAdd", check_single_output_args({inputs, indices, updates}))
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) override
{
throw ngraph_error("Not yet implemented");
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -72,6 +72,8 @@ set(SRC
builder/reverse.cpp
builder/reverse_sequence.cpp
builder/rnn.cpp
builder/scatter_add.cpp
builder/scatter_nd_add.cpp
builder/select.cpp
builder/sigmoid.cpp
builder/slice.cpp
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ScatterAdd)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
auto inputs_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto updates_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto inputs_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto updates_shape = args[2].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_add<float, int64_t>(
static_cast<float*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[updates_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
else
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_add<float, int32_t>(
static_cast<float*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[updates_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_add<double, int64_t>(
static_cast<double*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[updates_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
else
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_add<double, int32_t>(
static_cast<double*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[updates_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ScatterAdd");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(ScatterAdd);
}
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstring>
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ScatterNDAdd)
{
auto& functors = external_function->get_functors();
CPUKernelFunctor functor;
auto inputs_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto updates_buffer_index = external_function->get_buffer_index(args[2].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto inputs_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto updates_shape = args[2].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_nd_add<float, int64_t>(
static_cast<float*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[updates_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
else
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_nd_add<float, int32_t>(
static_cast<float*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[updates_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_nd_add<double, int64_t>(
static_cast<double*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[updates_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
else
{
functor = [&,
inputs_shape,
indices_shape,
updates_shape,
out_shape,
inputs_buffer_index,
indices_buffer_index,
updates_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::scatter_nd_add<double, int32_t>(
static_cast<double*>(ctx->buffer_data[inputs_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[updates_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
inputs_shape,
indices_shape,
updates_shape,
out_shape);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ScatterNDAdd");
}
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(ScatterNDAdd);
}
}
}
......@@ -96,6 +96,8 @@
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -1777,6 +1779,52 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ScatterAdd)
{
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(args[2].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ScatterNDAdd)
{
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
writer.block_begin();
writer << "reference::scatter_nd_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(args[2].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
......
......@@ -113,6 +113,8 @@
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -313,6 +315,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Erf), &runtime::cpu::CPU_Emitter::emit<op::Erf>},
{TI(ngraph::op::Gather), &runtime::cpu::CPU_Emitter::emit<op::Gather>},
{TI(ngraph::op::GatherND), &runtime::cpu::CPU_Emitter::emit<op::GatherND>},
{TI(ngraph::op::ScatterAdd), &runtime::cpu::CPU_Emitter::emit<op::ScatterAdd>},
{TI(ngraph::op::ScatterNDAdd), &runtime::cpu::CPU_Emitter::emit<op::ScatterNDAdd>},
{TI(ngraph::op::GetOutputElement), &runtime::cpu::CPU_Emitter::emit<op::GetOutputElement>},
{TI(ngraph::op::Greater), &runtime::cpu::CPU_Emitter::emit<op::Greater>},
{TI(ngraph::op::GreaterEq), &runtime::cpu::CPU_Emitter::emit<op::GreaterEq>},
......@@ -557,6 +561,8 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/topk.hpp"
......
......@@ -122,6 +122,8 @@
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
......@@ -1165,6 +1167,66 @@ private:
}
break;
}
case OP_TYPEID::ScatterAdd:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::scatter_add<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::scatter_add<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::ScatterNDAdd:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::scatter_nd_add<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::scatter_nd_add<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -109,6 +109,8 @@
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
......@@ -1227,6 +1229,16 @@ std::string runtime::gpu::GPU_Emitter::emit_ScalarConstantLike(EMIT_ARGS)
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_ScatterAdd(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_ScatterNDAdd(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
return emit_elementwise<ngraph::op::Select>(compiled_function, function_name, node, args, out);
......
......@@ -160,3 +160,8 @@ mvn_mean_normalization
mvn_mean_normalization_split_channels
mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
......@@ -2056,6 +2056,8 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ScatterAdd:
case OP_TYPEID::ScatterNDAdd:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::StopGradient:
......
......@@ -64,6 +64,8 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
gemm
gemm_broadcast_input_C
normalize_across_chw_scalar_scale_4d
normalize_across_chw_scalar_scale_3d
normalize_across_chw_scalar_scale_2d
......@@ -79,4 +81,9 @@ mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
scale_shift_no_broadcast
scale_shift
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
zero_sized_erf
......@@ -127,6 +127,8 @@
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
......@@ -1212,6 +1214,66 @@ private:
}
break;
}
case OP_TYPEID::ScatterAdd:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::scatter_add<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::scatter_add<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::ScatterNDAdd:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::scatter_nd_add<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::scatter_nd_add<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
args[2]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -117,3 +117,8 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstring>
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void scatter_add(T* inputs,
U* indices,
T* updates,
T* out,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& out_shape)
{
using namespace std;
// Copy inputs to out
memcpy(out, inputs, sizeof(T) * shape_size(inputs_shape));
// Create a CoordinateTransform for "indices"
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
Coordinate indices_start_corner(indices_ndim, 0);
Coordinate indices_end_corner(indices_shape);
Strides indices_strides(indices_ndim, 1);
AxisVector indices_axis_order(indices_ndim);
iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
CoordinateTransform indices_transform(indices_shape,
indices_start_corner,
indices_end_corner,
indices_strides,
indices_axis_order);
// Create an outer CoordinateTransform for "update"
size_t updates_ndim = static_cast<size_t>(updates_shape.size());
Coordinate updates_outer_start_corner(updates_ndim, 0);
Coordinate updates_outer_end_corner(updates_shape);
for (size_t i = indices_ndim; i < updates_ndim; i++)
{
updates_outer_end_corner[i] = 1;
}
Strides updates_strides(updates_ndim, 1);
AxisVector updates_axis_order(updates_ndim);
iota(updates_axis_order.begin(), updates_axis_order.end(), 0);
CoordinateTransform updates_outer_transform(updates_shape,
updates_outer_start_corner,
updates_outer_end_corner,
updates_strides,
updates_axis_order);
// Common vars for out
size_t out_ndim = static_cast<size_t>(out_shape.size());
Strides out_strides(out_ndim, 1);
AxisVector out_axis_order(out_ndim);
iota(out_axis_order.begin(), out_axis_order.end(), 0);
// Visit one updates silce and one out silce at a time.
auto updates_outer_coord_iter = updates_outer_transform.begin();
for (const Coordinate& indices_coord : indices_transform)
{
auto indices_index = indices_transform.index(indices_coord);
U slice_index = indices[indices_index];
// Create CoordinateTransform for out slice
Coordinate out_start_corner(out_ndim, 0);
Coordinate out_end_corner(out_shape);
out_start_corner[0] = static_cast<size_t>(slice_index);
out_end_corner[0] = out_start_corner[0] + 1;
CoordinateTransform out_transform(
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
// Create CoordinateTransform for updates slice
Coordinate updates_inner_start_corner = *updates_outer_coord_iter;
Coordinate updates_inner_end_corner(updates_shape);
for (size_t i = 0; i < indices_ndim; i++)
{
updates_inner_end_corner[i] = updates_inner_start_corner[i] + 1;
}
CoordinateTransform updates_inner_transform(updates_shape,
updates_inner_start_corner,
updates_inner_end_corner,
updates_strides,
updates_axis_order);
// Add one element from updates to inputs at a time
auto updates_inner_coord_iter = updates_inner_transform.begin();
for (const Coordinate& out_coord : out_transform)
{
out[out_transform.index(out_coord)] +=
updates[updates_inner_transform.index(*updates_inner_coord_iter)];
updates_inner_coord_iter++;
}
updates_outer_coord_iter++;
}
}
}
}
}
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstring>
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T, typename U>
void scatter_nd_add(T* inputs,
U* indices,
T* updates,
T* out,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const Shape& out_shape)
{
using namespace std;
// Copy inputs to out
memcpy(out, inputs, sizeof(T) * shape_size(inputs_shape));
// Create a CoordinateTransform for "indices" that visits only the first element along inner most axis
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
size_t slice_rank = indices_shape[indices_ndim - 1];
indices_outer_end_corner[indices_ndim - 1] = 1;
Strides indices_strides(indices_ndim, 1);
AxisVector indices_axis_order(indices_ndim);
std::iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_strides,
indices_axis_order);
// Create a matching CoordinateTransform for "updates" that visits the same outer coordinates
size_t updates_ndim = static_cast<size_t>(updates_shape.size());
Strides updates_strides(updates_ndim, 1);
AxisVector updates_axis_order(updates_ndim);
std::iota(updates_axis_order.begin(), updates_axis_order.end(), 0);
Coordinate updates_outer_start_corner(updates_ndim, 0);
Coordinate updates_outer_end_corner(updates_shape);
for (size_t i = indices_ndim - 1; i < updates_ndim; i++)
{
updates_outer_end_corner[i] = 1;
}
CoordinateTransform updates_outer_transform(updates_shape,
updates_outer_start_corner,
updates_outer_end_corner,
updates_strides,
updates_axis_order);
// Add an updates slice to a slice on out indexed by innermost dim ofindices
size_t out_ndim = static_cast<size_t>(out_shape.size());
Strides out_strides(out_ndim, 1);
AxisVector out_axis_order(out_ndim);
std::iota(out_axis_order.begin(), out_axis_order.end(), 0);
auto updates_outer_coord_iter = updates_outer_transform.begin();
for (const Coordinate& indices_coord : indices_outer_transform)
{
Coordinate out_start_corner(out_ndim, 0);
Coordinate out_end_corner(out_shape);
auto indices_index = indices_outer_transform.index(indices_coord);
for (size_t i = 0; i < slice_rank; i++)
{
U index = indices[indices_index];
out_start_corner[i] = index;
out_end_corner[i] = index + 1;
indices_index++;
}
CoordinateTransform out_transform(
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
auto updates_index = updates_outer_transform.index(*updates_outer_coord_iter);
for (const Coordinate& out_coord : out_transform)
{
out[out_transform.index(out_coord)] += updates[updates_index];
updates_index++;
}
updates_outer_coord_iter++;
}
}
}
}
}
......@@ -110,6 +110,8 @@
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
......@@ -1371,6 +1373,16 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::ScaleShift>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterAdd:
{
node = make_shared<op::ScatterAdd>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::ScatterNDAdd:
{
node = make_shared<op::ScatterNDAdd>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Select:
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
......@@ -2126,6 +2138,10 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::ScaleShift: { break;
}
case OP_TYPEID::ScatterAdd: { break;
}
case OP_TYPEID::ScatterNDAdd: { break;
}
case OP_TYPEID::Select: { break;
}
case OP_TYPEID::ShapeOf: { break;
......
......@@ -153,6 +153,7 @@ set(MULTI_TEST_SRC
backend_one_hot.in.cpp
backend_pool.in.cpp
backend_reshape.in.cpp
backend_scatter.in.cpp
backend_sum.in.cpp
backend_topk.in.cpp
backend_arg_reduce.in.cpp
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2};
Shape updates_shape{2, 2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 3, 6, 9, 12, 15, 18, 21, 24, 3, 6, 9, 12, 15, 18, 21, 24, 27}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_1d_indices)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2};
Shape updates_shape{2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16, 18}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_scalar_indices)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{};
Shape updates_shape{3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 2, 4, 6, 8, 10, 12, 14, 16, 18}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_nd_add_batch_2d_to_3d)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{2, 1};
Shape updates_shape{2, 3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5,
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 2});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 7, 7, 7, 8, 8, 8, 9, 9, 9});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f((vector<float>{2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 5, 5,
5, 6, 6, 6, 14, 14, 14, 16, 16, 16, 18, 18, 18}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_nd_add_2d_to_3d)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5,
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 1, 1, 2, 2, 2, 3, 3, 3});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f((vector<float>{2, 2, 2, 4, 4, 4, 6, 6, 6, 4, 4, 4, 5, 5,
5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
......@@ -13791,6 +13791,269 @@ TEST(type_prop, gather_nd_fail_indices_element_type)
}
}
TEST(type_prop, scatter_add_fail_indices_element_type)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2};
Shape updates_shape{2, 2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_add_fail_updates_element_type)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2};
Shape updates_shape{2, 2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
try
{
auto G = make_shared<op::ScatterAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Updates element type must be the same as Inputs"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_add_fail_updates_rank)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2};
Shape updates_shape{2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Updates rank is expected to be indices rank + inputs rank - 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_add_fail_updates_shape)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2};
Shape updates_shape{1, 2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Updates shape must be indices_shape + inputs_shape[1:]"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_indices_element_type)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_indices_rank)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{};
Shape updates_shape{3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Indices rank is expected to be at least 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_indices_last_dim)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{2, 4};
Shape updates_shape{2, 3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices innermost dim";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Last dimension of indices can be at most the rank of inputs"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_updates_element_type)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::i32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Updates element type must be the same as inputs"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_updates_rank)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{3, 3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Rank of updates must be rank of inputs + rank of indices "
"- last dimension of indices - 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, scatter_nd_add_fail_updates_shape)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{1};
Shape updates_shape{2, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
try
{
auto G = make_shared<op::ScatterNDAdd>(R, I, U);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect updates shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Updates shape must be indices_shape[:-1] + inputs_shape[indices.shape[-1]:]"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_bias_2d_deduce)
{
// Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment