Commit ab5d4c03 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Type prop tests for Quantize/Dequantize (#1807)

* Add type_prop unit tests for Quantize

* Fix naming consistency; add test for unsupported RoundMode fail

* Add type-prop unit tests for Dequantize
parent 651ff9ff
...@@ -50,9 +50,9 @@ void op::Dequantize::validate_and_infer_types() ...@@ -50,9 +50,9 @@ void op::Dequantize::validate_and_infer_types()
NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type
<< ") must be a floating point number"; << ") must be a floating point number";
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE).is_real()) NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == m_type)
<< "Scale element type (" << get_input_element_type(SCALE) << "Scale element type (" << get_input_element_type(SCALE)
<< ") must be a floating point number"; << ") must match the output element type (" << m_type << ")";
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == get_input_element_type(INPUT)) NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == get_input_element_type(INPUT))
<< "Offset element type (" << get_input_element_type(OFFSET) << "Offset element type (" << get_input_element_type(OFFSET)
...@@ -61,7 +61,7 @@ void op::Dequantize::validate_and_infer_types() ...@@ -61,7 +61,7 @@ void op::Dequantize::validate_and_infer_types()
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < get_shape().size()) NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
<< "Quantizaztion axis (" << axis << ") is greater than input shape rank (" << "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")"; << get_shape().size() << ")";
} }
......
...@@ -53,9 +53,9 @@ void op::Quantize::validate_and_infer_types() ...@@ -53,9 +53,9 @@ void op::Quantize::validate_and_infer_types()
<< "Input element type (" << get_input_element_type(INPUT) << "Input element type (" << get_input_element_type(INPUT)
<< ") must be a floating point number"; << ") must be a floating point number";
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE).is_real()) NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == get_input_element_type(INPUT))
<< "Scale element type (" << get_input_element_type(SCALE) << "Scale element type (" << get_input_element_type(SCALE)
<< ") must be a floating point number"; << ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == m_type) NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == m_type)
<< "Offset element type (" << get_input_element_type(OFFSET) << "Offset element type (" << get_input_element_type(OFFSET)
...@@ -64,7 +64,7 @@ void op::Quantize::validate_and_infer_types() ...@@ -64,7 +64,7 @@ void op::Quantize::validate_and_infer_types()
for (auto axis : m_axes) for (auto axis : m_axes)
{ {
NODE_VALIDATION_ASSERT(this, axis < get_shape().size()) NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
<< "Quantizaztion axis (" << axis << ") is greater than input shape rank (" << "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")"; << get_shape().size() << ")";
} }
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment