Unverified Commit 3dce6fdb authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

[ Reshape ] Check shape_size(input_shape)==shape_size(output_shape) (#4476)

parent 28432321
......@@ -201,7 +201,17 @@ void op::v1::Reshape::validate_and_infer_types()
if (!(zero_dims && m_special_zero) && !negative_dims)
{
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
auto output_shape = const_shape->get_shape_val();
if (get_input_partial_shape(0).is_static())
{
NODE_VALIDATION_CHECK(this,
shape_size(get_input_shape(0)) == shape_size(output_shape),
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
get_input_shape(0));
}
set_output_type(0, get_input_element_type(0), output_shape);
}
else
{
......
......@@ -292,4 +292,20 @@ TEST(type_prop, reshape_v1_arg_rank_static_pattern_zero)
auto reshape_v1_dynamic = make_shared<op::v1::Reshape>(dynamic_arg, pattern, true);
EXPECT_TRUE(reshape_v1_dynamic->get_output_partial_shape(0).same_scheme(
PartialShape{1, 2, Dimension::dynamic(), 32}));
try
{
auto static_shape_parameter = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto reshape_output_pattern = op::Constant::create(element::i64, Shape{4}, {2, 2, 3, 4});
auto reshape =
make_shared<op::v1::Reshape>(static_shape_parameter, reshape_output_pattern, true);
FAIL() << "Expected failure on reshape construction";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("is incompatible with input shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment