Commit 982889f5 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4c: Slice and ReplaceSlice (#1781)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Slice

* Add unit tests for partial shape/type validation for Slice

* Implement partial shape/type propagation for ReplaceSlice

* Add unit tests for partial type/shape validation for ReplaceSlice

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Change an internal variable name to something more descriptive

* Review comments
parent 91219e40
......@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(strides)
{
constructor_validate_and_infer_types();
check_args();
}
op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
......@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
, m_strides(Strides(lower_bounds.size(), 1))
{
constructor_validate_and_infer_types();
check_args();
}
void op::ReplaceSlice::check_args()
void op::ReplaceSlice::validate_and_infer_types()
{
auto& input_0 = get_inputs().at(0);
auto& input_0_shape = input_0.get_shape();
auto& input_0_element_type = input_0.get_element_type();
auto& input_1 = get_inputs().at(1);
auto& input_1_shape = input_1.get_shape();
auto& input_1_element_type = input_1.get_element_type();
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
NODE_VALIDATION_ASSERT(this, input_0_shape.size() == input_1_shape.size())
<< "Argument ranks do not match (arg0 shape: " << input_0_shape
<< ", arg1 shape: " << input_1_shape << ").";
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
Dimension merged_args_rank;
NODE_VALIDATION_ASSERT(this, input_0_element_type == input_1_element_type)
<< "Argument element types do not match (arg0 element type: " << input_0_element_type
<< ", arg1 element type: " << input_1_element_type << ").";
NODE_VALIDATION_ASSERT(this,
Dimension::merge(merged_args_rank, arg0_shape.rank(), arg1_shape.rank()))
<< "Argument ranks do not match (arg0 shape: " << arg0_shape
<< ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_0_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_0_shape << ").";
element::Type arg0_et = get_input_element_type(0);
element::Type arg1_et = get_input_element_type(1);
element::Type merged_args_et;
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_0_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, element::Type::merge(merged_args_et, arg0_et, arg1_et))
<< "Argument element types do not match (arg0 element type: " << arg0_et
<< ", arg1 element type: " << arg1_et << ").";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_0_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_0_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
Shape slice_shape;
size_t output_rank = m_upper_bounds.size();
for (size_t i = 0; i < input_0_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_0_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_0_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
size_t slice_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
slice_axis_size =
slice_axis_size / m_strides[i] + ((slice_axis_size % m_strides[i] == 0) ? 0 : 1);
slice_shape.push_back(slice_axis_size);
NODE_VALIDATION_ASSERT(this,
merged_args_rank.is_dynamic() || size_t(merged_args_rank) == output_rank)
<< "Argument ranks do not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> sliced_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
arg0_shape.rank().is_dynamic() || arg0_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(arg0_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << arg0_shape << ").";
size_t sliced_dim = m_upper_bounds[i] - m_lower_bounds[i];
sliced_dim = sliced_dim / m_strides[i] + ((sliced_dim % m_strides[i] == 0) ? 0 : 1);
sliced_dims[i] = sliced_dim;
}
NODE_VALIDATION_ASSERT(this, input_1_shape == slice_shape)
<< "Shape of replacement tensor (" << input_1_shape << ") does not match the slice shape "
PartialShape slice_shape{sliced_dims};
NODE_VALIDATION_ASSERT(this, arg1_shape.compatible(slice_shape))
<< "Shape of replacement tensor (" << arg1_shape << ") does not match the slice shape "
<< "(" << slice_shape << ").";
set_output_type(0, input_0_element_type, input_0_shape);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape result_shape =
(arg0_shape.rank().is_static())
? arg0_shape
: PartialShape(std::vector<Dimension>(output_rank, Dimension::dynamic()));
set_output_type(0, merged_args_et, result_shape);
}
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -88,11 +88,11 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void check_args();
void validate_and_infer_types() override;
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Strides m_strides;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
Strides m_strides;
};
}
}
......@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size())
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
// construct the default value.
if (m_strides.size() == 0)
{
m_strides = Strides(m_lower_bounds.size(), 1);
}
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
NODE_VALIDATION_ASSERT(this, m_lower_bounds.size() == input_shape.size())
<< "Rank of lower bounds (" << m_lower_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (lower bounds: " << m_lower_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_upper_bounds.size() == input_shape.size())
<< "Rank of upper bounds (" << m_upper_bounds.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (upper bounds: " << m_upper_bounds
<< ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this,
m_lower_bounds.size() == m_upper_bounds.size() &&
m_lower_bounds.size() == m_strides.size())
<< "Ranks of lower bounds (" << m_lower_bounds << "), upper bounds (" << m_upper_bounds
<< ") and strides (" << m_strides << ") do not match.";
NODE_VALIDATION_ASSERT(this, m_strides.size() == input_shape.size())
<< "Rank of strides (" << m_strides.size() << ") does not match rank "
<< "of argument (" << input_shape.size() << ") (strides: " << m_strides
<< ", argument shape: " << input_shape << ").";
size_t output_rank = m_upper_bounds.size();
Shape result_shape;
for (size_t i = 0; i < input_shape.size(); i++)
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this, m_upper_bounds[i] <= input_shape[i])
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_lower_bounds[i] <= m_upper_bounds[i])
<< "Lower bound for slice is greater than upper bound at axis " << i
<< " (lower bounds: " << m_lower_bounds << ", upper bounds: " << m_upper_bounds << ").";
NODE_VALIDATION_ASSERT(this, m_strides[i] != 0) << "Stride for slice is zero at axis " << i
<< " (strides: " << m_strides << ").";
}
const PartialShape& input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || size_t(input_rank) == output_rank)
<< "Input rank does not match the rank of the lower bounds (" << m_lower_bounds
<< "), upper bounds (" << m_upper_bounds << "), and strides (" << m_strides << ").";
std::vector<Dimension> result_dims(output_rank);
for (size_t i = 0; i < output_rank; i++)
{
NODE_VALIDATION_ASSERT(this,
input_rank.is_dynamic() || input_shape[i].is_dynamic() ||
m_upper_bounds[i] <= size_t(input_shape[i]))
<< "Upper bound for slice at axis " << i << " is out of range "
<< "(upper bounds: " << m_upper_bounds << ", argument shape: " << input_shape << ").";
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size =
result_axis_size / m_strides[i] + ((result_axis_size % m_strides[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size);
result_dims[i] = result_axis_size;
}
set_output_type(0, input.get_element_type(), result_shape);
set_output_type(0, get_input_element_type(0), PartialShape{result_dims});
}
shared_ptr<Node> op::Slice::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment