Commit 85582d0c authored by Adam Procter's avatar Adam Procter

Moar tests

parent c99d65a0
...@@ -91,6 +91,17 @@ void op::DynReplaceSlice::validate_and_infer_types() ...@@ -91,6 +91,17 @@ void op::DynReplaceSlice::validate_and_infer_types()
strides_shape.rank(), strides_shape.rank(),
"."); ".");
PartialShape attrs_shape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
(lower_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, lower_bounds_shape)) &&
(upper_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, upper_bounds_shape)) &&
(strides_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, strides_shape)),
"Shapes for lower bounds, upper bounds, and strides do not match");
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4); set_input_is_relevant_to_shape(4);
...@@ -99,32 +110,32 @@ void op::DynReplaceSlice::validate_and_infer_types() ...@@ -99,32 +110,32 @@ void op::DynReplaceSlice::validate_and_infer_types()
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3)); auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4)); auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4));
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
auto slice_shape = PartialShape::dynamic();
if (lower_bounds && upper_bounds && strides) if (lower_bounds && upper_bounds && strides)
{ {
auto inferred_slice_shape = infer_slice_shape(this, slice_shape = infer_slice_shape(this,
get_input_partial_shape(0), get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(), lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(), upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(), strides->get_vector<int64_t>(),
m_lower_bounds_mask, m_lower_bounds_mask,
m_upper_bounds_mask, m_upper_bounds_mask,
m_new_axis, m_new_axis,
m_shrink_axis, m_shrink_axis,
m_ellipsis_mask); m_ellipsis_mask);
NODE_VALIDATION_CHECK(this,
replacement_shape.compatible(inferred_slice_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
inferred_slice_shape,
")");
} }
PartialShape output_shape = arg_shape; NODE_VALIDATION_CHECK(this,
NODE_VALIDATION_CHECK( slice_shape.compatible(replacement_shape),
this, "Shape of the replacement is not compatible with the shape of the "
PartialShape::merge_into(output_shape, PartialShape::dynamic(replacement_shape.rank())), "slice (shape of slice: ",
"Rank of the replacement is not compatible with rank of the argument tensor"); slice_shape,
set_output_type(0, get_input_element_type(0), output_shape); ")");
set_output_type(0, result_et, arg_shape);
} }
shared_ptr<Node> op::DynReplaceSlice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::DynReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -626,7 +626,6 @@ PartialShape ngraph::infer_slice_shape(const Node* node, ...@@ -626,7 +626,6 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
{ {
// TODO(amprocte): double-check that these checks are needed.
if (lb.size() && ub.size()) if (lb.size() && ub.size())
{ {
NODE_VALIDATION_CHECK(node, NODE_VALIDATION_CHECK(node,
......
...@@ -14900,10 +14900,10 @@ TEST(type_prop, fake_quantize_invalid_rank) ...@@ -14900,10 +14900,10 @@ TEST(type_prop, fake_quantize_invalid_rank)
} }
} }
TEST(type_prop, dynreplaceslice_arg_static_params_static_ok) TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_static_ok)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4}); auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
...@@ -14915,203 +14915,527 @@ TEST(type_prop, dynreplaceslice_arg_static_params_static_ok) ...@@ -14915,203 +14915,527 @@ TEST(type_prop, dynreplaceslice_arg_static_params_static_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8})); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
} }
#if 0 TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_static_dynamic_params_static_ok)
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_static_ok) {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_static_dynamic_replacement_static_params_static_ok)
{ {
auto arg = make_shared<op::Parameter>( auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4}); auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
} }
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_ok) TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_static_dynamic_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_rank_static_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds = auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
} }
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_static_dynamic_ok) TEST(type_prop,
dynreplaceslice_arg_static_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_static_params_rank_static_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>( auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds = auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
} }
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_static_dynamic_ok) TEST(
type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds = auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()}); auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_dynamic_ok) TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_dynamic_ok) TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_rank_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>( auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); // TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
TEST(type_prop, dynslice_static_shape) TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_static_shape)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6}); auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{1, 2, 1, 1, 3});
auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1}); auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1});
auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6}); auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6});
auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2}); auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_EQ(r->get_shape(), (Shape{1, 2, 1, 1, 3})); EXPECT_EQ(r->get_shape(), (Shape{2, 3, 4, 5, 6}));
} }
struct DynSliceParams TEST(type_prop, dynreplaceslice_static_shape_replacement_inconsistent)
{ {
std::vector<Shape> shapes; auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6});
std::vector<std::vector<int64_t>> vals; auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 1, 1, 4});
std::vector<AxisSet> attrs; auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1});
auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6});
auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2});
DynSliceParams(const std::vector<Shape>& shape, try
const std::vector<std::vector<int64_t>>& val, {
const std::vector<AxisSet>& attr) auto r =
: shapes(shape) make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
, vals(val) FAIL() << "Did not detect mismatch of replacement shape";
, attrs(attr) }
catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(
error.what(), "Shape of the replacement is not compatible with the shape of the slice");
} }
}
struct DynReplaceSliceParams
{
Shape arg_shape;
Shape lower_bounds_shape;
Shape upper_bounds_shape;
Shape strides_shape;
Shape replacement_shape;
std::vector<int64_t> lower_bounds_val;
std::vector<int64_t> upper_bounds_val;
std::vector<int64_t> strides_val;
AxisSet lower_bounds_mask;
AxisSet upper_bounds_mask;
AxisSet new_axis;
AxisSet shrink_axis;
AxisSet ellipsis_mask;
}; };
struct DeduceDynSliceTest : ::testing::TestWithParam<DynSliceParams> struct DeduceDynReplaceSliceTest : ::testing::TestWithParam<DynReplaceSliceParams>
{ {
}; };
TEST_P(DeduceDynSliceTest, output_shape) TEST_P(DeduceDynReplaceSliceTest, output_shape)
{ {
auto tp = GetParam(); auto tp = GetParam();
auto arg = make_shared<op::Parameter>(element::f32, tp.shapes[0]); auto arg = make_shared<op::Parameter>(element::f32, tp.arg_shape);
auto lower_bounds = op::Constant::create(element::i64, tp.shapes[1], tp.vals[0]); auto replacement = make_shared<op::Parameter>(element::f32, tp.replacement_shape);
auto upper_bounds = op::Constant::create(element::i64, tp.shapes[2], tp.vals[1]); auto lower_bounds =
auto strides = op::Constant::create(element::i64, tp.shapes[3], tp.vals[2]); op::Constant::create(element::i64, tp.lower_bounds_shape, tp.lower_bounds_val);
auto upper_bounds =
op::Constant::create(element::i64, tp.upper_bounds_shape, tp.upper_bounds_val);
auto strides = op::Constant::create(element::i64, tp.strides_shape, tp.strides_val);
auto r = make_shared<op::DynSlice>(arg, auto r = make_shared<op::DynReplaceSlice>(arg,
lower_bounds, replacement,
upper_bounds, lower_bounds,
strides, upper_bounds,
tp.attrs[0], strides,
tp.attrs[1], tp.lower_bounds_mask,
tp.attrs[2], tp.upper_bounds_mask,
tp.attrs[3], tp.new_axis,
tp.attrs[4]); tp.shrink_axis,
tp.ellipsis_mask);
EXPECT_EQ(r->get_shape(), tp.shapes[4]); EXPECT_EQ(r->get_shape(), tp.arg_shape);
} }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
type_prop, type_prop,
DeduceDynSliceTest, DeduceDynReplaceSliceTest,
::testing::Values( ::testing::Values(
DynSliceParams({{2, 3, 4, 5, 6}, {5}, {5}, {5}, {1, 2, 1, 1, 3}}, DynReplaceSliceParams{{2, 3, 4, 5, 6},
{{0, 1, 2, 3, 1}, {1, 3, 3, 5, 6}, {1, 1, 1, 2, 2}}, {5},
{{}, {}, {}, {}, {}}), {5},
DynSliceParams({{10}, {0}, {0}, {0}, {10}}, {{}, {}, {}}, {{}, {}, {}, {}, {}}), {5},
DynSliceParams({{10}, {1}, {1}, {0}, {10}}, {1, 2, 1, 1, 3},
{{0}, {0}, {}}, {0, 1, 2, 3, 1},
{{}, {0}, {}, {}, {}}), // end-mask {1, 3, 3, 5, 6},
DynSliceParams({{10}, {1}, {1}, {0}, {9}}, {1, 1, 1, 2, 2},
{{-1}, {-1}, {}}, {},
{{0}, {}, {}, {}, {}}), // begin-mask {},
DynSliceParams({{10}, {1}, {1}, {0}, {10}}, {{0}, {10}, {}}, {{}, {}, {}, {}, {}}), {},
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{5}, {10}, {}}, {{}, {}, {}, {}, {}}), {},
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{-5}, {10}, {}}, {{}, {}, {}, {}, {}}), {}},
DynSliceParams({{10}, {1}, {1}, {1}, {6}}, DynReplaceSliceParams{{10}, {0}, {0}, {0}, {10}, {}, {}, {}, {}, {}, {}, {}, {}},
{{-5}, {0}, {-1}}, // negative-stride DynReplaceSliceParams{
{{}, {0}, {}, {}, {}}), {10}, {1}, {1}, {0}, {10}, {0}, {0}, {}, {}, {0}, {}, {}, {}}, // end-mask
DynSliceParams({{10}, {1}, {1}, {1}, {3}}, {{-5}, {2}, {-1}}, {{}, {}, {}, {}, {}}), DynReplaceSliceParams{
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{0}, {0}, {2}}, {{}, {0}, {}, {}, {}}), {10}, {1}, {1}, {0}, {9}, {-1}, {-1}, {}, {0}, {}, {}, {}, {}}, // begin-mask
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{1}, {0}, {2}}, {{}, {0}, {}, {}, {}}), DynReplaceSliceParams{{10}, {1}, {1}, {0}, {10}, {0}, {10}, {}, {}, {}, {}, {}, {}},
DynSliceParams({{10}, {1}, {1}, {1}, {10}}, {{-1}, {0}, {-1}}, {{}, {0}, {}, {}, {}}), DynReplaceSliceParams{{10}, {1}, {1}, {0}, {5}, {5}, {10}, {}, {}, {}, {}, {}, {}},
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{-1}, {0}, {-2}}, {{}, {0}, {}, {}, {}}), DynReplaceSliceParams{{10}, {1}, {1}, {0}, {5}, {-5}, {10}, {}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10},
{1},
{1},
{1},
{6},
{-5},
{0},
{-1}, // negative-stride
{},
{0},
{},
{},
{}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {3}, {-5}, {2}, {-1}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {0}, {0}, {2}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {1}, {0}, {2}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {10}, {-1}, {0}, {-1}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {-1}, {0}, {-2}, {}, {0}, {}, {}, {}},
/* Axis Masks: New, Shrink, Ellipsis */ /* Axis Masks: New, Shrink, Ellipsis */
DynSliceParams({{10}, {1}, {1}, {0}, {1, 10}}, {{0}, {10}, {}}, {{}, {}, {0}, {}, {}}), DynReplaceSliceParams{{10}, {1}, {1}, {0}, {1, 10}, {0}, {10}, {}, {}, {}, {0}, {}, {}},
DynSliceParams({{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}}, DynReplaceSliceParams{
{{0, 0}, {1, 2}, {}}, {1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}, {0, 0}, {1, 2}, {}, {}, {}, {}, {}, {1}},
{{}, {}, {}, {}, {1}}), DynReplaceSliceParams{{1, 2, 3},
DynSliceParams({{1, 2, 3}, {4}, {4}, {0}, {1, 2, 1}}, {4},
{{0, 0, 0, 1}, {2, 3, 2, 2}, {}}, {4},
{{}, {}, {2}, {3}, {}}), {0},
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}}, {1, 2, 1},
{{0, 0, 1}, {2, 2, 2}, {}}, {0, 0, 0, 1},
{{}, {}, {0}, {}, {1}}), {2, 3, 2, 2},
DynSliceParams({{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}}, {},
{{-1}, {0}, {-2}}, {},
{{1}, {1}, {}, {1}, {}}), {},
DynSliceParams({{1, 2, 2, 2}, {4}, {4}, {0}, {1, 2, 2}}, {2},
{{0, 1, 0, 0}, {1, 2, 2, 2}, {}}, {3},
{{1}, {1}, {}, {1}, {}}), {}},
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}}, DynReplaceSliceParams{
{{0, 0, 1}, {2, 2, 2}, {}}, {1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}, {0, 0, 1}, {2, 2, 2}, {}, {}, {}, {0}, {}, {1}},
{{}, {}, {0}, {2}, {1}}))); DynReplaceSliceParams{
{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}, {-1}, {0}, {-2}, {1}, {1}, {}, {1}, {}},
void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0, DynReplaceSliceParams{{1, 2, 2, 2},
const shared_ptr<Node>& param_1, {4},
const shared_ptr<Node>& param_2, {4},
const shared_ptr<Node>& param_3) {0},
{ {1, 2, 2},
try {0, 1, 0, 0},
{ {1, 2, 2, 2},
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3); {},
FAIL() << "Did not detect input order not vector"; {1},
{1},
{},
{1},
{}},
DynReplaceSliceParams{
{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}, {0, 0, 1}, {2, 2, 2}, {}, {}, {}, {0}, {2}, {1}}));
void DynReplaceSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3,
const shared_ptr<Node>& param_4)
{
try
{
auto r = make_shared<op::DynReplaceSlice>(param_0, param_1, param_2, param_3, param_4);
FAIL() << "Did not detect attributes not vector";
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
...@@ -15123,9 +15447,10 @@ void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0, ...@@ -15123,9 +15447,10 @@ void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
} }
} }
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector) TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_rank_static_dynamic_not_vector)
{ {
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8}); auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
...@@ -15133,74 +15458,98 @@ TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector) ...@@ -15133,74 +15458,98 @@ TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector)
{ {
lower_bounds = lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2}); lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
arg = make_shared<op::Parameter>( arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
lower_bounds = lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
upper_bounds = upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2}); upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
arg = make_shared<op::Parameter>( arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
upper_bounds = upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, 2}); strides = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
arg = make_shared<op::Parameter>( arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}); element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()}); strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
} }
TEST(type_prop, dynslice_params_et_dynamic_ok) TEST(type_prop, dynreplaceslice_params_et_dynamic_ok)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto arg = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto replacement = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 2, 4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto lower_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4}); auto upper_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto strides = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides); auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32); EXPECT_EQ(r->get_output_element_type(0), element::dynamic);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
} }
void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0, TEST(type_prop, dynreplaceslice_params_et_dynamic_inferrable_ok)
const shared_ptr<Node>& param_1, {
const shared_ptr<Node>& param_2, auto arg = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 6, 8});
const shared_ptr<Node>& param_3) auto replacement = make_shared<op::Parameter>(element::boolean, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto strides = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::boolean);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
void DynReplaceSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3,
const shared_ptr<Node>& param_4)
{ {
try try
{ {
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3); auto r = make_shared<op::DynReplaceSlice>(param_0, param_1, param_2, param_3, param_4);
FAIL() << "Did not detect parameter element type not i64"; FAIL() << "Did not detect parameter element type not i64";
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
...@@ -15213,9 +15562,10 @@ void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0, ...@@ -15213,9 +15562,10 @@ void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
} }
} }
TEST(type_prop, dynslice_params_et_wrong) TEST(type_prop, dynreplaceslice_params_et_wrong)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4}); auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
...@@ -15223,15 +15573,14 @@ TEST(type_prop, dynslice_params_et_wrong) ...@@ -15223,15 +15573,14 @@ TEST(type_prop, dynslice_params_et_wrong)
{ {
lower_bounds = make_shared<op::Parameter>(element::boolean, Shape{4}); lower_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
upper_bounds = make_shared<op::Parameter>(element::boolean, Shape{4}); upper_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
{ {
strides = make_shared<op::Parameter>(element::boolean, Shape{4}); strides = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides); DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
} }
} }
#endif /* 0 */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment