Unverified Commit 524d04fc authored by Adam Procter's avatar Adam Procter Committed by GitHub

Definitions of XLA ConvNet MNIST ops (#324)

parent 686ee9ab
......@@ -59,9 +59,12 @@ set (SRC
ops/parameter.cpp
ops/power.cpp
ops/reduce.cpp
ops/reduce_window.cpp
ops/replace_slice.cpp
ops/reshape.cpp
ops/reverse.cpp
ops/select.cpp
ops/select_and_scatter.cpp
ops/sin.cpp
ops/sinh.cpp
ops/slice.cpp
......
......@@ -85,8 +85,7 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
stack.pop_front();
for (shared_ptr<Node> op : func->get_ops())
{
shared_ptr<Function> fp = op->get_function();
if (fp)
for (shared_ptr<Function> fp : op->get_functions())
{
stack.push_front(fp);
}
......
......@@ -93,10 +93,13 @@
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
......
......@@ -214,9 +214,9 @@ std::shared_ptr<Node> Node::backprop_node(const std::shared_ptr<Node>& x,
return adjoints_it->second.get(x);
}
std::shared_ptr<Function> Node::get_function() const
std::vector<std::shared_ptr<Function>> Node::get_functions() const
{
return nullptr;
return std::vector<std::shared_ptr<Function>>{};
}
namespace ngraph
......
......@@ -155,7 +155,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const std::vector<std::shared_ptr<Node>>& new_args) const = 0;
virtual std::shared_ptr<Function> get_function() const;
virtual std::vector<std::shared_ptr<Function>> get_functions() const;
// True if this and node have one output with same element type and shape
bool has_same_type(std::shared_ptr<const Node> node) const;
......
......@@ -55,8 +55,12 @@ namespace ngraph
return std::make_shared<FunctionCall>(m_function, new_args);
}
/// \return The function to be called.
std::shared_ptr<Function> get_function() const override { return m_function; }
/// \return A singleton vector containing the function to be called.
std::vector<std::shared_ptr<Function>> get_functions() const override
{
return std::vector<std::shared_ptr<Function>>{m_function};
}
protected:
std::shared_ptr<Function> m_function;
};
......
......@@ -81,7 +81,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
}
if (m_reduction_function->get_output_element_type(0) != arg_init->get_element_type())
{
throw ngraph_error("Return type from reduction function does not match expected");
throw ngraph_error("Return element type from reduction function does not match expected");
}
if (m_reduction_function->get_output_shape(0) != Shape{})
{
throw ngraph_error("Return shape from reduction function is not a scalar");
}
add_output(input_reductee.get_element_type(), result_shape);
......
......@@ -103,8 +103,11 @@ namespace ngraph
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
}
/// \return The function to use for reduction.
std::shared_ptr<Function> get_function() const override { return m_reduction_function; }
/// \return A one-element vector containing the function to use for reduction.
std::vector<std::shared_ptr<Function>> get_functions() const override
{
return std::vector<std::shared_ptr<Function>>{m_reduction_function};
}
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/reduce_window.hpp"
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const Shape& window_shape,
const Strides& window_movement_strides)
: RequiresTensorViewArgs("ReduceWindow", {arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
{
auto& input_reductee = get_inputs().at(0);
auto& input_init = get_inputs().at(1);
auto input_reductee_shape = input_reductee.get_shape();
auto input_init_shape = input_init.get_shape();
if (input_init.get_shape().size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
}
if (input_init.get_element_type() != input_reductee.get_element_type())
{
throw ngraph_error("Element types for reductee and initial values do not match");
}
if (input_reductee_shape.size() != window_shape.size())
{
throw ngraph_error("Window shape has different rank from input tensor");
}
if (input_reductee_shape.size() != window_movement_strides.size())
{
throw ngraph_error("Window movement strides have different rank from input tensor");
}
for (size_t s : window_shape)
{
if (s == 0)
{
throw ngraph_error("Window shape has a zero-length axis");
}
}
for (size_t s : window_movement_strides)
{
if (s == 0)
{
throw ngraph_error("Window movement stride for some axis is zero");
}
}
for (size_t i = 0; i < input_reductee_shape.size(); i++)
{
if (window_shape[i] > input_reductee_shape[i])
{
throw ngraph_error("Reduction window is bigger than input");
}
}
auto f_params = m_reduction_function->get_parameters();
if (f_params.size() != 2)
{
throw ngraph_error("Reduction function has wrong number of parameters (should be two)");
}
if (f_params.at(0)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 0 of reduction function has wrong element type");
}
if (f_params.at(1)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 1 of reduction function has wrong element type");
}
if (f_params.at(0)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 0 of reduction function is not a scalar");
}
if (f_params.at(1)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 1 of reduction function is not a scalar");
}
if (m_reduction_function->get_output_size() > 1)
{
throw ngraph_error("Single-output reduction function was expected");
}
if (m_reduction_function->get_output_element_type(0) != arg_init->get_element_type())
{
throw ngraph_error("Return element type from reduction function does not match expected");
}
if (m_reduction_function->get_output_shape(0) != Shape{})
{
throw ngraph_error("Return shape from reduction function is not a scalar");
}
Shape result_shape;
for (size_t i = 0; i < input_reductee_shape.size(); i++)
{
result_shape.push_back(
ceil_div(input_reductee_shape[i] - window_shape[i] + 1, window_movement_strides[i]));
}
set_value_type_checked(input_reductee.get_element_type(), result_shape);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Windowed reduction operation.
///
/// Slides a window of user-defined shape, with user-defined strides, over the tensor and produces for each window position the result obtained by
/// reducing the tensors in the window to a scalar, using the user-supplied reduction function.
///
/// Given an input of shape \f$(d_1,\dots,d_n)\f$, a window shape of \f$(w_1,\dots,w_n)\f$ and window movement strides of \f$(s_1,\dots,s_n)\f$, the shape
/// of the output is \f$(d'_1,\dots,d'_n)\f$ where \f$d'_i = \lceil \frac {d_i - w_i + 1}{s_i} \rceil\f$.
///
/// ## Parameters
///
/// | | Description |
/// | ------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
/// | `reduction_function` | The scalar function used to reduce the input tensor. Must take two arguments of type \f$E[]\f$ and return type \f$E[]\f$. |
/// | `window_shape` | The shape \f$(w_1,\dots,w_n)\f$ of the reduction window. |
/// | `window_movement_strides` | Movement strides \f$(s_1,\dots,s_n)\f$ to apply to the sliding window. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | -------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `arg_reductee` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape, with the element type matching that expected by the reduction function. |
/// | `arg_init` | \f$E[]\f$ | A scalar to be used as an initial value for reduction computations. |
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{reduce}(\mathit{reduction\_function},\mathit{arg\_init},V)\f$ where \f$V\f$ is the set of values in the input tensor within the window defined by the lower bound \f$(s_1i_1,\dots,s_ni_n)\f$ and the noninclusive upper bound \f$(s_1i_1 + w_1,\dots,s_ni_n + w_n)\f$. |
class ReduceWindow : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reduce-window operation.
///
/// \param arg_reductee The tensor view to be reduced.
/// \param arg_init The initial value for reduction.
/// \param reduction_function The reduction function to use.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const Shape& window_shape,
const Strides& window_movement_strides);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1),
m_reduction_function,
m_window_shape,
m_window_movement_strides);
}
/// \return A singleton vector containing the function to use for reduction.
std::vector<std::shared_ptr<Function>> get_functions() const override
{
return std::vector<std::shared_ptr<Function>>{m_reduction_function};
}
/// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
protected:
std::shared_ptr<Function> m_reduction_function;
Shape m_window_shape;
Strides m_window_movement_strides;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/reverse.hpp"
#include "ngraph/function.hpp"
#include <algorithm>
using namespace std;
using namespace ngraph;
op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_axes)
: RequiresTensorViewArgs("Reverse", {arg})
, m_reversed_axes(reversed_axes)
{
auto& input = get_inputs().at(0);
auto input_shape = input.get_shape();
auto input_rank = input_shape.size();
// Make sure all reversed axis indices are valid.
for (size_t axis : reversed_axes)
{
if (axis >= input_rank)
{
std::stringstream ss;
ss << "Reverse axis " << axis << " is out of bounds (input rank is " << input_rank
<< ").";
throw ngraph_error(ss.str());
}
}
set_value_type_checked(input.get_element_type(), input_shape);
}
void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = get_input_op(0);
adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes));
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Axis-reverse operation.
///
/// Reverses the direction of zero or more axes in a tensor, where "reversing" an axis means that at the output tensor.
///
/// ## Parameters
///
/// | | Description |
/// | --------------- | ------------------------ |
/// | `reversed_axes` | The axes to be reversed. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | -------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any type and shape. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. |
class Reverse : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reverse operation.
///
/// \param arg The input tensor view, some of whose axes are to be reversed.
/// \param reversed_axes The axes to reverse.
Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<Reverse>(new_args.at(0), m_reversed_axes);
}
/// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
const AxisSet m_reversed_axes;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/select_and_scatter.hpp"
#include "ngraph/function.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee,
const std::shared_ptr<Node>& arg_source,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& selection_function,
const std::shared_ptr<Function>& scatter_function,
const Shape& window_shape,
const Strides& window_movement_strides)
: RequiresTensorViewArgs("SelectAndScatter", {arg_selectee, arg_source, arg_init})
, m_selection_function(selection_function)
, m_scatter_function(scatter_function)
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
{
auto& input_selectee = get_inputs().at(0);
auto& input_source = get_inputs().at(1);
auto& input_init = get_inputs().at(2);
auto input_selectee_shape = input_selectee.get_shape();
auto input_source_shape = input_source.get_shape();
auto input_init_shape = input_init.get_shape();
auto& input_selectee_element_type = input_selectee.get_element_type();
auto& input_source_element_type = input_source.get_element_type();
auto& input_init_element_type = input_init.get_element_type();
//
// Make sure the initial value is a scalar.
//
if (input_init_shape.size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
}
//
// Make sure input element types all match.
//
if (input_init_element_type != input_selectee_element_type)
{
throw ngraph_error("Element types for selectee and initial values do not match");
}
if (input_source_element_type != input_selectee_element_type)
{
throw ngraph_error("Element types for selectee and source tensors do not match");
}
//
// Check that the window shape and strides have the right rank.
//
if (input_selectee_shape.size() != window_shape.size())
{
throw ngraph_error("Window shape has different rank from selectee tensor");
}
if (input_selectee_shape.size() != window_movement_strides.size())
{
throw ngraph_error("Window movement strides have different rank from selectee tensor");
}
//
// Check for zero-length window axes or strides.
//
for (size_t s : window_shape)
{
if (s == 0)
{
throw ngraph_error("Window shape has a zero-length axis");
}
}
for (size_t s : window_movement_strides)
{
if (s == 0)
{
throw ngraph_error("Window movement stride for some axis is zero");
}
}
//
// Check that the window is not bigger than the selectee tensor.
//
for (size_t i = 0; i < input_selectee_shape.size(); i++)
{
if (window_shape[i] > input_selectee_shape[i])
{
throw ngraph_error("Reduction window is bigger than selectee tensor");
}
}
//
// The expected shape of the source tensor is the same as the shape of the output
// we would get if we window-reduced the selectee; in other words, this logic is
// the same as the logic for computing the output shape of reduce-window.
//
Shape expected_source_shape;
for (size_t i = 0; i < input_selectee_shape.size(); i++)
{
expected_source_shape.push_back(
ceil_div(input_selectee_shape[i] - window_shape[i] + 1, window_movement_strides[i]));
}
if (input_source_shape != expected_source_shape)
{
throw ngraph_error("Source tensor does not have expected shape");
}
//
// Check the type signature of the selection function. Should be T -> T -> Bool.
//
auto selection_function_params = m_selection_function->get_parameters();
if (selection_function_params.size() != 2)
{
throw ngraph_error("Selection function has wrong number of parameters (should be two)");
}
if (selection_function_params.at(0)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 0 of selection function has wrong element type");
}
if (selection_function_params.at(1)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 1 of selection function has wrong element type");
}
if (selection_function_params.at(0)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 0 of selection function is not a scalar");
}
if (selection_function_params.at(1)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 1 of selection function is not a scalar");
}
if (m_selection_function->get_output_size() > 1)
{
throw ngraph_error("Single-output selection function was expected");
}
if (m_selection_function->get_output_element_type(0) != element::boolean)
{
throw ngraph_error("Return element type from selection function is not boolean");
}
if (m_selection_function->get_output_shape(0) != Shape{})
{
throw ngraph_error("Return shape from selection function is not a scalar");
}
//
// Check the type signature of the scatter function. Should be T -> T -> T.
//
auto scatter_function_params = m_scatter_function->get_parameters();
if (scatter_function_params.size() != 2)
{
throw ngraph_error("Scatter function has wrong number of parameters (should be two)");
}
if (scatter_function_params.at(0)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 0 of scatter function has wrong element type");
}
if (scatter_function_params.at(1)->get_element_type() != arg_init->get_element_type())
{
throw ngraph_error("Parameter 1 of scatter function has wrong element type");
}
if (scatter_function_params.at(0)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 0 of scatter function is not a scalar");
}
if (scatter_function_params.at(1)->get_shape() != Shape{})
{
throw ngraph_error("Parameter 1 of scatter function is not a scalar");
}
if (m_scatter_function->get_output_size() > 1)
{
throw ngraph_error("Single-output scatter function was expected");
}
if (m_scatter_function->get_output_element_type(0) != arg_init->get_element_type())
{
throw ngraph_error(
"Return element type from scatter function does not match the init value type");
}
if (m_scatter_function->get_output_shape(0) != Shape{})
{
throw ngraph_error("Return shape from scatter function is not a scalar");
}
//
// Result type is the same element type and shape as the selectee.
//
set_value_type_checked(input_selectee_element_type, input_selectee_shape);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Select-and-scatter operation.
///
/// TODO: More formal definition. For now, see: https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter.
///
/// ## Parameters
///
/// | | Description |
/// | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
/// | `selection_function` | The scalar function used to select between two values. Must take two arguments of type \f$E[]\f$ and return type \f$\mathit{Bool}[]\f$. |
/// | `scatter_function` | The scalar function used to apply a scattered value. Must take two arguments of type \f$E[]\f$ and return type \f$E[]\f$. |
/// | `window_shape` | The shape \f$(w_1,\dots,w_n)\f$ of the selection window. |
/// | `window_movement_strides` | Movement strides \f$(s_1,\dots,s_n)\f$ to apply to the sliding window. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | -------------- | ----------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `arg_selectee` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape, with the element type matching that expected by the selection function. |
/// | `arg_source` | \f$E[d'_1,\dots,d'_n]~(n \geq 0)\f$ | The input tensor from which to scatter values. |
/// | `arg_init` | \f$E[]\f$ | A scalar to be used as an initial value in each output cell. |
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | -------------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ | (TODO: explain more) See the XLA docs. |
class SelectAndScatter : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a select-and-scatter operation.
///
/// \param arg_selectee The tensor view to be selected from.
/// \param arg_source The tensor to scatter values from.
/// \param arg_init The initial value for output.
/// \param selection_function The selection function.
/// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides.
SelectAndScatter(const std::shared_ptr<Node>& arg_selectee,
const std::shared_ptr<Node>& arg_source,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& selection_function,
const std::shared_ptr<Function>& scatter_function,
const Shape& window_shape,
const Strides& window_movement_strides);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_selection_function,
m_scatter_function,
m_window_shape,
m_window_movement_strides);
}
/// \return A vector of length 2 containing the selection function as element 0, and the scatter function as element 1.
std::vector<std::shared_ptr<Function>> get_functions() const override
{
return {m_selection_function, m_scatter_function};
}
/// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; }
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
protected:
std::shared_ptr<Function> m_selection_function;
std::shared_ptr<Function> m_scatter_function;
Shape m_window_shape;
Strides m_window_movement_strides;
};
}
}
......@@ -33,6 +33,7 @@
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
......@@ -652,7 +653,7 @@ void runtime::cpu::CPU_Emitter::EmitFunctionCall(
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto function_call = static_cast<const op::FunctionCall*>(n);
shared_ptr<Function> function = function_call->get_function();
shared_ptr<Function> function = function_call->get_functions()[0];
m_out << "{ // Call " << function->get_name() << "\n";
m_out.indent++;
......@@ -672,7 +673,7 @@ void runtime::cpu::CPU_Emitter::EmitReduce(const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto reduce = static_cast<const op::Reduce*>(n);
auto reduction_function = reduce->get_function();
auto reduction_function = reduce->get_functions()[0];
auto reductee_shape = args[0].get_shape();
......@@ -1382,6 +1383,22 @@ void runtime::cpu::CPU_Emitter::EmitMaxPool(const ngraph::Node* n,
m_out << " {" << join(max_pool->get_window_movement_strides()) << "});\n";
}
void runtime::cpu::CPU_Emitter::EmitReverse(const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto reverse = static_cast<const op::Reverse*>(n);
auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
m_out << "kernel::reverse<" << out[0].get_type() << ">(" << args[0].get_name() << ",\n";
m_out << " " << out[0].get_name() << ",\n";
m_out << " {" << join(arg_shape) << "},\n";
m_out << " {" << join(result_shape) << "},\n";
m_out << " {" << join(reverse->get_reversed_axes()) << "});\n";
}
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
......
......@@ -97,6 +97,7 @@ namespace ngraph
void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitReverse);
private:
void generate_call(const std::vector<TensorViewWrapper>& args,
......
......@@ -67,6 +67,7 @@
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/sin.hpp"
......@@ -151,6 +152,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::EmitConvolution},
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot},
{TI(ngraph::op::MaxPool), &runtime::cpu::CPU_Emitter::EmitMaxPool},
{TI(ngraph::op::Reverse), &runtime::cpu::CPU_Emitter::EmitReverse},
};
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
......@@ -204,6 +206,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reverse.hpp"
#include "ngraph/runtime/kernel/slice.hpp"
#include "ngraph/runtime/kernel/sum.hpp"
#include "ngraph/util.hpp"
......
......@@ -31,6 +31,7 @@
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/reverse.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/call_frame.hpp"
......@@ -72,6 +73,7 @@
#include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
#include "ngraph/runtime/kernel/reshape.hpp"
#include "ngraph/runtime/kernel/reverse.hpp"
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/kernel/sin.hpp"
......@@ -344,7 +346,7 @@ private:
}
else if (node_op == "FunctionCall")
{
std::shared_ptr<Function> function = node.get_function();
std::shared_ptr<Function> function = node.get_functions()[0];
call(function, args, out);
}
else if (node_op == "Greater")
......@@ -454,7 +456,7 @@ private:
else if (node_op == "Reduce")
{
ngraph::op::Reduce* reduce = dynamic_cast<ngraph::op::Reduce*>(&node);
std::shared_ptr<ngraph::Function> reduction_function = reduce->get_function();
std::shared_ptr<ngraph::Function> reduction_function = reduce->get_functions()[0];
std::function<T(T, T)> f = [this, &node, reduction_function](T x, T y) -> T {
auto tx = std::make_shared<runtime::interpreter::INT_TensorView>(
......@@ -477,6 +479,10 @@ private:
reduce->get_reduction_axes(),
f);
}
else if (node_op == "ReduceWindow")
{
// TODO: Implement this. Stubbed out for because XLA bridge folks need it.
}
// else if (node_op == "Remainder")
// {
// // node = make_shared<op::Remainder>(args[0], args[1]);
......@@ -502,6 +508,15 @@ private:
reshape->get_input_order(),
out[0]->get_shape());
}
else if (node_op == "Reverse")
{
ngraph::op::Reverse* reverse = dynamic_cast<ngraph::op::Reverse*>(&node);
kernel::reverse(reinterpret_cast<T*>(args[0]->get_data_ptr()),
reinterpret_cast<T*>(out[0]->get_data_ptr()),
args[0]->get_shape(),
out[0]->get_shape(),
reverse->get_reversed_axes());
}
else if (node_op == "Select")
{
kernel::select<T>(reinterpret_cast<char*>(args[0]->get_data_ptr()),
......@@ -510,6 +525,10 @@ private:
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "SelectAndScatter")
{
// TODO: Implement this. Stubbed out for because XLA bridge folks need it.
}
else if (node_op == "Sign")
{
kernel::sign<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void reverse(T* arg,
T* out,
const Shape& arg_shape,
const Shape& out_shape,
const AxisSet& reversed_axes)
{
// In fact arg_shape == out_shape, but we'll use both for stylistic consistency with other kernels.
CoordinateTransform arg_transform(arg_shape);
CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform)
{
Coordinate arg_coord = out_coord;
for (size_t i = 0; i < arg_coord.size(); i++)
{
if (reversed_axes.count(i) != 0)
{
arg_coord[i] = arg_shape[i] - arg_coord[i] - 1;
}
}
out[output_transform.index(out_coord)] = arg[arg_transform.index(arg_coord)];
}
}
}
}
}
......@@ -670,7 +670,7 @@ static json write(const Node& n)
}
else if (node_op == "FunctionCall")
{
node["function"] = n.get_function()->get_name();
node["function"] = n.get_functions()[0]->get_name();
}
else if (node_op == "GetOutputElement")
{
......@@ -732,7 +732,7 @@ static json write(const Node& n)
else if (node_op == "Reduce")
{
auto tmp = dynamic_cast<const op::Reduce*>(&n);
node["function"] = tmp->get_function()->get_name();
node["function"] = tmp->get_functions()[0]->get_name();
node["reduction_axes"] = tmp->get_reduction_axes();
}
else if (node_op == "Remainder")
......
......@@ -1110,3 +1110,20 @@ TEST(${BACKEND_NAME}, backwards_abc)
EXPECT_TRUE(
autodiff_numeric_compare<float>(manager, backend, make_graph, {x0, x1, x2}, .01f, .01f));
}
TEST(${BACKEND_NAME}, backwards_reverse_3d_02)
{
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto backend = manager->allocate_backend();
test::Uniform<float> rng(-1.0f, 1.0f);
auto shape = Shape{2, 4, 5};
auto x = rng.initialize(backend->make_primary_tensor_view(element::f32, shape));
auto make_graph = [shape]() {
auto X = make_shared<op::Parameter>(element::f32, shape);
return make_shared<Function>(make_shared<op::Reverse>(X, AxisSet{0, 2}),
std::vector<std::shared_ptr<op::Parameter>>{X});
};
EXPECT_TRUE(autodiff_numeric_compare<float>(manager, backend, make_graph, {x}, .01f, .01f));
}
This diff is collapsed.
......@@ -216,7 +216,7 @@ TEST(copy, FunctionCall)
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_input_ops());
ASSERT_TRUE(node_cast->get_function() == f);
ASSERT_TRUE(node_cast->get_functions()[0] == f);
}
TEST(copy, greater_eq)
......@@ -309,7 +309,7 @@ TEST(copy, reduce)
ASSERT_TRUE(nullptr != new_node);
ASSERT_TRUE(new_args == new_node->get_input_ops());
ASSERT_TRUE(f == node_cast->get_function());
ASSERT_TRUE(f == node_cast->get_functions()[0]);
ASSERT_TRUE(axes == node_cast->get_reduction_axes());
}
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment