Commit 85582d0c authored by Adam Procter's avatar Adam Procter

Moar tests

parent c99d65a0
......@@ -91,6 +91,17 @@ void op::DynReplaceSlice::validate_and_infer_types()
strides_shape.rank(),
".");
PartialShape attrs_shape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
(lower_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, lower_bounds_shape)) &&
(upper_bounds_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, upper_bounds_shape)) &&
(strides_shape.same_scheme(PartialShape{0}) ||
PartialShape::merge_into(attrs_shape, strides_shape)),
"Shapes for lower bounds, upper bounds, and strides do not match");
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4);
......@@ -99,32 +110,32 @@ void op::DynReplaceSlice::validate_and_infer_types()
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4));
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
auto slice_shape = PartialShape::dynamic();
if (lower_bounds && upper_bounds && strides)
{
auto inferred_slice_shape = infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
NODE_VALIDATION_CHECK(this,
replacement_shape.compatible(inferred_slice_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
inferred_slice_shape,
")");
slice_shape = infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask);
}
PartialShape output_shape = arg_shape;
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(output_shape, PartialShape::dynamic(replacement_shape.rank())),
"Rank of the replacement is not compatible with rank of the argument tensor");
set_output_type(0, get_input_element_type(0), output_shape);
NODE_VALIDATION_CHECK(this,
slice_shape.compatible(replacement_shape),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: ",
slice_shape,
")");
set_output_type(0, result_et, arg_shape);
}
shared_ptr<Node> op::DynReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -626,7 +626,6 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
{
// TODO(amprocte): double-check that these checks are needed.
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(node,
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment