Commit 9fea22b2 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Change OneHot to accept only integral types (#2689)

* Change OneHot to accept only non-real types

* Update docstring

* Update Python test

* Add is_integral to element::Type

* Update docs

* Change is_integral to false for boolean

* Revert "Change is_integral to false for boolean"

This reverts commit 099ff378ae7fcbd1d9346665812f6b95e4886186.

* Revert "Add is_integral to element::Type"

This reverts commit 58fdf76fecaefdad10431f9a894523f326f3adca.

* Change is_integral so it is, by definition, !is_real
parent d734d6d6
......@@ -15,11 +15,11 @@ Description
Inputs
------
+-----------------+-------------------------+---------------------------------------------------------+
| Name | Element Type | Shape |
+=================+=========================+=========================================================+
| ``arg`` | Any | :math:`d_1,\dots,d_{m-1},d_{m+1},\dots,d_n)~(n \geq 0)` |
+-----------------+-------------------------+---------------------------------------------------------+
+-----------------+-------------------+---------------------------------------------------------+
| Name | Element Type | Shape |
+=================+===================+=========================================================+
| ``arg`` | Any integral type | :math:`d_1,\dots,d_{m-1},d_{m+1},\dots,d_n)~(n \geq 0)` |
+-----------------+-------------------+---------------------------------------------------------+
Attributes
----------
......
......@@ -659,7 +659,7 @@ def test_constant():
def test_onehot():
element_type = Type.f32
element_type = Type.i32
A = Parameter(element_type, Shape([3]))
parameter_list = [A]
function = Function(NodeVector([OneHot(A, Shape([3, 3]), 0)]), parameter_list, 'test')
......@@ -668,9 +668,9 @@ def test_onehot():
a = backend.create_tensor(element_type, Shape([3]))
result = backend.create_tensor(element_type, Shape([3, 3]))
a.write(util.numpy_to_c(np.array([1, 0, 2], dtype=np.float32)), 0, 12)
a.write(util.numpy_to_c(np.array([1, 0, 2], dtype=np.int32)), 0, 12)
result_arr = np.zeros((3, 3), dtype=np.float32)
result_arr = np.zeros((3, 3), dtype=np.int32)
result.write(util.numpy_to_c(result_arr), 0, 36)
handle = backend.compile(function)
handle.call([result], [a])
......
......@@ -34,6 +34,10 @@ void op::OneHot::validate_and_infer_types()
PartialShape arg_shape = get_input_partial_shape(0);
Rank arg_rank = arg_shape.rank();
NODE_VALIDATION_CHECK(this,
arg_et.is_dynamic() || arg_et.is_integral(),
"Argument does not have integral element type.");
NODE_VALIDATION_CHECK(
this, m_shape.rank().is_static(), "Requested result shape has dynamic rank.");
......
......@@ -33,9 +33,9 @@ namespace ngraph
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | ------------------------------------------------------- | ------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any element type. |
/// | | Type | Description |
/// | ----- | ------------------------------------------------------- | -------------------------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any non-floating point element type. |
///
/// ## Output
///
......
......@@ -78,6 +78,9 @@ namespace ngraph
bool is_static() const;
bool is_dynamic() const { return !is_static(); }
bool is_real() const;
// TODO: We may want to revisit this definition when we do a more general cleanup of
// element types:
bool is_integral() const { return !is_real(); }
bool is_signed() const;
bool is_quantized() const;
size_t bitwidth() const;
......
......@@ -94,36 +94,6 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_0_in_3)
EXPECT_EQ((vector<int32_t>{1, 0, 0}), read_vector<int32_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_fp_nonint_in_3)
{
Shape shape_a{};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{3};
auto r = make_shared<op::OneHot>(A, Shape{3}, 0);
auto f = make_shared<Function>(r, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{1.1f});
auto result = backend->create_tensor(element::f32, shape_r);
try
{
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
}
catch (const std::exception& e)
{
EXPECT_EQ(e.what(), std::string("One-hot: non-integral value in input"));
}
catch (...)
{
FAIL() << "Expected a std::out_of_range exception";
}
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_oob_in_3)
{
Shape shape_a{};
......@@ -270,58 +240,6 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_matrix_0)
read_vector<int32_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_1_fp)
{
Shape shape_a{8};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{8, 3};
auto r = make_shared<op::OneHot>(A, Shape{8, 3}, 1);
auto f = make_shared<Function>(r, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{2, 1, 0, 0, 2, 2, 1, 0});
auto result = backend->create_tensor(element::f32, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ(
(vector<float>{0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_1_fp_nonint)
{
Shape shape_a{8};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{8, 3};
auto r = make_shared<op::OneHot>(A, Shape{8, 3}, 1);
auto f = make_shared<Function>(r, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{2, 1, 0, 0, 2, 2, 1.01f, 0});
auto result = backend->create_tensor(element::f32, shape_r);
try
{
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
}
catch (const std::exception& e)
{
EXPECT_EQ(e.what(), std::string("One-hot: non-integral value in input"));
}
catch (...)
{
FAIL() << "Expected a std::out_of_range exception";
}
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_many_categories)
{
// Imagenet has roughly 20,000 categories
......
......@@ -4343,12 +4343,32 @@ TEST(type_prop, one_hot_deduce_matrix_2)
ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
}
TEST(type_prop, one_hot_deduce_et_dynamic)
{
auto param = make_shared<op::Parameter>(element::dynamic, Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
ASSERT_EQ(oh->get_element_type(), element::dynamic);
ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
}
TEST(type_prop, one_hot_deduce_floating_point)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{12, 24});
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 2);
ASSERT_EQ(oh->get_element_type(), element::f32);
ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 8}));
try
{
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid floating-point element type not detected.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument does not have integral element type."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_deduce_axis_oob)
......@@ -4396,7 +4416,7 @@ TEST(type_prop, one_hot_partial_rank_dynamic_rank_dynamic)
PartialShape requested_shape{PartialShape::dynamic()};
size_t one_hot_axis{3000};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4419,10 +4439,10 @@ TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_ok)
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
ASSERT_EQ(oh->get_output_element_type(0), element::f32);
ASSERT_EQ(oh->get_output_element_type(0), element::i32);
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, 3, Dimension::dynamic()}));
}
......@@ -4433,7 +4453,7 @@ TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_dim_dyn
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{3};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4458,7 +4478,7 @@ TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_axis_oo
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{4};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4484,10 +4504,10 @@ TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_ok)
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
ASSERT_EQ(oh->get_output_element_type(0), element::f32);
ASSERT_EQ(oh->get_output_element_type(0), element::i32);
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
PartialShape{3, 2, 3, Dimension::dynamic(), 4}));
}
......@@ -4499,7 +4519,7 @@ TEST(type_prop,
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4526,7 +4546,7 @@ TEST(type_prop,
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4553,7 +4573,7 @@ TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompat
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4580,7 +4600,7 @@ TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......@@ -4607,7 +4627,7 @@ TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto param = make_shared<op::Parameter>(element::i32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment