Unverified Commit c5ffe8e9 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Implement reduce-window in interpreter and CPU (#359)

parent 7b1dc3e3
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ngraph/ops/max_pool.hpp" #include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/replace_slice.hpp" #include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp" #include "ngraph/ops/reverse.hpp"
...@@ -1722,6 +1723,41 @@ void runtime::cpu::CPU_Emitter::EmitReverse(const ngraph::Node* n, ...@@ -1722,6 +1723,41 @@ void runtime::cpu::CPU_Emitter::EmitReverse(const ngraph::Node* n,
m_out << " {" << join(reverse->get_reversed_axes()) << "});\n"; m_out << " {" << join(reverse->get_reversed_axes()) << "});\n";
} }
void runtime::cpu::CPU_Emitter::EmitReduceWindow(
const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto reduce_window = static_cast<const op::ReduceWindow*>(n);
auto arg_reductee_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
auto reduction_function = reduce_window->get_functions()[0];
auto& f_result_element_type = out[0].get_element_type();
string type = f_result_element_type.c_type_string();
m_out << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
m_out.indent++;
m_out << "\n";
m_out << type << " result;\n";
m_out << "void* args[] = {&x, &y};\n";
m_out << "void* out[] = {&result};\n";
m_out << reduction_function->get_name() << "(args, out);\n";
m_out << "return result;\n";
m_out.indent--;
m_out << "};\n";
m_out << "kernel::reduce_window<" << out[0].get_type() << ">(" << args[0].get_name() << ",\n";
m_out << " " << args[1].get_name() << ",\n";
m_out << " " << out[0].get_name() << ",\n";
m_out << " {" << join(arg_reductee_shape) << "},\n";
m_out << " {" << join(result_shape) << "},\n";
m_out << " f,\n";
m_out << " {" << join(reduce_window->get_window_shape()) << "},\n";
m_out << " {" << join(reduce_window->get_window_movement_strides())
<< "});\n";
}
//------------------------------------------------------------------------------------------------ //------------------------------------------------------------------------------------------------
// Utility methods // Utility methods
//------------------------------------------------------------------------------------------------ //------------------------------------------------------------------------------------------------
......
...@@ -98,6 +98,7 @@ namespace ngraph ...@@ -98,6 +98,7 @@ namespace ngraph
void EMITTER_DECL(EmitNot); void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitMaxPool); void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitReverse); void EMITTER_DECL(EmitReverse);
void EMITTER_DECL(EmitReduceWindow);
private: private:
void generate_call(const std::vector<TensorViewWrapper>& args, void generate_call(const std::vector<TensorViewWrapper>& args,
......
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp" #include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/replace_slice.hpp" #include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp" #include "ngraph/ops/reverse.hpp"
...@@ -185,6 +186,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -185,6 +186,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot}, {TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot},
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool}, {TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse}, {TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse},
{TI(ngraph::op::ReduceWindow), &runtime::cpu::CPU_Emitter::EmitReduceWindow},
}; };
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction( runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
...@@ -236,6 +238,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -236,6 +238,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/kernel/not.hpp" #include "ngraph/runtime/kernel/not.hpp"
#include "ngraph/runtime/kernel/one_hot.hpp" #include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/kernel/reduce.hpp" #include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/reduce_window.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp" #include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reverse.hpp" #include "ngraph/runtime/kernel/reverse.hpp"
#include "ngraph/runtime/kernel/slice.hpp" #include "ngraph/runtime/kernel/slice.hpp"
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/ops/max_pool.hpp" #include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/one_hot.hpp" #include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/replace_slice.hpp" #include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp" #include "ngraph/ops/reverse.hpp"
...@@ -71,6 +72,7 @@ ...@@ -71,6 +72,7 @@
#include "ngraph/runtime/kernel/one_hot.hpp" #include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/kernel/power.hpp" #include "ngraph/runtime/kernel/power.hpp"
#include "ngraph/runtime/kernel/reduce.hpp" #include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/reduce_window.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp" #include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reshape.hpp" #include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/reverse.hpp" #include "ngraph/runtime/kernel/reverse.hpp"
...@@ -485,7 +487,32 @@ private: ...@@ -485,7 +487,32 @@ private:
} }
else if (node_op == "ReduceWindow") else if (node_op == "ReduceWindow")
{ {
// TODO: Implement this. Stubbed out for because XLA bridge folks need it. ngraph::op::ReduceWindow* reduce_window =
dynamic_cast<ngraph::op::ReduceWindow*>(&node);
std::shared_ptr<ngraph::Function> reduction_function =
reduce_window->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(0).get_element_type(), Shape{}, "reduce_window_temp_x");
auto ty = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_inputs().at(1).get_element_type(), Shape{}, "reduce_window_temp_y");
auto tr = std::make_shared<runtime::interpreter::INT_TensorView>(
node.get_output_element_type(0), Shape{}, "reduce_window_temp_r");
*(reinterpret_cast<T*>(tx->get_data_ptr())) = x;
*(reinterpret_cast<T*>(ty->get_data_ptr())) = y;
call(reduction_function, {tx, ty}, {tr});
return *(reinterpret_cast<T*>(tr->get_data_ptr()));
};
kernel::reduce_window(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(args[1]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
node.get_inputs().at(0).get_shape(),
node.get_output_shape(0),
f,
reduce_window->get_window_shape(),
reduce_window->get_window_movement_strides());
} }
// else if (node_op == "Remainder") // else if (node_op == "Remainder")
// { // {
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void reduce_window(T* arg_reductee,
T* arg_init,
T* out,
const Shape& arg_reductee_shape,
const Shape& out_shape,
std::function<T(T, T)> reduction_function,
const Shape& window_shape,
const Strides& window_movement_strides)
{
// At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape);
for (const Coordinate& out_coord : output_transform)
{
// Our output coordinate O will have the form:
//
// (i_1,...,i_n)
//
// For the reductee we need to iterate the coordinate:
//
// I:
//
// over the range (noninclusive on the right):
//
// (s_1*i_1,s_2*i_2,...,s_n*i_n) ->
//
// (s_1*i_1 + window_shape_1,...,s_n*i_n + window_shape_n)
//
// with unit stride.
Shape reductee_transform_start;
Shape reductee_transform_end;
for (size_t i = 0; i < arg_reductee_shape.size(); i++)
{
size_t window_shape_this_dim = window_shape[i];
size_t movement_stride = window_movement_strides[i];
reductee_transform_start.push_back(movement_stride * out_coord[i]);
reductee_transform_end.push_back(reductee_transform_start[i] +
window_shape_this_dim);
}
CoordinateTransform reductee_transform(
arg_reductee_shape, reductee_transform_start, reductee_transform_end);
// As we go, we compute the reduced value:
//
// output[O] := reduction_function(output[O],arg[I])
T result = *arg_init;
for (const Coordinate& reductee_coord : reductee_transform)
{
result = reduction_function(
result, arg_reductee[reductee_transform.index(reductee_coord)]);
}
out[output_transform.index(out_coord)] = result;
}
}
}
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment