Commit 3ea3c0c3 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4i: OneHot (#1832)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial type/shape propagation for GetOutputElement

* Basic support for partial shape/type propagation for OneHot

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Update OneHot to take PartialShape for result, with dynamic dimension allowed at non-one-hot axes

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* (->{ ; )->}

* size_t{...} -> static_cast<size_t>(...)
parent 62524c8d
...@@ -20,28 +20,58 @@ ...@@ -20,28 +20,58 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis) op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op("OneHot", check_single_output_args({arg})) : Op("OneHot", check_single_output_args({arg}))
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
void op::OneHot::validate_and_infer_types()
{
element::Type arg_et = get_input_element_type(0);
PartialShape arg_shape = get_input_partial_shape(0);
Rank arg_rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, m_shape.rank().is_static())
<< "Requested result shape has dynamic rank.";
NODE_VALIDATION_ASSERT(this, m_one_hot_axis < static_cast<size_t>(m_shape.rank()))
<< "One-hot axis (" << m_one_hot_axis
<< ") is out of bounds (requested result shape: " << m_shape << ").";
NODE_VALIDATION_ASSERT(this, m_shape[m_one_hot_axis].is_static())
<< "Requested result shape (" << m_shape << ") has dynamic dimension at the one-hot axis "
<< "(" << m_one_hot_axis << ").";
auto& input = m_inputs.at(0); PartialShape result_shape{m_shape};
auto& input_element_type = input.get_element_type();
NODE_VALIDATION_ASSERT(this, one_hot_axis < shape.size()) if (arg_rank.is_static())
<< "One-hot axis (" << one_hot_axis {
<< ") is out of bounds (requested result shape: " << shape << ")."; std::vector<Dimension> expected_input_dims(static_cast<size_t>(m_shape.rank()));
for (size_t i = 0; i < static_cast<size_t>(m_shape.rank()); i++)
{
expected_input_dims[i] = m_shape[i];
}
expected_input_dims.erase(expected_input_dims.begin() + m_one_hot_axis);
PartialShape expected_input_shape{expected_input_dims};
auto expected_input_shape = shape; PartialShape merged_input_shape{expected_input_shape};
expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis); NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(merged_input_shape, arg_shape))
<< "Argument shape " << arg_shape << " does not match the expected shape of "
<< expected_input_shape << ".";
NODE_VALIDATION_ASSERT(this, input.get_shape() == expected_input_shape) std::vector<Dimension> output_dims(static_cast<size_t>(merged_input_shape.rank()));
<< "Argument shape " << input.get_shape() << " does not match the expected shape of " for (size_t i = 0; i < static_cast<size_t>(merged_input_shape.rank()); i++)
<< expected_input_shape << "."; {
output_dims[i] = merged_input_shape[i];
}
output_dims.insert(output_dims.begin() + m_one_hot_axis, m_shape[m_one_hot_axis]);
result_shape = PartialShape{output_dims};
}
set_output_type(0, input_element_type, shape); set_output_type(0, arg_et, result_shape);
} }
shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
/// \param arg Node that produces the input tensor to be one-hot encoded. /// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis. /// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis. /// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis); OneHot(const std::shared_ptr<Node>& arg,
const PartialShape& shape,
size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -58,7 +60,9 @@ namespace ngraph ...@@ -58,7 +60,9 @@ namespace ngraph
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
protected: protected:
Shape m_shape; void validate_and_infer_types() override;
PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
}; };
} }
......
...@@ -151,13 +151,25 @@ static string ...@@ -151,13 +151,25 @@ static string
static json write_dimension(Dimension d) static json write_dimension(Dimension d)
{ {
if (d.is_static()) if (d.is_dynamic())
{ {
return size_t(d); return nullptr;
} }
else else
{ {
return nullptr; return static_cast<size_t>(d);
}
}
static Dimension read_dimension(const json& j)
{
if (j.is_null())
{
return Dimension::dynamic();
}
else
{
return Dimension(static_cast<size_t>(j));
} }
} }
...@@ -169,7 +181,7 @@ static json write_partial_shape(const PartialShape& s) ...@@ -169,7 +181,7 @@ static json write_partial_shape(const PartialShape& s)
} }
else else
{ {
std::vector<json> vals(size_t(s.rank())); std::vector<json> vals(static_cast<size_t>(s.rank()));
for (size_t i = 0; i < vals.size(); i++) for (size_t i = 0; i < vals.size(); i++)
{ {
vals[i] = write_dimension(s[i]); vals[i] = write_dimension(s[i]);
...@@ -189,14 +201,7 @@ static PartialShape read_partial_shape(const json& j) ...@@ -189,14 +201,7 @@ static PartialShape read_partial_shape(const json& j)
std::vector<Dimension> dims(j.size()); std::vector<Dimension> dims(j.size());
for (size_t i = 0; i < j.size(); i++) for (size_t i = 0; i < j.size(); i++)
{ {
if (j[i].is_null()) dims[i] = read_dimension(j[i]);
{
dims[i] = Dimension::dynamic();
}
else
{
dims[i] = size_t(j[i]);
}
} }
return PartialShape(dims); return PartialShape(dims);
} }
...@@ -868,7 +873,7 @@ static shared_ptr<ngraph::Function> ...@@ -868,7 +873,7 @@ static shared_ptr<ngraph::Function>
{ {
auto shape = node_js.at("shape").get<vector<size_t>>(); auto shape = node_js.at("shape").get<vector<size_t>>();
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>(); auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis); node = make_shared<op::OneHot>(args[0], read_partial_shape(shape), one_hot_axis);
break; break;
} }
case OP_TYPEID::Or: case OP_TYPEID::Or:
...@@ -1426,7 +1431,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1426,7 +1431,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::OneHot: case OP_TYPEID::OneHot:
{ {
auto tmp = dynamic_cast<const op::OneHot*>(&n); auto tmp = dynamic_cast<const op::OneHot*>(&n);
node["shape"] = tmp->get_shape(); node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["one_hot_axis"] = tmp->get_one_hot_axis(); node["one_hot_axis"] = tmp->get_one_hot_axis();
break; break;
} }
......
...@@ -3365,10 +3365,246 @@ TEST(type_prop, one_hot_deduce_shape_incompatible) ...@@ -3365,10 +3365,246 @@ TEST(type_prop, one_hot_deduce_shape_incompatible)
FAIL() << "Incompatible one-hot output shape not detected."; FAIL() << "Incompatible one-hot output shape not detected.";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Argument shape {12,24} does not match the expected shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_dynamic_rank_dynamic)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape requested_shape{PartialShape::dynamic()};
size_t one_hot_axis{3000};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Dynamic rank for requested result shape not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Requested result shape has dynamic rank"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
ASSERT_EQ(oh->get_output_element_type(0), element::f32);
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, 3, Dimension::dynamic()}));
}
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{3};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Dynamic one-hot dimension not detected";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Requested result shape ({?,2,3,?}) has dynamic dimension "
"at the one-hot axis (3)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_axis_oob)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
size_t one_hot_axis{4};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "One-hot axis out of bounds not detected (rank-dynamic argument, rank-static "
"dynamic result shape)";
}
catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Argument shape Shape{12, 24} does not match the expected shape")); std::string("One-hot axis (4) is out of bounds (requested result shape: {?,2,3,?})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_ok)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
ASSERT_EQ(oh->get_output_element_type(0), element::f32);
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
PartialShape{3, 2, 3, Dimension::dynamic(), 4}));
}
TEST(type_prop,
one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_short)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
"rank-static dynamic result shape)";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shape {3,?,?} does not match the expected shape of {?,2,?,4}"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_long)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
"rank-static dynamic result shape)";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Argument shape {3,?,?,4,5} does not match the expected shape of {?,2,?,4}"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_dim)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 5};
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible input/output dimensions not detected (rank-static dynamic "
"argument, rank-static dynamic result shape)";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shape {3,?,?,5} does not match the expected shape of {?,2,?,4}"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
PartialShape requested_shape{
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "Dynamic one-hot dimension not detected (rank-static dynamic argument, "
"rank-static dynamic result shape)";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
"dimension at the one-hot axis (2)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_axis_oob)
{
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
PartialShape requested_shape{
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
size_t one_hot_axis{2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
// Should have thrown, so fail if it didn't
FAIL() << "One-hot axis out of bounds not detected (rank-static dynamic argument, "
"rank-static dynamic result shape)";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
"dimension at the one-hot axis (2)"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment