Commit ac17d797 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use Eigen kernel for ScatterAdd. (#3002)

*  Use Eigen kernel for ScatterAdd.

* Emit ScatterAdd Eigen kernel for CODEGEN.

Add comments.

Address PR feedback.

* Add more unit tests.

Fix style error.

Add ScatterAdd to AssignOpMap.

* Combine non-scalar cases together.

* Address PR feedback.

* Fix a bug.

* Use reshape to make the shapes of two slices match.

* Rename variables.

Fix bugs.

Use helper function.

Add one unit test.

* Add reshape back.
parent 743fcb47
......@@ -204,6 +204,73 @@
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
}
#define SELECT_RANK1(KV, ET, R1, R2, K) \
if (R1 == 1) \
KV = K<ET, 1, R2>; \
else if (R1 == 2) \
KV = K<ET, 2, R2>; \
else if (R1 == 3) \
KV = K<ET, 3, R2>; \
else if (R1 == 4) \
KV = K<ET, 4, R2>; \
else if (R1 == 5) \
KV = K<ET, 5, R2>; \
else if (R1 == 6) \
KV = K<ET, 6, R2>; \
else if (R1 == 7) \
KV = K<ET, 7, R2>; \
else \
throw ngraph_error("Unsupported first rank " + std::to_string(R1) + " for kernel " #K);
#define SELECT_2RANKS(KV, ET, R1, R2, K) \
if (R2 == 1) \
{ \
SELECT_RANK1(KV, ET, R1, 1, K); \
} \
else if (R2 == 2) \
{ \
SELECT_RANK1(KV, ET, R1, 2, K); \
} \
else if (R2 == 3) \
{ \
SELECT_RANK1(KV, ET, R1, 3, K); \
} \
else if (R2 == 4) \
{ \
SELECT_RANK1(KV, ET, R1, 4, K); \
} \
else if (R2 == 5) \
{ \
SELECT_RANK1(KV, ET, R1, 5, K); \
} \
else if (R2 == 6) \
{ \
SELECT_RANK1(KV, ET, R1, 6, K); \
} \
else if (R2 == 7) \
{ \
SELECT_RANK1(KV, ET, R1, 7, K); \
} \
else \
{ \
throw ngraph_error("Unsupported second rank " + std::to_string(R2) + " for kernel " #K); \
}
// Per-type and ranks kernel macro
#define SELECT_KERNEL_BY_2RANKS(KV, ET, R1, R2, K) \
if (ET == element::f32) \
{ \
SELECT_2RANKS(KV, float, R1, R2, K); \
} \
else if (ET == element::f64) \
{ \
SELECT_2RANKS(KV, double, R1, R2, K); \
} \
else \
{ \
throw ngraph_error("Unsupported element type " + ET.c_type_string() + " for kernel " #K); \
}
// Helper macros for a partial set of element types and ranks
// Useful for keeping compilation time and memory usage reasonable
// when the computed expression is complex
......
......@@ -1858,16 +1858,34 @@ namespace ngraph
}
writer.block_begin();
writer << "reference::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(args[2].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
if (args[0].get_element_type() == element::f64 ||
args[0].get_element_type() == element::f32)
{
writer << "cpu::kernel::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ", "
<< args[0].get_shape().size() << ", " << args[2].get_shape().size()
<< ">(" << args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(args[2].get_shape()) << "},\n";
writer << " 0);\n";
}
else
{
writer << "reference::scatter_add<" << args[0].get_type() << ", "
<< args[1].get_element_type().c_type_string() << ">("
<< args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {" << join(args[2].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
}
writer.block_end();
}
......@@ -2003,8 +2021,8 @@ namespace ngraph
<< "auto pos_raw = " << emit_vector(args[0]) << "(0, 0);\n"
<< "if (floor(pos_raw) != pos_raw)\n";
writer.block_begin();
writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer << "throw(std::range_error(\"One-hot: non-integral value in "
"input\"));\n";
writer.block_end();
writer << "size_t pos = pos_raw;\n"
......@@ -2031,8 +2049,8 @@ namespace ngraph
writer << "if (floor(pos_raw) != pos_raw)\n";
writer.block_begin();
writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer << "throw(std::range_error(\"One-hot: non-integral value in "
"input\"));\n";
writer.block_end();
writer << "size_t pos = pos_raw;\n";
......@@ -2468,7 +2486,8 @@ namespace ngraph
else
{
throw ngraph_error(
"QuantizedConvolutionBiasAdd is only supported with MKLDNN kernel.");
"QuantizedConvolutionBiasAdd is only supported with MKLDNN "
"kernel.");
}
}
......@@ -2500,7 +2519,8 @@ namespace ngraph
else
{
throw ngraph_error(
"QuantizedConvolutionBiasSignedAdd is only supported with MKLDNN kernel.");
"QuantizedConvolutionBiasSignedAdd is only supported with MKLDNN "
"kernel.");
}
}
......@@ -2682,7 +2702,8 @@ namespace ngraph
else
{
throw ngraph_error(
"ConvolutionBiasBackpropFiltersBias is only supported with MKLDNN kernel.");
"ConvolutionBiasBackpropFiltersBias is only supported with MKLDNN "
"kernel.");
}
}
......
......@@ -224,6 +224,19 @@ namespace ngraph
template <typename ElementType>
void reference_erf(void* arg, void* out, size_t count);
template <typename ElementType,
typename IndicesType,
unsigned int Rank1,
unsigned int Rank2>
void scatter_add(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
int arena);
}
}
}
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
static void
get_leading_indices(const Shape& shape, int index, std::vector<int>& indices)
{
auto rank = shape.size();
std::vector<int> partial_sum(rank);
partial_sum[rank - 1] = 1;
for (int j = rank - 2; j >= 0; j--)
{
partial_sum[j] = partial_sum[j + 1] * shape[j + 1];
}
for (int j = 0; j < rank; j++)
{
indices[j] = index / partial_sum[j];
index = index % partial_sum[j];
}
}
// ScatterAdd is to update bunch of slices of the inputs. The rank of slice is 1 less than the rank of the inputs.
template <typename ElementType,
typename IndicesType,
unsigned int Rank1,
unsigned int Rank2>
void scatter_add(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
int arena)
{
// For Eigen slice op, both parameters (offsets and extents) need to have the same rank.
// Here *_offsets and *_extents have the same rank.
Eigen::array<Eigen::Index, Rank1> in_dims, in_extents, in_offsets;
Eigen::array<Eigen::Index, Rank2> updates_dims, updates_extents,
updates_offsets;
for (int i = 0; i < Rank1; i++)
{
in_extents[i] = in_dims[i] = inputs_shape[i];
in_offsets[i] = 0;
}
in_extents[0] = 1;
for (int i = 0; i < Rank2; i++)
{
updates_extents[i] = updates_dims[i] = updates_shape[i];
updates_offsets[i] = 0;
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank1, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), in_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank1, Eigen::RowMajor>> in(
static_cast<ElementType*>(inputs), in_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank2, Eigen::RowMajor>> up(
static_cast<ElementType*>(updates), updates_dims);
// copy if not in place.
if (inputs != output)
{
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) = in;
}
auto indices_ptr = static_cast<IndicesType*>(indices);
auto indices_rank = indices_shape.size();
if (indices_rank == 0)
{
in_offsets[0] = indices_ptr[0];
out.slice(in_offsets, in_extents)
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) =
out.slice(in_offsets, in_extents) +
up.slice(updates_offsets, updates_extents).reshape(in_extents);
}
else
{
std::vector<int> leading_indices(indices_rank);
for (int i = 0; i < shape_size(indices_shape); i++)
{
in_offsets[0] = indices_ptr[i];
get_leading_indices(indices_shape, i, leading_indices);
for (int j = 0; j < indices_rank; j++)
{
updates_extents[j] = 1;
updates_offsets[j] = leading_indices[j];
}
out.slice(in_offsets, in_extents)
.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(
arena)) =
out.slice(in_offsets, in_extents) +
up.slice(updates_offsets, updates_extents).reshape(in_extents);
}
}
}
template <typename ElementType, unsigned int Rank1, unsigned int Rank2>
void scatter_add_i64(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
int arena)
{
scatter_add<ElementType, int64_t, Rank1, Rank2>(inputs,
indices,
updates,
output,
inputs_shape,
indices_shape,
updates_shape,
arena);
}
template <typename ElementType, unsigned int Rank1, unsigned int Rank2>
void scatter_add_i32(void* inputs,
void* indices,
void* updates,
void* output,
const Shape& inputs_shape,
const Shape& indices_shape,
const Shape& updates_shape,
int arena)
{
scatter_add<ElementType, int32_t, Rank1, Rank2>(inputs,
indices,
updates,
output,
inputs_shape,
indices_shape,
updates_shape,
arena);
}
}
}
}
}
......@@ -47,6 +47,7 @@
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -527,6 +528,21 @@ namespace ngraph
update_slice->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ScatterAdd)
{
auto update_slice = static_cast<ngraph::op::ScatterAdd*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
update_slice->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
......@@ -998,6 +1014,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::GetOutputElement>},
{TI(ngraph::op::DeconvolutionBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::DeconvolutionBias>},
{TI(ngraph::op::ScatterAdd),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ScatterAdd>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
......@@ -162,6 +162,8 @@ mvn_mean_normalization
mvn_mean_normalization_split_channels
mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
scatter_add_4d_indices
scatter_add_3d_indices
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
......
......@@ -78,6 +78,8 @@ mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
scale_shift_no_broadcast
scale_shift
scatter_add_4d_indices
scatter_add_3d_indices
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
......
......@@ -117,6 +117,8 @@ gather_nd_batch_2d_from_3d
gather_scalar_indices_no_axis
gather_scalar_indices
gather_nd_single_indices
scatter_add_4d_indices
scatter_add_3d_indices
scatter_add_2d_indices
scatter_add_1d_indices
scatter_add_scalar_indices
......
......@@ -35,6 +35,94 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_4d_indices)
{
Shape ref_shape{3, 3, 3};
Shape indices_shape{2, 3, 4, 2};
Shape updates_shape{2, 3, 4, 2, 3, 3};
Shape out_shape{3, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5,
6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0,
1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1,
2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u,
vector<float>{
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 17, 34, 51, 68, 85, 102, 119, 136, 17, 34, 51, 68, 85,
102, 119, 136, 153, 0, 17, 34, 51, 68, 85, 102, 119, 136}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_3d_indices)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2, 2, 2};
Shape updates_shape{2, 2, 2, 3, 3};
Shape out_shape{2, 3, 3};
auto R = make_shared<op::Parameter>(element::f32, ref_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto f =
make_shared<Function>(make_shared<op::GetOutputElement>(G, 0), ParameterVector{R, I, U});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto r = backend->create_tensor(element::f32, ref_shape);
copy_data(r, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0, 0, 1, 1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 5, 10, 15, 20, 25, 30, 35, 40, 5, 10, 15, 20, 25, 30, 35, 40, 45}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, scatter_add_2d_indices)
{
Shape ref_shape{2, 3, 3};
......
......@@ -1729,3 +1729,73 @@ TEST(cpu_test, avg_pool_bprop_2d_2channel_2image)
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, scatter_add_1d_indices_in_place)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2};
Shape updates_shape{2, 3, 3};
Shape out_shape{2, 3, 3};
auto R1 = make_shared<op::Parameter>(element::f32, ref_shape);
auto R2 = make_shared<op::Parameter>(element::f32, ref_shape);
auto R = make_shared<op::Add>(R1, R2);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto add = make_shared<op::Add>(G, R2);
auto f = make_shared<Function>(add, ParameterVector{R1, R2, I, U});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto r1 = backend->create_tensor(element::f32, ref_shape);
copy_data(r1, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto r2 = backend->create_tensor(element::f32, ref_shape);
copy_data(r2, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r1, r2, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 4, 8, 12, 16, 20, 24, 28, 32, 4, 8, 12, 16, 20, 24, 28, 32, 36}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, scatter_add_1d_indices_no_in_place)
{
Shape ref_shape{2, 3, 3};
Shape indices_shape{2};
Shape updates_shape{2, 3, 3};
Shape out_shape{2, 3, 3};
auto R1 = make_shared<op::Parameter>(element::f32, ref_shape);
auto R2 = make_shared<op::Parameter>(element::f32, ref_shape);
auto R = make_shared<op::Add>(R1, R2);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto U = make_shared<op::Parameter>(element::f32, updates_shape);
auto G = make_shared<op::ScatterAdd>(R, I, U);
auto add = make_shared<op::Add>(G, R);
auto f = make_shared<Function>(add, ParameterVector{R1, R2, I, U});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto r1 = backend->create_tensor(element::f32, ref_shape);
copy_data(r1, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto r2 = backend->create_tensor(element::f32, ref_shape);
copy_data(r2, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto u = backend->create_tensor(element::f32, updates_shape);
copy_data(u, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {r1, r2, i, u});
EXPECT_TRUE(test::all_close_f(
(vector<float>{0, 5, 10, 15, 20, 25, 30, 35, 40, 5, 10, 15, 20, 25, 30, 35, 40, 45}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment