Commit 351917d5 authored by nishant.b.patel's avatar nishant.b.patel

Address feedback

parent d394d986
...@@ -77,20 +77,25 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -77,20 +77,25 @@ void op::QuantizedConvolution::validate_and_infer_types()
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
get_input_element_type(INPUT_SCALE).is_real() || get_input_element_type(INPUT_SCALE).is_real() ||
get_input_element_type(INPUT_SCALE).is_dynamic() ||
get_input_element_type(FILTER_SCALE).is_real() || get_input_element_type(FILTER_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_real(), get_input_element_type(FILTER_SCALE).is_dynamic() ||
get_input_element_type(OUTPUT_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_dynamic(),
"Scale must be a floating point number"); "Scale must be a floating point number");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(
get_input_element_type(0) == get_input_element_type(INPUT_ZERO_POINT), this,
get_input_element_type(0).compatible(get_input_element_type(INPUT_ZERO_POINT)),
"Input Zero point element type (", "Input Zero point element type (",
get_input_element_type(INPUT_ZERO_POINT), get_input_element_type(INPUT_ZERO_POINT),
") must match input element type (", ") must match input element type (",
get_input_element_type(0), get_input_element_type(0),
")"); ")");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(
get_input_element_type(1) == get_input_element_type(FILTER_ZERO_POINT), this,
get_input_element_type(1).compatible(get_input_element_type(FILTER_ZERO_POINT)),
"Filter Zero point element type (", "Filter Zero point element type (",
get_input_element_type(FILTER_ZERO_POINT), get_input_element_type(FILTER_ZERO_POINT),
") must match filter element type (", ") must match filter element type (",
...@@ -158,8 +163,10 @@ void op::QuantizedConvolution::validate_and_infer_types() ...@@ -158,8 +163,10 @@ void op::QuantizedConvolution::validate_and_infer_types()
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1 /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
)); ));
NODE_VALIDATION_CHECK(this,
get_output_element_type(0) == get_input_element_type(OUTPUT_ZERO_POINT), NODE_VALIDATION_CHECK(
this,
get_output_element_type(0).compatible(get_input_element_type(OUTPUT_ZERO_POINT)),
"Output Zero point element type (", "Output Zero point element type (",
get_input_element_type(OUTPUT_ZERO_POINT), get_input_element_type(OUTPUT_ZERO_POINT),
") must match output element type (", ") must match output element type (",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment