Commit 351917d5 authored by nishant.b.patel's avatar nishant.b.patel

Address feedback

parent d394d986
...@@ -77,25 +77,30 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -77,25 +77,30 @@ void op::QuantizedConvolution::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
get_input_element_type(INPUT_SCALE).is_real() || get_input_element_type(INPUT_SCALE).is_real() ||
get_input_element_type(INPUT_SCALE).is_dynamic() ||
get_input_element_type(FILTER_SCALE).is_real() || get_input_element_type(FILTER_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_real(), get_input_element_type(FILTER_SCALE).is_dynamic() ||
get_input_element_type(OUTPUT_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_dynamic(),
"Scale must be a floating point number"); "Scale must be a floating point number");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(
get_input_element_type(0) == get_input_element_type(INPUT_ZERO_POINT), this,
"Input Zero point element type (", get_input_element_type(0).compatible(get_input_element_type(INPUT_ZERO_POINT)),
get_input_element_type(INPUT_ZERO_POINT), "Input Zero point element type (",
") must match input element type (", get_input_element_type(INPUT_ZERO_POINT),
get_input_element_type(0), ") must match input element type (",
")"); get_input_element_type(0),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(FILTER_ZERO_POINT), NODE_VALIDATION_CHECK(
"Filter Zero point element type (", this,
get_input_element_type(FILTER_ZERO_POINT), get_input_element_type(1).compatible(get_input_element_type(FILTER_ZERO_POINT)),
") must match filter element type (", "Filter Zero point element type (",
get_input_element_type(1), get_input_element_type(FILTER_ZERO_POINT),
")"); ") must match filter element type (",
get_input_element_type(1),
")");
// TODO Remove these checks once we support channelwise and vector of scales // TODO Remove these checks once we support channelwise and vector of scales
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -158,13 +163,15 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -158,13 +163,15 @@ void op::QuantizedConvolution::validate_and_infer_types()
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1 /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
)); ));
NODE_VALIDATION_CHECK(this,
get_output_element_type(0) == get_input_element_type(OUTPUT_ZERO_POINT), NODE_VALIDATION_CHECK(
"Output Zero point element type (", this,
get_input_element_type(OUTPUT_ZERO_POINT), get_output_element_type(0).compatible(get_input_element_type(OUTPUT_ZERO_POINT)),
") must match output element type (", "Output Zero point element type (",
get_output_element_type(0), get_input_element_type(OUTPUT_ZERO_POINT),
")"); ") must match output element type (",
get_output_element_type(0),
")");
} }
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment