Unverified Commit 2f18a44b authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Handle unknown ranks and partial shapes in reduction axes inference for MVN (#4262)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent b57c7661
...@@ -55,13 +55,14 @@ void op::MVN::validate_and_infer_types() ...@@ -55,13 +55,14 @@ void op::MVN::validate_and_infer_types()
{ {
// if m_across_channels is true we should calculate mean and variance per batch // if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel // else we calculate these per channel
if (m_reduction_axes.empty()) if (m_reduction_axes.empty() && input_value(0).get_partial_shape().rank().is_static())
{ {
auto data = input_value(0);
AxisSet reduction_axes; AxisSet reduction_axes;
reduction_axes.insert(0); reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2; size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data.get_shape().size(); ++i) for (size_t i = start_axis;
i < static_cast<size_t>(input_value(0).get_partial_shape().rank());
++i)
{ {
reduction_axes.insert(i); reduction_axes.insert(i);
} }
......
...@@ -28,3 +28,22 @@ TEST(type_prop, mvn) ...@@ -28,3 +28,22 @@ TEST(type_prop, mvn)
EXPECT_EQ(mvn_func->get_element_type(), element::f32); EXPECT_EQ(mvn_func->get_element_type(), element::f32);
EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 3, 6})); EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 3, 6}));
} }
TEST(type_prop, mvn_partial)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto mvn_func = make_shared<op::MVN>(data);
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
EXPECT_EQ(mvn_func->get_reduction_axes(), (AxisSet{0, 1, 2}));
ASSERT_TRUE(mvn_func->get_output_partial_shape(0).same_scheme(
(PartialShape{1, Dimension::dynamic(), 6})));
// across_channels = false
EXPECT_EQ(make_shared<op::MVN>(data, false)->get_reduction_axes(), (AxisSet{0, 2}));
// rank unknown
auto mvn_partial =
make_shared<op::MVN>(make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
EXPECT_EQ(mvn_partial->get_reduction_axes(), AxisSet{});
ASSERT_TRUE(mvn_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment