Commit 7f6f07ee authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4j: Quantize/Dequantize (#1842)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Dynamic shpae/type prop for Quantize

* Add unit tests for partial shape/type validation for Quantize

* Implement partial shape/type validation for Dequantize, with unit tests

* Remove #if 0'd code

* Deal with std::find_if #include issues

* Fix more include madness
parent 925e7b27
......@@ -35,11 +35,6 @@ op::Dequantize::Dequantize(shared_ptr<Node> input,
void op::Dequantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......@@ -47,41 +42,82 @@ void op::Dequantize::validate_and_infer_types()
OFFSET
};
set_output_size(1);
set_output_type(0, m_type, get_input_shape(INPUT));
NODE_VALIDATION_ASSERT(this, get_input_element_type(INPUT).is_quantized())
<< "Input element type (" << get_input_element_type(INPUT) << ") must be a quantized type";
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type
<< ") must be a floating point number";
<< ") must be a floating point type";
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == m_type)
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must match the output element type (" << m_type << ")";
element::Type quantized_type;
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == get_input_element_type(INPUT))
NODE_VALIDATION_ASSERT(this,
element::Type::merge(quantized_type,
get_input_element_type(INPUT),
get_input_element_type(OFFSET)))
<< "Offset element type (" << get_input_element_type(OFFSET)
<< ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, quantized_type.is_dynamic() || quantized_type.is_quantized())
<< "Offset/input element type (" << quantized_type << ") must be a quantized type";
element::Type unquantized_type;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(unquantized_type, get_input_element_type(SCALE), m_type))
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must match output element type (" << m_type << ")";
PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank))
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")";
<< input_rank << ")";
}
Shape projected_shape = project(get_input_shape(INPUT), m_axes);
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)))
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape ("
<< get_input_partial_shape(OFFSET) << ") must match";
NODE_VALIDATION_ASSERT(this, get_input_shape(SCALE) == projected_shape)
<< "Scale shape (" << get_input_shape(SCALE)
<< ") must match input shape projected along the quantization axes (" << projected_shape
<< ")";
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size()))
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of "
<< "quantization axes (" << m_axes.size() << ")";
set_output_size(1);
NODE_VALIDATION_ASSERT(this, get_input_shape(OFFSET) == projected_shape)
<< "Offset shape (" << get_input_shape(OFFSET)
<< ") must match input shape projected along the quantization axes (" << projected_shape
<< ")";
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static())
{
size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{
if (m_axes.count(j) != 0)
{
injected_scale_offset_dims.push_back(scale_offset_shape[i++]);
}
else
{
injected_scale_offset_dims.push_back(Dimension::dynamic());
}
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}))
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape ("
<< input_shape << ") at the quantization axes (" << m_axes << ")";
set_output_type(0, unquantized_type, result_shape);
}
else
{
set_output_type(0, unquantized_type, PartialShape::dynamic());
}
}
shared_ptr<Node> op::Dequantize::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -37,11 +37,6 @@ op::Quantize::Quantize(shared_ptr<Node> input,
void op::Quantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......@@ -49,45 +44,85 @@ void op::Quantize::validate_and_infer_types()
OFFSET
};
set_output_size(1);
set_output_type(0, m_type, get_input_shape(INPUT));
NODE_VALIDATION_ASSERT(this, m_round_mode == RoundMode::HALF_AWAY_FROM_ZERO)
<< "Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now";
NODE_VALIDATION_ASSERT(this, m_type.is_static()) << "Output element type must not be dynamic";
NODE_VALIDATION_ASSERT(this, m_type.is_quantized()) << "Output element type (" << m_type
<< ") must be a quantized type";
NODE_VALIDATION_ASSERT(this, get_input_element_type(INPUT).is_real())
<< "Input element type (" << get_input_element_type(INPUT)
<< ") must be a floating point number";
element::Type unquantized_type;
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == get_input_element_type(INPUT))
NODE_VALIDATION_ASSERT(this,
element::Type::merge(unquantized_type,
get_input_element_type(INPUT),
get_input_element_type(SCALE)))
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == m_type)
NODE_VALIDATION_ASSERT(this, unquantized_type.is_dynamic() || unquantized_type.is_real())
<< "Scale/input element type (" << unquantized_type << ") must be a floating point number";
element::Type quantized_type;
NODE_VALIDATION_ASSERT(
this, element::Type::merge(quantized_type, get_input_element_type(OFFSET), m_type))
<< "Offset element type (" << get_input_element_type(OFFSET)
<< ") must match output element type (" << m_type << ")";
PartialShape input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || axis < size_t(input_rank))
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")";
<< input_rank << ")";
}
Shape projected_shape = project(get_input_shape(INPUT), m_axes);
PartialShape scale_offset_shape = get_input_partial_shape(SCALE);
NODE_VALIDATION_ASSERT(this, get_input_shape(SCALE) == projected_shape)
<< "Scale shape (" << get_input_shape(SCALE)
<< ") must match input shape projected along the quantization axes (" << projected_shape
<< ")";
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(scale_offset_shape, get_input_partial_shape(OFFSET)))
<< "Scale shape (" << get_input_partial_shape(SCALE) << ") and offset shape ("
<< get_input_partial_shape(OFFSET) << ") must match";
NODE_VALIDATION_ASSERT(this, get_input_shape(OFFSET) == projected_shape)
<< "Offset shape (" << get_input_shape(OFFSET)
<< ") must match input shape projected along the quantization axes (" << projected_shape
<< ")";
NODE_VALIDATION_ASSERT(this, scale_offset_shape.rank().compatible(m_axes.size()))
<< "Scale/offset rank (" << scale_offset_shape.rank() << ") does not match the number of "
<< "quantization axes (" << m_axes.size() << ")";
NODE_VALIDATION_ASSERT(this, m_round_mode == RoundMode::HALF_AWAY_FROM_ZERO)
<< "Only RoundMode = HALF_AWAY_FROM_ZERO is supported, for now";
set_output_size(1);
if (input_shape.rank().is_static() && scale_offset_shape.rank().is_static())
{
size_t i = 0;
std::vector<Dimension> injected_scale_offset_dims;
for (size_t j = 0; j < size_t(input_shape.rank()); j++)
{
if (m_axes.count(j) != 0)
{
injected_scale_offset_dims.push_back(scale_offset_shape[i++]);
}
else
{
injected_scale_offset_dims.push_back(Dimension::dynamic());
}
}
PartialShape result_shape = input_shape;
NODE_VALIDATION_ASSERT(
this, PartialShape::merge_into(result_shape, PartialShape{injected_scale_offset_dims}))
<< "Scale/offset shape (" << scale_offset_shape << ") must match input shape ("
<< input_shape << ") at the quantization axes (" << m_axes << ")";
set_output_type(0, quantized_type, result_shape);
}
else
{
set_output_type(0, quantized_type, PartialShape::dynamic());
}
}
shared_ptr<Node> op::Quantize::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment