Commit 351917d5 authored by nishant.b.patel's avatar nishant.b.patel

Address feedback

parent d394d986
......@@ -77,25 +77,30 @@ void op::QuantizedConvolution::validate_and_infer_types()
NODE_VALIDATION_CHECK(this,
get_input_element_type(INPUT_SCALE).is_real() ||
get_input_element_type(INPUT_SCALE).is_dynamic() ||
get_input_element_type(FILTER_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_real(),
get_input_element_type(FILTER_SCALE).is_dynamic() ||
get_input_element_type(OUTPUT_SCALE).is_real() ||
get_input_element_type(OUTPUT_SCALE).is_dynamic(),
"Scale must be a floating point number");
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(INPUT_ZERO_POINT),
"Input Zero point element type (",
get_input_element_type(INPUT_ZERO_POINT),
") must match input element type (",
get_input_element_type(0),
")");
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == get_input_element_type(FILTER_ZERO_POINT),
"Filter Zero point element type (",
get_input_element_type(FILTER_ZERO_POINT),
") must match filter element type (",
get_input_element_type(1),
")");
NODE_VALIDATION_CHECK(
this,
get_input_element_type(0).compatible(get_input_element_type(INPUT_ZERO_POINT)),
"Input Zero point element type (",
get_input_element_type(INPUT_ZERO_POINT),
") must match input element type (",
get_input_element_type(0),
")");
NODE_VALIDATION_CHECK(
this,
get_input_element_type(1).compatible(get_input_element_type(FILTER_ZERO_POINT)),
"Filter Zero point element type (",
get_input_element_type(FILTER_ZERO_POINT),
") must match filter element type (",
get_input_element_type(1),
")");
// TODO Remove these checks once we support channelwise and vector of scales
NODE_VALIDATION_CHECK(this,
......@@ -158,13 +163,15 @@ void op::QuantizedConvolution::validate_and_infer_types()
0, /* batch_axis_result, */
1 /* output_channel_axis_result, */
));
NODE_VALIDATION_CHECK(this,
get_output_element_type(0) == get_input_element_type(OUTPUT_ZERO_POINT),
"Output Zero point element type (",
get_input_element_type(OUTPUT_ZERO_POINT),
") must match output element type (",
get_output_element_type(0),
")");
NODE_VALIDATION_CHECK(
this,
get_output_element_type(0).compatible(get_input_element_type(OUTPUT_ZERO_POINT)),
"Output Zero point element type (",
get_input_element_type(OUTPUT_ZERO_POINT),
") must match output element type (",
get_output_element_type(0),
")");
}
shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment