Unverified Commit 29231e11 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Implement select-and-scatter (#364)

parent d2b081c8
......@@ -22,7 +22,27 @@ namespace ngraph
{
/// \brief Select-and-scatter operation.
///
/// TODO: More formal definition. For now, see: https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter.
/// Select-and-scatter takes three inputs, all of which must have the same element type \f$E\f$:
///
/// 1. the <i>selectee</i>, a tensor of shape \f$(d_1,\dots,d_n)\f$,
/// 2. the <i>source</i>, a tensor of shape \f$(d'_1,\dots,d'_n)\f$ where \f$d'_i = \lceil \frac {d_i - w_i + 1}{s_i} \rceil\f$ (see below for definition of window sizes \f$w_i\f$ and strides \f$s_i\f$, and
/// 3. the <i>initial value</i>, a scalar.
///
/// It also takes four parameters:
///
/// 1. the <i>selection function</i>, a function that takes two arguments of type \f$E[]\f$ and returns `Bool` (think of this as a binary relation),
/// 2. the <i>scatter function</i>, a function that takes two arguments of type \f$E[]\f$ and returns \f$E[]\f$,
/// 3. the <i>window shape</i>, a vector \f$(w_1,\dots,w_n)\f$ of non-negative integers, and
/// 4. the <i>window movement strides</i>, a vector \f$(s_1,\dots,s_n)\f$ of non-negative integers.
///
/// It is assumed that the selection function is a strict total order; otherwise behavior is undefined. (TODO: we may be able to generalize usefully here.)
///
/// The output \f$T_\textit{out}\f$ has the same element type and shape as the selectee. Its values are produced as follows:
/// 1. Initialize every element \f$T_\textit{out}\f$ with the initial value.
/// 2. Slide a window of shape \f$(w_1,\dots,w_n)\f$ over the selectee, with stride \f$(s_1,\dots,s_n)\f$, with the start corner of the window increasing in natural ("row-major") order. Note that for every valid window position, there will be a corresponding value in the source tensor located at some coordinate \f$(i_1,\dots,i_n)\f$.
/// 3. At each window position, using the selection function as the relation, find a coordinate \f$(j_1,\dots,j_n)\f$ where some "maximum" value resides. Replace \f$T_\textit{out}[j_1,\dots,j_n]\f$ with the value \f$f(T_\textit{out}[j_1,\dots,j_n],T_\textit{source}[i_1,\dots,i_n])\f$ where \f$f\f$ is the scatter function and \f$T_\textit{source}\f$ is the source tensor.
///
/// The XLA documentation has good examples at https://www.tensorflow.org/versions/r1.5/performance/xla/operation_semantics#selectandscatter .
///
/// ## Parameters
///
......@@ -43,9 +63,9 @@ namespace ngraph
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | -------------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ | (TODO: explain more) See the XLA docs. |
/// | Type | Description |
/// | ---------------------- | -------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | See above algorithm. |
class SelectAndScatter : public RequiresTensorViewArgs
{
public:
......
......@@ -35,6 +35,7 @@
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
......@@ -1841,6 +1842,59 @@ void runtime::cpu::CPU_Emitter::EmitReduceWindow(
<< "});\n";
}
void runtime::cpu::CPU_Emitter::EmitSelectAndScatter(
codegen::CodeWriter& writer,
const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto select_and_scatter = static_cast<const op::SelectAndScatter*>(n);
auto selection_function = select_and_scatter->get_functions()[0];
auto scatter_function = select_and_scatter->get_functions()[1];
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
string type = n->get_output_element_type(0).c_type_string();
writer << "auto f_select = [](" << type << " x, " << type << " y) -> char\n{";
writer.indent++;
writer << "\n";
writer << "char result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << selection_function->get_name() << "(args, out);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << "auto f_scatter = [](" << type << " x, " << type << " y) -> " << type << "\n{";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << scatter_function->get_name() << "(args, out);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << "kernel::select_and_scatter<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << args[2].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg0_shape) << "},\n";
writer << " {" << join(arg1_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " f_select,\n";
writer << " f_scatter,\n";
writer << " {" << join(select_and_scatter->get_window_shape()) << "},\n";
writer << " {" << join(select_and_scatter->get_window_movement_strides())
<< "});\n";
}
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
......
......@@ -89,6 +89,7 @@ namespace ngraph
static void EMITTER_DECL(EmitMaxPool);
static void EMITTER_DECL(EmitReverse);
static void EMITTER_DECL(EmitReduceWindow);
static void EMITTER_DECL(EmitSelectAndScatter);
private:
static std::string emit_vector(const TensorViewWrapper&,
......
......@@ -71,6 +71,7 @@
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
......@@ -187,6 +188,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse},
{TI(ngraph::op::ReduceWindow), &runtime::cpu::CPU_Emitter::EmitReduceWindow},
{TI(ngraph::op::SelectAndScatter), &runtime::cpu::CPU_Emitter::EmitSelectAndScatter},
};
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
......@@ -240,6 +242,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/kernel/reduce_window.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reverse.hpp"
#include "ngraph/runtime/kernel/select_and_scatter.hpp"
#include "ngraph/runtime/kernel/slice.hpp"
#include "ngraph/runtime/kernel/sum.hpp"
#include "ngraph/util.hpp"
......
......@@ -33,6 +33,7 @@
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/call_frame.hpp"
......@@ -77,6 +78,7 @@
#include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/reverse.hpp"
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/kernel/select_and_scatter.hpp"
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/kernel/sin.hpp"
#include "ngraph/runtime/kernel/sinh.hpp"
......@@ -565,7 +567,51 @@ private:
}
else if (node_op == "SelectAndScatter")
{
// TODO: Implement this. Stubbed out for because XLA bridge folks need it.
ngraph::op::SelectAndScatter* select_and_scatter =
dynamic_cast<ngraph::op::SelectAndScatter*>(&node);
std::shared_ptr<ngraph::Function> selection_function =
select_and_scatter->get_functions()[0];
std::function<bool(T, T)> f_selection = [this, &node, selection_function](T x,
T y) -> bool {
auto tx = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(0).get_element_type(), Shape{}, "selection_temp_x");
auto ty = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(1).get_element_type(), Shape{}, "selection_temp_y");
auto tr = std::make_shared<runtime::interpreter::INT_TensorView>(
element::boolean, Shape{}, "selection_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(selection_function, {tx, ty}, {tr});
return *(reinterpret_cast<char*>(tr->get_data_ptr()));
};
std::shared_ptr<ngraph::Function> scatter_function =
select_and_scatter->get_functions()[1];
std::function<T(T, T)> f_scatter = [this, &node, scatter_function](T x, T y) -> T {
auto tx = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(0).get_element_type(), Shape{}, "scatter_temp_x");
auto ty = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(1).get_element_type(), Shape{}, "scatter_temp_y");
auto tr = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_output_element_type(0), Shape{}, "scatter_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(scatter_function, {tx, ty}, {tr});
return *(reinterpret_cast<T*>(tr->get_data_ptr()));
};
kernel::select_and_scatter<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(args[2]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
f_selection,
f_scatter,
select_and_scatter->get_window_shape(),
select_and_scatter->get_window_movement_strides());
}
else if (node_op == "Sign")
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void select_and_scatter(T* arg_selectee,
T* arg_source,
T* arg_init,
T* out,
const Shape& arg_selectee_shape,
const Shape& arg_source_shape,
const Shape& out_shape,
std::function<char(T, T)> selection_function,
std::function<T(T, T)> scatter_function,
const Shape& window_shape,
const Strides& window_movement_strides)
{
// First write every element of the output with the supplied initial value.
CoordinateTransform output_transform(out_shape);
for (const Coordinate& out_coord : output_transform)
{
out[output_transform.index(out_coord)] = *arg_init;
}
// Slide the window over selectee/output.
Shape window_start_corner_transform_start(arg_selectee_shape.size(), 0);
Shape window_start_corner_transform_end(arg_selectee_shape.size());
for (size_t i = 0; i < arg_selectee_shape.size(); i++)
{
window_start_corner_transform_end[i] =
arg_selectee_shape[i] - window_shape[i] + 1;
}
CoordinateTransform window_start_corner_transform(
arg_selectee_shape,
window_start_corner_transform_start,
window_start_corner_transform_end,
window_movement_strides);
CoordinateTransform source_transform(arg_source_shape);
CoordinateTransform::Iterator source_it = source_transform.begin();
for (Coordinate window_start_coord : window_start_corner_transform)
{
// We need a physical rather than virtual coordinate to start the window.
window_start_coord =
window_start_corner_transform.to_source_coordinate(window_start_coord);
Shape window_transform_end(arg_selectee_shape.size());
for (size_t i = 0; i < arg_selectee_shape.size(); i++)
{
window_transform_end[i] = window_start_coord[i] + window_shape[i];
}
CoordinateTransform window_transform(
arg_selectee_shape, window_start_coord, window_transform_end);
bool first_val = true;
Coordinate winner_coord;
// This initial value is ignored; it's just here so the compiler knows
// for sure that winner_val is initialized.
T winner_val = 0;
for (const Coordinate& challenger_coord : window_transform)
{
T challenger_val = arg_selectee[window_transform.index(challenger_coord)];
if (first_val || selection_function(challenger_val, winner_val))
{
winner_coord = challenger_coord;
winner_val = challenger_val;
first_val = false;
}
}
Coordinate source_coord = *source_it;
T old_output_val = out[window_transform.index(winner_coord)];
T source_val = arg_source[source_transform.index(source_coord)];
T new_output_val = scatter_function(old_output_val, source_val);
out[window_transform.index(winner_coord)] = new_output_val;
++source_it;
}
}
}
}
}
......@@ -5162,3 +5162,175 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride
EXPECT_EQ((test::NDArray<float, 4>({{{{3, 2, 2}, {2, 2, 3}, {2, 2, 2}}}}).get_vector()),
result->get_vector<float>());
}
//
// From the XLA docs: https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
//
TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
{
auto shape_sel_a = Shape{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
auto shape_sel_b = Shape{};
auto SEL_B = make_shared<op::Parameter>(element::f32, shape_sel_b);
auto sel_f =
make_shared<Function>(make_shared<op::Greater>(SEL_A, SEL_B), op::Parameters{SEL_A, SEL_B});
auto shape_scatter_a = Shape{};
auto SCATTER_A = make_shared<op::Parameter>(element::f32, shape_scatter_a);
auto shape_scatter_b = Shape{};
auto SCATTER_B = make_shared<op::Parameter>(element::f32, shape_scatter_b);
auto scatter_f =
make_shared<Function>(SCATTER_A + SCATTER_B, op::Parameters{SCATTER_A, SCATTER_B});
auto shape_a = Shape{4, 5};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto shape_b = Shape{2, 2};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto shape_c = Shape{};
auto C = make_shared<op::Parameter>(element::f32, shape_c);
auto shape_r = Shape{4, 5};
auto window_shape = Shape{2, 3};
auto window_strides = Strides{2, 2};
auto f = make_shared<Function>(
make_shared<op::SelectAndScatter>(A, B, C, sel_f, scatter_f, window_shape, window_strides),
op::Parameters{A, B, C});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::f32, shape_a);
copy_data(a,
test::NDArray<float, 2>(
{{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}, {1, 5, 7, 5, 6}, {0, 6, 2, 10, 2}})
.get_vector());
auto b = backend->make_primary_tensor_view(element::f32, shape_b);
copy_data(b, test::NDArray<float, 2>({{2, 6}, {3, 1}}).get_vector());
auto c = backend->make_primary_tensor_view(element::f32, shape_c);
copy_data(c, vector<float>{0});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
cf->call({a, b, c}, {result});
EXPECT_EQ((test::NDArray<float, 2>(
{{0, 0, 0, 0, 0}, {0, 0, 8, 0, 0}, {0, 0, 3, 0, 0}, {0, 0, 0, 1, 0}})
.get_vector()),
result->get_vector<float>());
}
//
// From the XLA docs: https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
//
TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
{
auto shape_sel_a = Shape{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
auto shape_sel_b = Shape{};
auto SEL_B = make_shared<op::Parameter>(element::f32, shape_sel_b);
auto sel_f =
make_shared<Function>(make_shared<op::Greater>(SEL_A, SEL_B), op::Parameters{SEL_A, SEL_B});
auto shape_scatter_a = Shape{};
auto SCATTER_A = make_shared<op::Parameter>(element::f32, shape_scatter_a);
auto shape_scatter_b = Shape{};
auto SCATTER_B = make_shared<op::Parameter>(element::f32, shape_scatter_b);
auto scatter_f =
make_shared<Function>(SCATTER_A + SCATTER_B, op::Parameters{SCATTER_A, SCATTER_B});
auto shape_a = Shape{4, 6};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto shape_b = Shape{2, 2};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto shape_c = Shape{};
auto C = make_shared<op::Parameter>(element::f32, shape_c);
auto shape_r = Shape{4, 6};
auto window_shape = Shape{2, 3};
auto window_strides = Strides{2, 3};
auto f = make_shared<Function>(
make_shared<op::SelectAndScatter>(A, B, C, sel_f, scatter_f, window_shape, window_strides),
op::Parameters{A, B, C});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::f32, shape_a);
copy_data(a,
test::NDArray<float, 2>(
{{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}, {1, 5, 7, 5, 6, 1}, {0, 6, 2, 7, 2, 8}})
.get_vector());
auto b = backend->make_primary_tensor_view(element::f32, shape_b);
copy_data(b, test::NDArray<float, 2>({{2, 6}, {3, 1}}).get_vector());
auto c = backend->make_primary_tensor_view(element::f32, shape_c);
copy_data(c, vector<float>{0});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
cf->call({a, b, c}, {result});
EXPECT_EQ((test::NDArray<float, 2>(
{{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}, {0, 0, 3, 0, 0, 0}, {0, 0, 0, 0, 0, 1}})
.get_vector()),
result->get_vector<float>());
}
//
// Adapted from the XLA docs to provide an example in >2D: https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
//
TEST(${BACKEND_NAME}, select_and_scatter_3d_without_overlap)
{
auto shape_sel_a = Shape{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
auto shape_sel_b = Shape{};
auto SEL_B = make_shared<op::Parameter>(element::f32, shape_sel_b);
auto sel_f =
make_shared<Function>(make_shared<op::Greater>(SEL_A, SEL_B), op::Parameters{SEL_A, SEL_B});
auto shape_scatter_a = Shape{};
auto SCATTER_A = make_shared<op::Parameter>(element::f32, shape_scatter_a);
auto shape_scatter_b = Shape{};
auto SCATTER_B = make_shared<op::Parameter>(element::f32, shape_scatter_b);
auto scatter_f =
make_shared<Function>(SCATTER_A + SCATTER_B, op::Parameters{SCATTER_A, SCATTER_B});
auto shape_a = Shape{2, 4, 6};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto shape_b = Shape{1, 2, 2};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto shape_c = Shape{};
auto C = make_shared<op::Parameter>(element::f32, shape_c);
auto shape_r = Shape{2, 4, 6};
auto window_shape = Shape{2, 2, 3};
auto window_strides = Strides{2, 2, 3};
auto f = make_shared<Function>(
make_shared<op::SelectAndScatter>(A, B, C, sel_f, scatter_f, window_shape, window_strides),
op::Parameters{A, B, C});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::f32, shape_a);
copy_data(
a,
test::NDArray<float, 3>(
{{{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}, {1, 5, 7, 5, 6, 1}, {0, 6, 2, 7, 2, 8}},
{{2, 5, 8, 3, 4, 2}, {1, 2, 8, 4, 5, 2}, {10, 2, 3, 4, 1, 0}, {4, 1, 2, 4, 5, 7}}})
.get_vector());
auto b = backend->make_primary_tensor_view(element::f32, shape_b);
copy_data(b, test::NDArray<float, 3>({{{2, 6}, {3, 1}}}).get_vector());
auto c = backend->make_primary_tensor_view(element::f32, shape_c);
copy_data(c, vector<float>{0});
auto result = backend->make_primary_tensor_view(element::f32, shape_r);
cf->call({a, b, c}, {result});
EXPECT_EQ(
(test::NDArray<float, 3>(
{{{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}, {0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 1}},
{{0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0}, {3, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0}}})
.get_vector()),
result->get_vector<float>());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment