Unverified Commit 3a59991e authored by Adam Procter's avatar Adam Procter Committed by GitHub

One-hot op (#272)

parent 47607dfd
......@@ -51,6 +51,7 @@ set (SRC
ops/multiply.cpp
ops/negative.cpp
ops/not.cpp
ops/one_hot.cpp
ops/op.cpp
ops/parameter.cpp
ops/power.cpp
......
......@@ -89,6 +89,7 @@
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/sum.hpp"
using namespace std;
using namespace ngraph;
op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis)
: RequiresTensorViewArgs("OneHot", {arg})
, m_shape(shape)
, m_one_hot_axis(one_hot_axis)
{
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
if (one_hot_axis >= shape.size())
{
throw ngraph_error("One-hot axis is out of bounds");
}
auto expected_arg_shape = shape;
expected_arg_shape.erase(expected_arg_shape.begin() + one_hot_axis);
if (arg_tensor_view_type->get_shape() != expected_arg_shape)
{
throw ngraph_error("One-hot argument shape is not compatible with desired output shape");
}
set_value_type_checked(make_shared<TensorViewType>(arg_element_type, shape));
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief One-hot operator.
///
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------------------------- |
/// | `shape` | The desired output shape, including the new one-hot axis. |
/// | `one_hot_axis` | The index within the output shape of the new one-hot axis. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | ------------------------------------------------------- | ------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
///
/// ## Implementation Status
///
/// | Backend | Status |
/// | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | NGVM | Fully implemented. NOTE: Execution throws `std::range_error` if either a non-integral value or an out-of-bounds value is detected in the input tensor. |
class OneHot : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a one-hot operation.
///
/// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
}
/// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; }
protected:
Shape m_shape;
size_t m_one_hot_axis;
};
}
}
File mode changed from 100755 to 100644
......@@ -34,7 +34,8 @@ namespace ngraph
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename T>
using DynamicArray = Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>;
using DynamicArray =
Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename T>
using EigenArrayBase = Eigen::Map<DynamicArray<T>, 0, DynamicStrides>;
......
......@@ -27,6 +27,7 @@
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
#include "ngraph/ops/reshape.hpp"
......@@ -1541,6 +1542,105 @@ void Emitter::EmitReplaceSlice(const ngraph::Node* n,
}
}
void Emitter::EmitOneHot(const ngraph::Node* n,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
{
auto oh = static_cast<const op::OneHot*>(n);
auto arg_type = oh->get_arguments().at(0)->get_value_type();
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
assert(nullptr != arg_tensor_view_type);
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
auto result_type = oh->get_value_type();
auto result_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(result_type);
assert(nullptr != result_tensor_view_type);
auto result_shape = result_tensor_view_type->get_shape();
size_t bounds = result_shape[oh->get_one_hot_axis()];
if (arg_rank == 0)
{
TU << "{ // " << n->get_name() << " 1\n";
TU.indent++;
TU << "" << emit_vector(outputs[0], "out_vector") << ";\n";
TU << "out_vector.setZero();\n"
<< ""
<< "auto pos_raw = " << emit_vector(inputs[0]) << "(0, 0);\n"
<< "if (std::floor(pos_raw) != pos_raw)\n"
<< "{\n";
TU.indent++;
TU << "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
TU.indent--;
TU << "}\n";
TU << "size_t pos = pos_raw;\n"
<< "if (pos >= " << bounds << ")\n";
TU << "{\n";
TU.indent++;
TU << "throw(std::range_error(\"One-hot: value is out of category range\"));\n";
TU.indent--;
TU << "}\n";
TU << "out_vector(pos, 0) = 1;\n";
TU.indent--;
TU << "}\n";
}
else if (arg_rank == 1)
{
TU << "{ // " << n->get_name() << " 1\n";
TU.indent++;
TU << "" << emit_vector(inputs[0], "arg_vector") << ";\n";
TU << "" << emit_matrix(outputs[0], "out_vector") << ";\n";
TU << "out_vector.setZero();\n";
TU << "for (size_t i = 0; i < " << arg_shape[0] << "; i++)\n"
<< "{\n";
TU.indent++;
TU << "auto pos_raw = arg_vector(i, 0);\n";
TU << "if (std::floor(pos_raw) != pos_raw)\n"
<< "{\n";
TU.indent++;
TU << "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
TU.indent--;
TU << "}\n";
TU << "size_t pos = pos_raw;\n";
TU << "bool found = false;\n";
TU << "if (pos >= " << bounds << ")\n"
<< "{\n";
TU.indent++;
TU << "throw(std::range_error(\"One-hot: value is out of category range\"));\n";
TU.indent--;
TU << "}\n";
TU << "out_vector" << (oh->get_one_hot_axis() == 0 ? "(pos, i)" : "(i, pos)") << " = 1;\n";
TU.indent--;
TU << "}\n";
TU.indent--;
TU << "}\n";
}
// Other cases are not handled yet.
else
{
throw ngraph_error("One-hot is not implemented yet for tensors with rank>1");
}
}
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
......
......@@ -95,6 +95,7 @@ namespace ngraph
void EMITTER_DECL(EmitAtan);
void EMITTER_DECL(EmitPower);
void EMITTER_DECL(EmitReplaceSlice);
void EMITTER_DECL(EmitOneHot);
private:
void generate_call(const std::vector<TensorViewInfo>& inputs,
......
......@@ -56,6 +56,7 @@
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
......@@ -158,6 +159,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::Acos), &Emitter::EmitAcos},
{TI(ngraph::op::Atan), &Emitter::EmitAtan},
{TI(ngraph::op::ReplaceSlice), &Emitter::EmitReplaceSlice},
{TI(ngraph::op::OneHot), &Emitter::EmitOneHot},
};
ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void one_hot(
T* arg, T* out, const Shape& in_shape, const Shape& out_shape, size_t one_hot_axis)
{
// For the outer loop we will walk over the entire input shape.
CoordinateIterator arg_iter(in_shape);
do
{
// For the inner loop we will walk across the entire axis for the one-hot axis, and stay put at the current arg position for the existing axes.
Coordinate arg_coordinate = arg_iter.get_current_coordinate();
Strides out_strides(out_shape.size(), 1);
Coordinate out_outer_corner(out_shape.size());
Coordinate out_inner_corner(out_shape.size());
size_t arg_pos = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (i != one_hot_axis)
{
// This is an existing axis.
out_outer_corner[i] = arg_coordinate[arg_pos];
out_inner_corner[i] = arg_coordinate[arg_pos];
arg_pos++;
}
else
{
// This is the one-hot axis.
out_outer_corner[i] = out_shape[i];
out_inner_corner[i] = 0;
}
}
CoordinateIterator out_iter(
out_shape, out_strides, out_outer_corner, out_inner_corner);
bool found = false;
do
{
auto out_index = out_iter.get_current_index();
auto one_hot_pos = out_iter.get_current_coordinate()[one_hot_axis];
auto in_index = arg_iter.get_current_index();
// The weird test for equality here is because this template winds up being
// instantiated for floating-point types, and clang complains if you try to
// == on a float.
if (arg[in_index] <= one_hot_pos && arg[in_index] >= one_hot_pos)
{
out[out_index] = 1;
found = true;
}
else
{
out[out_index] = 0;
}
} while (out_iter.increment());
if (!found)
{
throw std::range_error("One-hot: value is out of category range");
}
} while (arg_iter.increment());
}
}
}
}
......@@ -38,8 +38,8 @@ namespace ngraph
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray =
Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
using DynamicArray = Eigen::
Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
......
......@@ -53,6 +53,7 @@
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/replace_slice.hpp"
......@@ -118,6 +119,7 @@
#include "ngraph/runtime/ngvm/instruction/negate.hpp"
#include "ngraph/runtime/ngvm/instruction/not.hpp"
#include "ngraph/runtime/ngvm/instruction/not_equal.hpp"
#include "ngraph/runtime/ngvm/instruction/one_hot.hpp"
#include "ngraph/runtime/ngvm/instruction/power.hpp"
#include "ngraph/runtime/ngvm/instruction/return.hpp"
#include "ngraph/runtime/ngvm/instruction/select.hpp"
......@@ -1044,6 +1046,31 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}
};
REGISTER_TO_OP_MAP(op::OneHot)
{
auto one_hot = static_cast<const op::OneHot*>(n);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"One-hot has unhandled element type",
instruction::OneHotInstruction,
in[0],
out[0],
arg_shape,
result_shape,
one_hot->get_one_hot_axis());
};
initialized = true;
}
return op_map;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class OneHotInstruction : public Instruction
{
public:
OneHotInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out,
const Shape& arg_shape,
const Shape& out_shape,
size_t one_hot_axis)
: m_arg(arg)
, m_out(out)
, m_arg_shape(arg_shape)
, m_out_shape(out_shape)
, m_one_hot_axis(one_hot_axis)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
kernel::one_hot<typename ET::type>(
arg, out, m_arg_shape, m_out_shape, m_one_hot_axis);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
Shape m_arg_shape;
Shape m_out_shape;
size_t m_one_hot_axis;
};
}
}
}
}
......@@ -63,6 +63,8 @@ namespace ngraph
bool operator<(const Type& other) const;
friend std::ostream& operator<<(std::ostream&, const Type&);
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
private:
static std::map<std::string, Type> m_element_list;
size_t m_bitwidth;
......
This diff is collapsed.
......@@ -1922,3 +1922,99 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_deduce_scalar)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{});
auto oh = make_shared<op::OneHot>(param, Shape{9}, 0);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{9}));
}
TEST(type_prop, one_hot_deduce_vector_0)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{8});
auto oh = make_shared<op::OneHot>(param, Shape{9, 8}, 0);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{9, 8}));
}
TEST(type_prop, one_hot_deduce_vector_1)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{8});
auto oh = make_shared<op::OneHot>(param, Shape{8, 9}, 1);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{8, 9}));
}
TEST(type_prop, one_hot_deduce_matrix_0)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{2, 12, 24}, 0);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{2, 12, 24}));
}
TEST(type_prop, one_hot_deduce_matrix_1)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{12, 2, 24}, 1);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{12, 2, 24}));
}
TEST(type_prop, one_hot_deduce_matrix_2)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Int32::element_type(), Shape{12, 24, 2}));
}
TEST(type_prop, one_hot_deduce_floating_point)
{
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 2);
auto oh_vt = oh->get_value_type();
ASSERT_EQ(*oh_vt, TensorViewType(element::Float32::element_type(), Shape{12, 24, 8}));
}
TEST(type_prop, one_hot_deduce_axis_oob)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{12, 24});
try
{
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
// Should have thrown, so fail if it didn't
FAIL() << "One-hot axis out of bounds not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("One-hot axis is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_deduce_shape_incompatible)
{
auto param = make_shared<op::Parameter>(element::Int32::element_type(), Shape{12, 24});
try
{
auto oh = make_shared<op::OneHot>(param, Shape{12, 22, 8}, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible one-hot output shape not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string("One-hot argument shape is not compatible with desired output shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment