Commit 982889f5 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4c: Slice and ReplaceSlice (#1781)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Slice

* Add unit tests for partial shape/type validation for Slice

* Implement partial shape/type propagation for ReplaceSlice

* Add unit tests for partial type/shape validation for ReplaceSlice

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Change an internal variable name to something more descriptive

* Review comments
parent 91219e40
......@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides)
{
constructor_validate_and_infer_types();
check_args();
}
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
......@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1))
{
constructor_validate_and_infer_types();
check_args();
}
void op::ReplaceSlice::check_args()
void op::ReplaceSlice::validate_and_infer_types()
{
auto& input_0 = get_inputs().at(0);
auto& input_0_shape = input_0.get_shape();
auto& input_0_element_type = input_0.get_element_type();
auto& input_1 = get_inputs().at(1);
auto& input_1_shape = input_1.get_shape();
auto& input_1_element_type = input_1.get_element_type();
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size())
<< "Argument ranks do not match (arg0 shape: " << input_0_shape
<< ", arg1 shape: " << input_1_shape << ").";
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type)
<< "Argument element types do not match (arg0 element type: " << input_0_element_type
<< ", arg1 element type: " << input_1_element_type << ").";
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_0_shape << ").";
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape;
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
slice_axis_size =
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1);
slice_shape.push_back(slice_axis_size);
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
}
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape)
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
}
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -88,11 +88,11 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void check_args();
void validate_and_infer_types() override;
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Strides m_strides;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
Strides m_strides;
};
}
}
......@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size())
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
size_t output_rank = m_upper_bounds.size();
Shape result_shape;
for (size_t i = 0; i < input_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size);
result_dims[i] = result_axis_size;
}
set_output_type(0, input.get_element_type(), result_shape);
set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
}
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -1972,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)"));
error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
}
catch (...)
{
......@@ -2075,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)"));
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
}
catch (...)
{
......@@ -2096,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)"));
std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
}
catch (...)
{
......@@ -2115,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of lower bounds (3) does not match rank of argument (2)"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0, "
"0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
}
catch (...)
{
......@@ -2135,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_input_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)"));
std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
ASSERT_EQ(sl->get_element_type(), element::f32);
ASSERT_EQ(sl->get_shape(), (Shape{0, 1, 2, 2}));
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_attribs_rank_mismatches_arg)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Input rank does not match the "
"rank of the lower bounds (Coordinate{1, 2, "
"3, 4}), upper bounds (Coordinate{1, 3, 5, "
"7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_partial_arg_rank_static_dynamic_some_dims_known_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
try
{
auto sl = make_shared<op::Slice>(param, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
}
catch (...)
{
......@@ -2282,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides)
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Rank of strides (2) does not match rank of argument (1)"));
error.what(),
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"));
}
catch (...)
{
......@@ -2345,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape of replacement tensor (Shape{3, 6}) does not match "
"the slice shape (Shape{4, 6})"));
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Shape of replacement tensor ({3,6}) does not match the slice shape ({4,6})"));
}
catch (...)
{
......@@ -2371,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided)
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Shape of replacement tensor (Shape{4, 6}) does not match the slice shape"));
"Shape of replacement tensor ({4,6}) does not match the slice shape ({4,3})"));
}
catch (...)
{
......@@ -2481,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of lower bounds (1) does not match rank of argument (2)"));
std::string("Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"));
}
catch (...)
{
......@@ -2503,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of upper bounds (1) does not match rank of argument (2)"));
std::string("Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"));
}
catch (...)
{
......@@ -2524,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of lower bounds (3) does not match rank of argument (2)"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0, "
"0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"));
}
catch (...)
{
......@@ -2546,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
FAIL() << "Extra upper bound coordinate not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_rank_mismatch)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_bounds_crossing)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 8};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Crossing lower/upper bounds not detected (argument rank-dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Rank of upper bounds (3) does not match rank of argument (2)"));
std::string("Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_ok)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_ok)
{
PartialShape input_shape{2, 4, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(
rsl->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 10, Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_attribs_rank_mismatches_input)
{
PartialShape input_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_upper_bounds_oob)
{
PartialShape input_shape{2, 2, 10, Dimension::dynamic()};
PartialShape replacement_shape{PartialShape::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Upper bounds out of bounds not detected (argument rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_ok)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{0, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
auto rsl = make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_TRUE(rsl->get_output_partial_shape(0).same_scheme(PartialShape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_attribs_mismatch_replacement_shape)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{1, Dimension::dynamic(), Dimension::dynamic(), 2};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of shape inferred from attributes with provided replacement shape not "
"detected (rank-dynamic/rank-static dynamic inputs)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape of replacement tensor ({1,?,?,2}) does not match "
"the slice shape ({0,1,2,2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_attribs_rank_mismatches_replacement)
{
PartialShape input_shape{PartialShape::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of attrib ranks with arg ranks not detected (arguments "
"rank-dynamic/rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_static_dynamic_argument_ranks_mismatch)
{
PartialShape input_shape{
Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape replacement_shape{Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()};
Coordinate lower_bounds{1, 2, 3, 4};
Coordinate upper_bounds{1, 3, 5, 7};
Strides strides{1, 1, 1, 2};
auto param0 = make_shared<op::Parameter>(element::f32, input_shape);
auto param1 = make_shared<op::Parameter>(element::f32, replacement_shape);
try
{
auto rsl =
make_shared<op::ReplaceSlice>(param0, param1, lower_bounds, upper_bounds, strides);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatching input/replacement ranks not detected (arguments both rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument ranks do not match"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment