Commit 85582d0c authored by Adam Procter's avatar Adam Procter

Moar tests

parent c99d65a0
......@@ -91,6 +91,17 @@ void op::DynReplaceSlice::validate_and_infer_types()
strides_shape.rank(),
".");
PartialShape attrs_shape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
(lower_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, lower_bounds_shape)) &&
(upper_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, upper_bounds_shape)) &&
(strides_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, strides_shape)),
"Shapes for lower bounds, upper bounds, and strides do not match");
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4);
......@@ -99,32 +110,32 @@ void op::DynReplaceSlice::validate_and_infer_types()
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4));
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
auto slice_shape = PartialShape::dynamic();
if (lower_bounds && upper_bounds && strides)
{
auto inferred_slice_shape = infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
NODE_VALIDATION_CHECK(this,
replacement_shape.compatible(inferred_slice_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
inferred_slice_shape,
")");
slice_shape = infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
}
PartialShape output_shape = arg_shape;
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(output_shape, PartialShape::dynamic(replacement_shape.rank())),
"Rank of the replacement is not compatible with rank of the argument tensor");
set_output_type(0, get_input_element_type(0), output_shape);
NODE_VALIDATION_CHECK(this,
slice_shape.compatible(replacement_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
slice_shape,
")");
set_output_type(0, result_et, arg_shape);
}
shared_ptr<Node> op::DynReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -626,7 +626,6 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
{
// TODO(amprocte): double-check that these checks are needed.
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(node,
......
......@@ -14900,10 +14900,10 @@ TEST(type_prop, fake_quantize_invalid_rank)
}
}
TEST(type_prop, dynreplaceslice_arg_static_params_static_ok)
TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
......@@ -14915,203 +14915,527 @@ TEST(type_prop, dynreplaceslice_arg_static_params_static_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
#if 0
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_static_ok)
TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_static_dynamic_replacement_static_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_ok)
TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_static_dynamic_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_static_dynamic_ok)
TEST(type_prop,
dynreplaceslice_arg_static_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop, dynreplaceslice_arg_static_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_static_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_static_dynamic_ok)
TEST(
type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop,
dynreplaceslice_arg_rank_static_dynamic_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8}));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_dynamic_ok)
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_static_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_dynamic_ok)
TEST(type_prop,
dynreplaceslice_arg_rank_dynamic_replacement_rank_static_dynamic_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
// TODO(amprocte): We should be able to infer PartialShape::dynamic(4) here.
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynslice_static_shape)
TEST(type_prop, dynreplaceslice_arg_rank_dynamic_replacement_rank_dynamic_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreplaceslice_static_shape)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{1, 2, 1, 1, 3});
auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1});
auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6});
auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_EQ(r->get_shape(), (Shape{1, 2, 1, 1, 3}));
EXPECT_EQ(r->get_shape(), (Shape{2, 3, 4, 5, 6}));
}
struct DynSliceParams
TEST(type_prop, dynreplaceslice_static_shape_replacement_inconsistent)
{
std::vector<Shape> shapes;
std::vector<std::vector<int64_t>> vals;
std::vector<AxisSet> attrs;
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 1, 1, 4});
auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1});
auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6});
auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2});
DynSliceParams(const std::vector<Shape>& shape,
const std::vector<std::vector<int64_t>>& val,
const std::vector<AxisSet>& attr)
: shapes(shape)
, vals(val)
, attrs(attr)
try
{
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
FAIL() << "Did not detect mismatch of replacement shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), "Shape of the replacement is not compatible with the shape of the slice");
}
}
struct DynReplaceSliceParams
{
Shape arg_shape;
Shape lower_bounds_shape;
Shape upper_bounds_shape;
Shape strides_shape;
Shape replacement_shape;
std::vector<int64_t> lower_bounds_val;
std::vector<int64_t> upper_bounds_val;
std::vector<int64_t> strides_val;
AxisSet lower_bounds_mask;
AxisSet upper_bounds_mask;
AxisSet new_axis;
AxisSet shrink_axis;
AxisSet ellipsis_mask;
};
struct DeduceDynSliceTest : ::testing::TestWithParam<DynSliceParams>
struct DeduceDynReplaceSliceTest : ::testing::TestWithParam<DynReplaceSliceParams>
{
};
TEST_P(DeduceDynSliceTest, output_shape)
TEST_P(DeduceDynReplaceSliceTest, output_shape)
{
auto tp = GetParam();
auto arg = make_shared<op::Parameter>(element::f32, tp.shapes[0]);
auto lower_bounds = op::Constant::create(element::i64, tp.shapes[1], tp.vals[0]);
auto upper_bounds = op::Constant::create(element::i64, tp.shapes[2], tp.vals[1]);
auto strides = op::Constant::create(element::i64, tp.shapes[3], tp.vals[2]);
auto arg = make_shared<op::Parameter>(element::f32, tp.arg_shape);
auto replacement = make_shared<op::Parameter>(element::f32, tp.replacement_shape);
auto lower_bounds =
op::Constant::create(element::i64, tp.lower_bounds_shape, tp.lower_bounds_val);
auto upper_bounds =
op::Constant::create(element::i64, tp.upper_bounds_shape, tp.upper_bounds_val);
auto strides = op::Constant::create(element::i64, tp.strides_shape, tp.strides_val);
auto r = make_shared<op::DynSlice>(arg,
lower_bounds,
upper_bounds,
strides,
tp.attrs[0],
tp.attrs[1],
tp.attrs[2],
tp.attrs[3],
tp.attrs[4]);
auto r = make_shared<op::DynReplaceSlice>(arg,
replacement,
lower_bounds,
upper_bounds,
strides,
tp.lower_bounds_mask,
tp.upper_bounds_mask,
tp.new_axis,
tp.shrink_axis,
tp.ellipsis_mask);
EXPECT_EQ(r->get_shape(), tp.shapes[4]);
EXPECT_EQ(r->get_shape(), tp.arg_shape);
}
INSTANTIATE_TEST_CASE_P(
type_prop,
DeduceDynSliceTest,
DeduceDynReplaceSliceTest,
::testing::Values(
DynSliceParams({{2, 3, 4, 5, 6}, {5}, {5}, {5}, {1, 2, 1, 1, 3}},
{{0, 1, 2, 3, 1}, {1, 3, 3, 5, 6}, {1, 1, 1, 2, 2}},
{{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {0}, {0}, {0}, {10}}, {{}, {}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {10}},
{{0}, {0}, {}},
{{}, {0}, {}, {}, {}}), // end-mask
DynSliceParams({{10}, {1}, {1}, {0}, {9}},
{{-1}, {-1}, {}},
{{0}, {}, {}, {}, {}}), // begin-mask
DynSliceParams({{10}, {1}, {1}, {0}, {10}}, {{0}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{5}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{-5}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {6}},
{{-5}, {0}, {-1}}, // negative-stride
{{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {3}}, {{-5}, {2}, {-1}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{0}, {0}, {2}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{1}, {0}, {2}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {10}}, {{-1}, {0}, {-1}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{-1}, {0}, {-2}}, {{}, {0}, {}, {}, {}}),
DynReplaceSliceParams{{2, 3, 4, 5, 6},
{5},
{5},
{5},
{1, 2, 1, 1, 3},
{0, 1, 2, 3, 1},
{1, 3, 3, 5, 6},
{1, 1, 1, 2, 2},
{},
{},
{},
{},
{}},
DynReplaceSliceParams{{10}, {0}, {0}, {0}, {10}, {}, {}, {}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{
{10}, {1}, {1}, {0}, {10}, {0}, {0}, {}, {}, {0}, {}, {}, {}}, // end-mask
DynReplaceSliceParams{
{10}, {1}, {1}, {0}, {9}, {-1}, {-1}, {}, {0}, {}, {}, {}, {}}, // begin-mask
DynReplaceSliceParams{{10}, {1}, {1}, {0}, {10}, {0}, {10}, {}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {0}, {5}, {5}, {10}, {}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {0}, {5}, {-5}, {10}, {}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10},
{1},
{1},
{1},
{6},
{-5},
{0},
{-1}, // negative-stride
{},
{0},
{},
{},
{}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {3}, {-5}, {2}, {-1}, {}, {}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {0}, {0}, {2}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {1}, {0}, {2}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {10}, {-1}, {0}, {-1}, {}, {0}, {}, {}, {}},
DynReplaceSliceParams{{10}, {1}, {1}, {1}, {5}, {-1}, {0}, {-2}, {}, {0}, {}, {}, {}},
/* Axis Masks: New, Shrink, Ellipsis */
DynSliceParams({{10}, {1}, {1}, {0}, {1, 10}}, {{0}, {10}, {}}, {{}, {}, {0}, {}, {}}),
DynSliceParams({{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}},
{{0, 0}, {1, 2}, {}},
{{}, {}, {}, {}, {1}}),
DynSliceParams({{1, 2, 3}, {4}, {4}, {0}, {1, 2, 1}},
{{0, 0, 0, 1}, {2, 3, 2, 2}, {}},
{{}, {}, {2}, {3}, {}}),
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {}, {1}}),
DynSliceParams({{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}},
{{-1}, {0}, {-2}},
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 2, 2}, {4}, {4}, {0}, {1, 2, 2}},
{{0, 1, 0, 0}, {1, 2, 2, 2}, {}},
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {2}, {1}})));
void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3)
{
try
{
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3);
FAIL() << "Did not detect input order not vector";
DynReplaceSliceParams{{10}, {1}, {1}, {0}, {1, 10}, {0}, {10}, {}, {}, {}, {0}, {}, {}},
DynReplaceSliceParams{
{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}, {0, 0}, {1, 2}, {}, {}, {}, {}, {}, {1}},
DynReplaceSliceParams{{1, 2, 3},
{4},
{4},
{0},
{1, 2, 1},
{0, 0, 0, 1},
{2, 3, 2, 2},
{},
{},
{},
{2},
{3},
{}},
DynReplaceSliceParams{
{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}, {0, 0, 1}, {2, 2, 2}, {}, {}, {}, {0}, {}, {1}},
DynReplaceSliceParams{
{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}, {-1}, {0}, {-2}, {1}, {1}, {}, {1}, {}},
DynReplaceSliceParams{{1, 2, 2, 2},
{4},
{4},
{0},
{1, 2, 2},
{0, 1, 0, 0},
{1, 2, 2, 2},
{},
{1},
{1},
{},
{1},
{}},
DynReplaceSliceParams{
{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}, {0, 0, 1}, {2, 2, 2}, {}, {}, {}, {0}, {2}, {1}}));
void DynReplaceSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3,
const shared_ptr<Node>& param_4)
{
try
{
auto r = make_shared<op::DynReplaceSlice>(param_0, param_1, param_2, param_3, param_4);
FAIL() << "Did not detect attributes not vector";
}
catch (const NodeValidationFailure& error)
{
......@@ -15123,9 +15447,10 @@ void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
}
}
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector)
TEST(type_prop, dynreplaceslice_arg_static_replacement_static_params_rank_static_dynamic_not_vector)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
......@@ -15133,74 +15458,98 @@ TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector)
{
lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
replacement =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 2, 4});
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynReplaceSlice_Test_Shape_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
}
TEST(type_prop, dynslice_params_et_dynamic_ok)
TEST(type_prop, dynreplaceslice_params_et_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto arg = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto strides = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
EXPECT_EQ(r->get_output_element_type(0), element::dynamic);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3)
TEST(type_prop, dynreplaceslice_params_et_dynamic_inferrable_ok)
{
auto arg = make_shared<op::Parameter>(element::dynamic, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::boolean, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto strides = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto r =
make_shared<op::DynReplaceSlice>(arg, replacement, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::boolean);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 4, 6, 8}));
}
void DynReplaceSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3,
const shared_ptr<Node>& param_4)
{
try
{
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3);
auto r = make_shared<op::DynReplaceSlice>(param_0, param_1, param_2, param_3, param_4);
FAIL() << "Did not detect parameter element type not i64";
}
catch (const NodeValidationFailure& error)
......@@ -15213,9 +15562,10 @@ void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
}
}
TEST(type_prop, dynslice_params_et_wrong)
TEST(type_prop, dynreplaceslice_params_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto replacement = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 4});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
......@@ -15223,15 +15573,14 @@ TEST(type_prop, dynslice_params_et_wrong)
{
lower_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
upper_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
DynReplaceSlice_Test_Type_Except(arg, replacement, lower_bounds, upper_bounds, strides);
}
}
#endif /* 0 */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment