Commit 982889f5 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4c: Slice and ReplaceSlice (#1781)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Slice

* Add unit tests for partial shape/type validation for Slice

* Implement partial shape/type propagation for ReplaceSlice

* Add unit tests for partial type/shape validation for ReplaceSlice

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Change an internal variable name to something more descriptive

* Review comments
parent 91219e40
...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides) , m_strides(strides)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0, ...@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1)) , m_strides(Strides(lower_bounds.size(), 1))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
check_args();
} }
void op::ReplaceSlice::check_args() void op::ReplaceSlice::validate_and_infer_types()
{ {
auto& input_0 = get_inputs().at(0); // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
auto& input_0_shape = input_0.get_shape(); // construct the default value.
auto& input_0_element_type = input_0.get_element_type(); if (m_strides.size() == 0)
{
auto& input_1 = get_inputs().at(1); m_strides = Strides(m_lower_bounds.size(), 1);
auto& input_1_shape = input_1.get_shape(); }
auto& input_1_element_type = input_1.get_element_type();
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size()) const PartialShape& arg0_shape = get_input_partial_shape(0);
<< "Argument ranks do not match (arg0 shape: " << input_0_shape const PartialShape& arg1_shape = get_input_partial_shape(1);
<< ", arg1 shape: " << input_1_shape << ")."; Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type) NODE_VALIDATION_ASSERT(this,
<< "Argument element types do not match (arg0 element type: " << input_0_element_type Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< ", arg1 element type: " << input_1_element_type << ")."; << "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size()) element::Type arg0_et = get_input_element_type(0);
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank " element::Type arg1_et = get_input_element_type(1);
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds element::Type merged_args_et;
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " << "Argument element types do not match (arg0 element type: " << arg0_et
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds << ", arg1 element type: " << arg1_et << ").";
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of strides (" << m_strides.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_0_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape; size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++) for (size_t i = 0; i < output_rank; i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; NODE_VALIDATION_ASSERT(this,
slice_axis_size = merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1); << "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
slice_shape.push_back(slice_axis_size); << "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
} }
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape) PartialShape slice_shape{sliced_dims};
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ")."; << "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape); // Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
} }
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -88,11 +88,11 @@ namespace ngraph ...@@ -88,11 +88,11 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void check_args(); void validate_and_infer_types() override;
const Coordinate m_lower_bounds; Coordinate m_lower_bounds;
const Coordinate m_upper_bounds; Coordinate m_upper_bounds;
const Strides m_strides; Strides m_strides;
}; };
} }
} }
...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types() void op::Slice::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) // An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
{ // construct the default value.
return; if (m_strides.size() == 0)
}
if (0 == m_strides.size())
{ {
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
} }
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank " m_lower_bounds.size() == m_upper_bounds.size() &&
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds m_lower_bounds.size() == m_strides.size())
<< ", argument shape: " << input_shape << ")."; << "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size()) size_t output_rank = m_upper_bounds.size();
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
Shape result_shape; for (size_t i = 0; i < output_rank; i++)
for (size_t i = 0; i < input_shape.size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i]) NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i << "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ")."; << " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ")."; << " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i]; size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size = result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1); result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size); result_dims[i] = result_axis_size;
} }
set_output_type(0, input.get_element_type(), result_shape); set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
} }
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment