Commit a91abd8b authored by nishant.b.patel's avatar nishant.b.patel

some more checks for quantized conv op

parent 65ad5f08
......@@ -97,6 +97,22 @@ void op::QuantizedConvolution::validate_and_infer_types()
get_input_element_type(1),
")");
// TODO Remove these checks once we support channelwise and vector of scales
NODE_VALIDATION_CHECK(this,
shape_size(get_input_shape(2)) == 1 ||
shape_size(get_input_shape(3)) == 1,
"Input scale and input zero point shape must be same and 1");
NODE_VALIDATION_CHECK(this,
shape_size(get_input_shape(4)) == 1 ||
shape_size(get_input_shape(5)) == 1,
"Filter scale and filter zero point shape must be same and 1");
NODE_VALIDATION_CHECK(this,
shape_size(get_input_shape(6)) == 1 ||
shape_size(get_input_shape(7)) == 1,
"Output scale and output zero point shape must be same and 1");
auto input_shape = get_input_shape(0);
auto filters_shape = get_input_shape(1);
......@@ -172,3 +188,9 @@ shared_ptr<Node> op::QuantizedConvolution::copy_with_new_args(const NodeVector&
m_filter_axes,
m_output_axes));
}
void op::QuantizedConvolution::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
throw ngraph_error("Forward-propagation-only operation");
}
......@@ -42,7 +42,9 @@ namespace ngraph
/// \param output_scale Scale to transform the output
/// \param output_zero_point Zero point used for mapping
/// \param output_type Output element type
///
/// \param input_axes Input axes set for channel wise quantization
/// \param filter_axes Filter axes set for channel wise quantization
/// \param output_axes Output axes set for channel wise quantization
QuantizedConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
......@@ -60,6 +62,7 @@ namespace ngraph
const ngraph::AxisSet& input_axes,
const ngraph::AxisSet& filter_axes,
const ngraph::AxisSet& output_axes);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
......@@ -75,6 +78,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
Strides m_window_movement_strides;
Strides m_window_dilation_strides;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment