Commit 8d2d827f authored by Ilya Churaev's avatar Ilya Churaev Committed by Robert Kimball

Direct shape inference for fused MVN op instead of relying on op decomposition (#4221)

Co-authored-by: 's avatarJayaram Bobba <jayaram.bobba@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 06e02c28
......@@ -48,7 +48,10 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va
constructor_validate_and_infer_types();
}
void op::MVN::pre_validate_and_infer_types()
// decompose_op() relies on knowing the data type of input data which might
// not be available at shape inference time. So do direct shape inference
// instead of relying on op decomposition.
void op::MVN::validate_and_infer_types()
{
// if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel
......@@ -64,6 +67,8 @@ void op::MVN::pre_validate_and_infer_types()
}
set_reduction_axes(reduction_axes);
}
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
NodeVector op::MVN::decompose_op() const
......
......@@ -65,7 +65,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment