Commit de27f2b1 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add OneHot:v1 (#3884)

* Moved OneHot to v0

* Introduced OneHot:v1

* Added shape calculation for OneHot:v1

* Added element types checking

* Added output shape tests

* Added tests to checking if inputs are scalars

* Updated OneHot:v1 doc

* Implemented OneHot:v1 downgrade pass

* Using OneHot:v1 in onnx_importer

* Implemented OneHot:v0 upgrade

* Fixed OneHot onnx_importer

* Refactored normalize_axis

* Added OneHot:v1 serialized

* Code review remarks introduced

* Added doc to normalize_axis
parent 9b3d197a
......@@ -17,18 +17,12 @@
#include <cstdint>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
......@@ -43,49 +37,18 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()};
auto indices =
std::make_shared<ngraph::op::Convert>(inputs.at(0), element::i64);
auto indices_shape = indices->get_shape();
auto depth = inputs.at(1);
auto values = inputs.at(2);
std::shared_ptr<ngraph::Node> off_value =
std::make_shared<ngraph::op::Slice>(values, Coordinate{0}, Coordinate{1});
std::shared_ptr<ngraph::Node> on_value =
std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2});
auto axis = node.get_attribute_value<std::int64_t>("axis", -1);
// Accepted range for axis is [-r-1, r] where r = rank(indices). Validate
// against rank+1.
std::size_t valid_axis = common::validate_axis(node,
axis,
indices_shape.size() + 1,
-indices_shape.size() - 1,
indices_shape.size());
auto depth = reshape::interpret_as_scalar(inputs.at(1));
auto constant_depth = ngraph::as_type_ptr<ngraph::op::Constant>(depth);
ASSERT_VALID_ARGUMENT(node, constant_depth)
<< "Only constant values for depth input are supported for the OneHot "
"operator.";
auto values = inputs.at(2);
std::shared_ptr<ngraph::Node> off_value = reshape::interpret_as_scalar(
std::make_shared<ngraph::op::Slice>(values, Coordinate{0}, Coordinate{1}));
std::shared_ptr<ngraph::Node> on_value = reshape::interpret_as_scalar(
std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2}));
std::int64_t depth_value = constant_depth->get_vector<std::int64_t>()[0];
auto output_shape = indices_shape;
// Insert OneHot axis on position pointed by an axis attribute.
// example:
// data_shape = (2, 2)
// axis = 1
// depth = 10
// output_shape = (2, 10, 2)
output_shape.insert(std::next(std::begin(output_shape), valid_axis),
depth_value);
auto axis = node.get_attribute_value<std::int64_t>("axis", -1);
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, valid_axis),
values->get_element_type());
auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
on_value = broadcasted_values[1];
off_value = broadcasted_values[2];
one_hot = one_hot * (on_value - off_value) + off_value;
return {one_hot};
return {std::make_shared<ngraph::op::v1::OneHot>(
indices, depth, on_value, off_value, axis)};
}
} // namespace set_1
......
......@@ -16,6 +16,7 @@
#include <onnx/onnx_pb.h> // onnx types
#include "common.hpp"
#include "validation_util.hpp"
namespace ngraph
{
......@@ -60,23 +61,8 @@ namespace ngraph
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node.get_description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");
if (axis < 0)
{
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
return ngraph::normalize_axis(
node.get_description(), axis, tensor_rank, axis_range_min, axis_range_max);
}
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
......
......@@ -15,13 +15,14 @@
//*****************************************************************************
#include "ngraph/op/one_hot.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::OneHot::type_info;
constexpr NodeTypeInfo op::v0::OneHot::type_info;
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
op::v0::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape)
, m_one_hot_axis(one_hot_axis)
......@@ -29,7 +30,7 @@ op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t on
constructor_validate_and_infer_types();
}
void op::OneHot::validate_and_infer_types()
void op::v0::OneHot::validate_and_infer_types()
{
element::Type arg_et = get_input_element_type(0);
PartialShape arg_shape = get_input_partial_shape(0);
......@@ -92,8 +93,86 @@ void op::OneHot::validate_and_infer_types()
set_output_type(0, arg_et, result_shape);
}
shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::OneHot::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
return make_shared<v0::OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
}
constexpr NodeTypeInfo op::v1::OneHot::type_info;
op::v1::OneHot::OneHot(const Output<Node>& indices,
const Output<Node>& depth,
const Output<Node>& on_value,
const Output<Node>& off_value,
int64_t axis)
: Op({indices, depth, on_value, off_value})
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
void op::v1::OneHot::validate_and_infer_types()
{
const auto& indices_et = get_input_element_type(0);
const auto& depth_et = get_input_element_type(1);
const auto& on_value_et = get_input_element_type(2);
const auto& off_value_et = get_input_element_type(3);
NODE_VALIDATION_CHECK(this,
indices_et.is_dynamic() || indices_et.is_integral(),
"Indices must be integral element type.");
NODE_VALIDATION_CHECK(this,
depth_et.is_dynamic() || depth_et.is_integral(),
"Depth must be integral element type.");
NODE_VALIDATION_CHECK(this,
on_value_et.compatible(off_value_et),
"on_value element type must be compatible with off_value element type.");
const auto& indices_shape = get_input_partial_shape(0);
const auto& depth_shape = get_input_partial_shape(1);
const auto& on_value_shape = get_input_partial_shape(2);
const auto& off_value_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
depth_shape.is_dynamic() || is_scalar(depth_shape.to_shape()),
"depth input must be scalar.");
NODE_VALIDATION_CHECK(this,
on_value_shape.is_dynamic() || is_scalar(on_value_shape.to_shape()),
"on_value input must be scalar.");
NODE_VALIDATION_CHECK(this,
off_value_shape.is_dynamic() || is_scalar(off_value_shape.to_shape()),
"off_value input must be scalar.");
const auto& depth = input_value(1).get_node_shared_ptr();
PartialShape result_shape{PartialShape::dynamic()};
if (indices_shape.is_static() && indices_shape.rank().is_static() && depth->is_constant())
{
const auto indices_rank = static_cast<int64_t>(indices_shape.rank());
std::vector<Dimension> out_dims(indices_rank);
for (auto i = 0; i < indices_rank; i++)
{
out_dims[i] = indices_shape[i];
}
m_axis =
ngraph::normalize_axis(this, m_axis, indices_rank + 1, -indices_rank - 1, indices_rank);
int64_t depth_val = as_type_ptr<op::Constant>(depth)->get_vector<int64_t>()[0];
out_dims.insert(out_dims.begin() + m_axis, Dimension(depth_val));
result_shape = out_dims;
}
set_output_type(0, on_value_et, result_shape);
}
shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::OneHot>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
}
......@@ -22,53 +22,96 @@ namespace ngraph
{
namespace op
{
// clang-format off
/// \brief One-hot operator.
///
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------------------------- |
/// | `shape` | The desired output shape, including the new one-hot axis. |
/// | `one_hot_axis` | The index within the output shape of the new one-hot axis. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | ------------------------------------------------------- | -------------------------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any non-floating point element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
// clang-format on
class OneHot : public Op
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"OneHot", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation.
// clang-format off
/// \brief One-hot operator.
///
/// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------------------------- |
/// | `shape` | The desired output shape, including the new one-hot axis. |
/// | `one_hot_axis` | The index within the output shape of the new one-hot axis. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | ------------------------------------------------------- | -------------------------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any non-floating point element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
// clang-format on
class OneHot : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"OneHot", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation.
///
/// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot
/// axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected:
PartialShape m_shape;
size_t m_one_hot_axis;
};
}
namespace v1
{
class OneHot : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"OneHot", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation.
///
/// \param indices Input tensor containing indices.
/// \param depth Specifies number of classes and the size of one-hot dimension.
/// \param on_value Specifies value that the locations in output tensor represented
/// by indices in input take.
/// \param off_value Specifies value that the locations in output tensor not
/// represented
/// by indices in input take.
/// \param axis Axis along which one-hot representation in added.
OneHot(const Output<Node>& indices,
const Output<Node>& depth,
const Output<Node>& on_value,
const Output<Node>& off_value,
int64_t axis);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected:
PartialShape m_shape;
size_t m_one_hot_axis;
};
/// \return The index of the one-hot axis.
int64_t get_axis() const { return m_axis; }
void set_axis(int64_t axis) { m_axis = axis; }
protected:
int64_t m_axis;
};
}
// default opset version
using v0::OneHot;
}
}
......@@ -113,7 +113,7 @@ NGRAPH_OP(Negative, ngraph::op)
// NGRAPH_OP(NonMaxSuppression, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op)
NGRAPH_OP(OneHot, ngraph::op::v1)
NGRAPH_OP(PRelu, ngraph::op)
// NGRAPH_OP(PSROIPooling, ngraph::op)
NGRAPH_OP(Pad, ngraph::op::v1)
......
......@@ -20,6 +20,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
......@@ -413,6 +414,38 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::OneHot_v1:
{
auto tmp = as_type_ptr<op::v1::OneHot>(node);
const auto indices = tmp->input_value(0).get_node_shared_ptr();
const auto depth = tmp->input_value(1).get_node_shared_ptr();
auto on_value = tmp->input_value(2).get_node_shared_ptr();
auto off_value = tmp->input_value(3).get_node_shared_ptr();
const auto axis = tmp->get_axis();
NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
const auto const_depth = as_type_ptr<op::Constant>(depth);
std::int64_t depth_value = const_depth->get_vector<std::int64_t>()[0];
const auto indices_shape = tmp->get_input_partial_shape(0);
NGRAPH_CHECK(indices_shape.is_static(), "indices shape must be static", *node);
auto output_shape = indices_shape.to_shape();
output_shape.insert(output_shape.begin() + axis, depth_value);
auto one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
on_value->get_element_type());
auto broadcasted_values = op::numpy_style_broadcast({one_hot, on_value, off_value});
on_value = broadcasted_values[1];
off_value = broadcasted_values[2];
auto replacement_node = one_hot * (on_value - off_value) + off_value;
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Pad_v1:
{
auto tmp = as_type_ptr<op::v1::Pad>(node);
......
......@@ -405,6 +405,28 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::OneHot:
{
auto tmp = as_type_ptr<op::v0::OneHot>(node);
const auto indices = tmp->input_value(0).get_node_shared_ptr();
const auto one_hot_axis = tmp->get_one_hot_axis();
const auto output_pshape = tmp->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
"OneHot:v0 one hot axis dimension must be static ",
*node);
const auto depth = static_cast<int64_t>(output_pshape[one_hot_axis]);
const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
auto replacement_node =
make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Or:
{
upgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
......
......@@ -42,6 +42,7 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
......@@ -110,7 +111,6 @@ namespace ngraph
class ScatterNDAdd;
class UpdateSlice;
class ReplaceSlice;
class OneHot;
class Ceiling;
class Floor;
class Sqrt;
......
......@@ -2123,9 +2123,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::OneHot:
case OP_TYPEID::OneHot_v1:
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], read_partial_shape(shape), one_hot_axis);
if (op_version == 0)
{
auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node =
make_shared<op::v0::OneHot>(args[0], read_partial_shape(shape), one_hot_axis);
}
if (op_version == 1)
{
auto axis = node_js.at("axis").get<int64_t>();
node = make_shared<op::v1::OneHot>(args[0], args[1], args[2], args[3], axis);
}
break;
}
case OP_TYPEID::Or:
......@@ -3864,9 +3873,17 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::OneHot:
case OP_TYPEID::OneHot_v1:
{
auto tmp = static_cast<const op::OneHot*>(&n);
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["one_hot_axis"] = tmp->get_one_hot_axis();
if (op_version == 0)
{
auto tmp = static_cast<const op::v0::OneHot*>(&n);
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["one_hot_axis"] = tmp->get_one_hot_axis();
}
if (op_version == 1)
{
auto tmp = static_cast<const op::v1::OneHot*>(&n);
node["axis"] = tmp->get_axis();
}
break;
}
case OP_TYPEID::Or:
......
......@@ -795,14 +795,30 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
return dim;
}
std::size_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
int64_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank)
{
const auto axis_range_min = -tensor_rank;
const auto axis_range_max = tensor_rank - 1;
return normalize_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
int64_t ngraph::normalize_axis(const Node* node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
return ngraph::normalize_axis(
node->description(), axis, tensor_rank, axis_range_min, axis_range_max);
}
int64_t ngraph::normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node->description(),
node_description,
"Parameter axis ",
axis,
" out of the tensor rank [-",
......@@ -816,5 +832,5 @@ std::size_t ngraph::normalize_axis(const Node* node, std::int64_t axis, std::int
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
return static_cast<int64_t>(axis);
}
......@@ -103,5 +103,48 @@ namespace ngraph
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
std::size_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node, std::int64_t axis, std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const Node* node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
/// \brief Handle out of range axis.
///
/// \param[in] node The name of node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error. If negative axis, it counts from the last to the first axis,
/// by adding tensor_rank to axis.
int64_t normalize_axis(const std::string& node_description,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
}
......@@ -79,6 +79,7 @@ set(SRC
opset_pass/logical_not_opset_pass.cpp
opset_pass/logical_or_opset_pass.cpp
opset_pass/logical_xor_opset_pass.cpp
opset_pass/one_hot_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
......
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_one_hot_upgrade_pass)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
const auto depth = 4;
PartialShape shape{1, 3, 2, depth, 3};
size_t one_hot_axis = 3;
auto ont_hot_v0 = make_shared<op::v0::OneHot>(indices, shape, one_hot_axis);
auto result = make_shared<op::Result>(ont_hot_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto one_hot_v1 = static_pointer_cast<op::v1::OneHot>(pass_replacement_node);
EXPECT_EQ(one_hot_v1->description(), "OneHot");
EXPECT_EQ(one_hot_v1->get_version(), 1);
EXPECT_EQ(one_hot_v1->get_axis(), one_hot_axis);
auto one_hot_v1_depth =
as_type_ptr<op::Constant>(one_hot_v1->input_value(1).get_node_shared_ptr());
EXPECT_EQ(one_hot_v1_depth->get_vector<int64_t>()[0], depth);
auto one_hot_v1_on_value =
as_type_ptr<op::Constant>(one_hot_v1->input_value(2).get_node_shared_ptr());
EXPECT_EQ(one_hot_v1_on_value->get_vector<int64_t>()[0], 1);
auto one_hot_v1_off_value =
as_type_ptr<op::Constant>(one_hot_v1->input_value(3).get_node_shared_ptr());
EXPECT_EQ(one_hot_v1_off_value->get_vector<int64_t>()[0], 0);
}
TEST(opset_transform, opset1_one_hot_downgrade_pass)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
auto depth = op::Constant::create(element::i64, Shape{}, {4});
auto on_value = op::Constant::create(element::u32, Shape{}, {5});
auto off_value = op::Constant::create(element::u32, Shape{}, {10});
int64_t axis = 3;
auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
auto result = make_shared<op::Result>(ont_hot_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto one_hot_v0 = static_pointer_cast<op::v0::OneHot>(pass_replacement_node);
EXPECT_EQ(one_hot_v0->get_shape(), (Shape{1, 3, 2, 4, 3}));
}
TEST(opset_transform, opset1_one_hot_downgrade_pass_depth_not_constant)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
auto depth = make_shared<op::Parameter>(element::i64, Shape{});
auto on_value = op::Constant::create(element::u32, Shape{}, {5});
auto off_value = op::Constant::create(element::u32, Shape{}, {10});
int64_t axis = 3;
auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
auto result = make_shared<op::Result>(ont_hot_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices, depth});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
// Should have thrown, so fail if it didn't
FAIL() << "Not constant depth not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("depth input must be constant"));
}
catch (...)
{
FAIL() << "OneHot downgrade failed for unexpected reason";
}
}
TEST(opset_transform, opset1_one_hot_downgrade_pass_indices_shape_not_static)
{
auto indices = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto depth = op::Constant::create(element::i64, Shape{}, {4});
auto on_value = op::Constant::create(element::u32, Shape{}, {5});
auto off_value = op::Constant::create(element::u32, Shape{}, {10});
int64_t axis = 3;
auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
auto result = make_shared<op::Result>(ont_hot_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
// Should have thrown, so fail if it didn't
FAIL() << "Not static indices shape not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("indices shape must be static"));
}
catch (...)
{
FAIL() << "OneHot downgrade failed for unexpected reason";
}
}
......@@ -372,3 +372,167 @@ TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_output_shape)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{3});
auto depth = op::Constant::create(element::i64, Shape{}, {2});
auto on_value = op::Constant::create(element::u32, Shape{}, {5});
auto off_value = op::Constant::create(element::u32, Shape{}, {10});
int64_t axis = -1;
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
ASSERT_EQ(ont_hot->get_element_type(), element::u32);
ASSERT_EQ(ont_hot->get_shape(), (Shape{3, 2}));
}
TEST(type_prop, one_hot_v1_output_shape_2)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
auto depth = op::Constant::create(element::i64, Shape{}, {4});
auto on_value = op::Constant::create(element::f32, Shape{}, {1.0f});
auto off_value = op::Constant::create(element::f32, Shape{}, {0.0f});
int64_t axis = 3;
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
ASSERT_EQ(ont_hot->get_element_type(), element::f32);
ASSERT_EQ(ont_hot->get_shape(), (Shape{1, 3, 2, 4, 3}));
}
TEST(type_prop, one_hot_v1_indices_elem_not_integral)
{
auto indices = make_shared<op::Parameter>(element::f16, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::i64, Shape{});
auto on_value = make_shared<op::Parameter>(element::u32, Shape{});
auto off_value = make_shared<op::Parameter>(element::u32, Shape{});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices must be integral element type."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_depth_elem_not_integral)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::f16, Shape{});
auto on_value = make_shared<op::Parameter>(element::u32, Shape{});
auto off_value = make_shared<op::Parameter>(element::u32, Shape{});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect depth element type not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Depth must be integral element type."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_on_off_values_not_compatible)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::i64, Shape{});
auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
auto off_value = make_shared<op::Parameter>(element::f16, Shape{});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible on/off element types not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("on_value element type must be compatible with off_value element type."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_depth_not_scalar)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::i64, Shape{1});
auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
auto off_value = make_shared<op::Parameter>(element::bf16, Shape{});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Not scalar depth input not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("depth input must be scalar."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_on_value_not_scalar)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::i64, Shape{});
auto on_value = make_shared<op::Parameter>(element::bf16, Shape{2});
auto off_value = make_shared<op::Parameter>(element::bf16, Shape{});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Not scalar on_value input not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("on_value input must be scalar."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_v1_off_value_not_scalar)
{
auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
auto depth = make_shared<op::Parameter>(element::i64, Shape{});
auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
auto off_value = make_shared<op::Parameter>(element::bf16, Shape{3});
int64_t axis = -1;
try
{
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Not scalar off_value input not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("off_value input must be scalar."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment