Commit ab5d4c03 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Type prop tests for Quantize/Dequantize (#1807)

* Add type_prop unit tests for Quantize

* Fix naming consistency; add test for unsupported RoundMode fail

* Add type-prop unit tests for Dequantize
parent 651ff9ff
......@@ -50,9 +50,9 @@ void op::Dequantize::validate_and_infer_types()
NODE_VALIDATION_ASSERT(this, m_type.is_real()) << "Output element type (" << m_type
<< ") must be a floating point number";
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE).is_real())
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == m_type)
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must be a floating point number";
<< ") must match the output element type (" << m_type << ")";
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == get_input_element_type(INPUT))
<< "Offset element type (" << get_input_element_type(OFFSET)
......@@ -61,7 +61,7 @@ void op::Dequantize::validate_and_infer_types()
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
<< "Quantizaztion axis (" << axis << ") is greater than input shape rank ("
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")";
}
......
......@@ -53,9 +53,9 @@ void op::Quantize::validate_and_infer_types()
<< "Input element type (" << get_input_element_type(INPUT)
<< ") must be a floating point number";
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE).is_real())
NODE_VALIDATION_ASSERT(this, get_input_element_type(SCALE) == get_input_element_type(INPUT))
<< "Scale element type (" << get_input_element_type(SCALE)
<< ") must be a floating point number";
<< ") must match input element type (" << get_input_element_type(INPUT) << ")";
NODE_VALIDATION_ASSERT(this, get_input_element_type(OFFSET) == m_type)
<< "Offset element type (" << get_input_element_type(OFFSET)
......@@ -64,7 +64,7 @@ void op::Quantize::validate_and_infer_types()
for (auto axis : m_axes)
{
NODE_VALIDATION_ASSERT(this, axis < get_shape().size())
<< "Quantizaztion axis (" << axis << ") is greater than input shape rank ("
<< "Quantization axis (" << axis << ") must be less than input shape rank ("
<< get_shape().size() << ")";
}
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment