Commit 982889f5 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4c: Slice and ReplaceSlice (#1781)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Slice

* Add unit tests for partial shape/type validation for Slice

* Implement partial shape/type propagation for ReplaceSlice

* Add unit tests for partial type/shape validation for ReplaceSlice

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Change an internal variable name to something more descriptive

* Review comments
parent 91219e40
...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides) , m_strides(strides)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1)) , m_strides(Strides(lower_bounds.size(), 1))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
void op::ReplaceSlice::check_args() void op::ReplaceSlice::validate_and_infer_types()
{ {
auto& input_0 = get_inputs().at(0); // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
auto& input_0_shape = input_0.get_shape(); // construct the default value.
auto& input_0_element_type = input_0.get_element_type(); if (m_strides.size() == 0)
{
auto& input_1 = get_inputs().at(1); m_strides = Strides(m_lower_bounds.size(), 1);
auto& input_1_shape = input_1.get_shape(); }
auto& input_1_element_type = input_1.get_element_type();
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size()) const PartialShape& arg0_shape = get_input_partial_shape(0);
<< "Argument ranks do not match (arg0 shape: " << input_0_shape const PartialShape& arg1_shape = get_input_partial_shape(1);
<< ", arg1 shape: " << input_1_shape << ")."; Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type) NODE_VALIDATION_ASSERT(this,
<< "Argument element types do not match (arg0 element type: " << input_0_element_type Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< ", arg1 element type: " << input_1_element_type << ")."; << "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size()) element::Type arg0_et = get_input_element_type(0);
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank " element::Type arg1_et = get_input_element_type(1);
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds element::Type merged_args_et;
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " << "Argument element types do not match (arg0 element type: " << arg0_et
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds << ", arg1 element type: " << arg1_et << ").";
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of strides (" << m_strides.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_0_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape; size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; NODE_VALIDATION_ASSERT(this,
slice_axis_size = merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1); << "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
slice_shape.push_back(slice_axis_size); << "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
} }
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape) PartialShape slice_shape{sliced_dims};
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ")."; << "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape); // Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
} }
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -88,11 +88,11 @@ namespace ngraph ...@@ -88,11 +88,11 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void check_args(); void validate_and_infer_types() override;
const Coordinate m_lower_bounds; Coordinate m_lower_bounds;
const Coordinate m_upper_bounds; Coordinate m_upper_bounds;
const Strides m_strides; Strides m_strides;
}; };
} }
} }
...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types() void op::Slice::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
{ // construct the default value.
return; if (m_strides.size() == 0)
}
if (0 == m_strides.size())
{ {
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
} }
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size()) size_t output_rank = m_upper_bounds.size();
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
Shape result_shape; for (size_t i = 0; i < output_rank; i++)
for (size_t i = 0; i < input_shape.size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size = result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1); result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size); result_dims[i] = result_axis_size;
} }
set_output_type(0, input.get_element_type(), result_shape); set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
} }
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -1972,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides) ...@@ -1972,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)")); error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2075,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing) ...@@ -2075,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2096,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing) ...@@ -2096,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2115,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra) ...@@ -2115,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks of lower bounds (Coordinate{0, 0, "
std::string("Rank of lower bounds (3) does not match rank of argument (2)")); "0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2135,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra) ...@@ -2135,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected"; FAIL() << "Extra upper bound coordinate not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_input_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_attribs_rank_mismatches_arg)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input rank does not match the "
"rank of the lower bounds (Coordinate{1, 2, "
"3, 4}), upper bounds (Coordinate{1, 3, 5, "
"7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
} }
catch (...) catch (...)
{ {
...@@ -2282,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides) ...@@ -2282,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)")); error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2345,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch) ...@@ -2345,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Shape of replacement tensor (Shape{3, 6}) does not match " error.what(),
"the slice shape (Shape{4, 6})")); std::string(
"Shape of replacement tensor ({3,6}) does not match the slice shape ({4,6})"));
} }
catch (...) catch (...)
{ {
...@@ -2371,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided) ...@@ -2371,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided)
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string( std::string(
"Shape of replacement tensor (Shape{4, 6}) does not match the slice shape")); "Shape of replacement tensor ({4,6}) does not match the slice shape ({4,3})"));
} }
catch (...) catch (...)
{ {
...@@ -2481,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing) ...@@ -2481,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2503,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing) ...@@ -2503,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)")); std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2524,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra) ...@@ -2524,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks of lower bounds (Coordinate{0, 0, "
std::string("Rank of lower bounds (3) does not match rank of argument (2)")); "0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2546,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra) ...@@ -2546,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected"; FAIL() << "Extra upper bound coordinate not detected";
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)")); std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(
rsl->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 10, Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_attribs_rank_mismatches_input)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{0, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_attribs_mismatch_replacement_shape)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{1, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of shape inferred from attributes with provided replacement shape not "
"detected (rank-dynamic/rank-static dynamic inputs)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape of replacement tensor ({1,?,?,2}) does not match "
"the slice shape ({0,1,2,2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_attribs_rank_mismatches_replacement)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (arguments "
"rank-dynamic/rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_static_dynamic_argument_ranks_mismatch)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatching input/replacement ranks not detected (arguments both rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument ranks do not match"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment