Commit 59d1504c authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Handle 0 and -1 in shape inputs for dynamic reshape op (#2999)

* Handle 0 and -1 in shape inputs for dynamic reshape op

* Flag to control semantics of zero values in dynreshape op
parent 70cf8f28
......@@ -24,8 +24,11 @@
using namespace std;
using namespace ngraph;
op::DynReshape::DynReshape(const shared_ptr<Node>& arg, const shared_ptr<Node>& pattern)
op::DynReshape::DynReshape(const shared_ptr<Node>& arg,
const shared_ptr<Node>& pattern,
bool zero_flag)
: Op("DynReshape", check_single_output_args({arg, pattern}))
, m_zero_flag(zero_flag)
{
constructor_validate_and_infer_types();
}
......@@ -47,9 +50,87 @@ void op::DynReshape::validate_and_infer_types()
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
NODE_VALIDATION_CHECK(this,
std::none_of(out_shape_val.begin(),
out_shape_val.end(),
[](int64_t v) { return v < -1; }),
"Dim size cannot be less than -1 ");
int zero_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
int negative_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
NODE_VALIDATION_CHECK(this,
negative_dims <= 1,
"More than one dimension has size of -1 (",
negative_dims,
")");
if (!(zero_dims && m_zero_flag) && !negative_dims)
{
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
}
else
{
std::vector<Dimension> partial_shape(static_cast<size_t>(output_rank));
// Replace zeros and negatives with Dynamic dimensions as needed
std::transform(out_shape_val.begin(),
out_shape_val.end(),
partial_shape.begin(),
[&](const int64_t& v) {
return (v < 0)
? Dimension()
: ((v == 0 && m_zero_flag) ? Dimension() : Dimension(v));
});
if (get_input_partial_shape(0).is_static())
{
size_t output_elements = 1;
int negative_dim = -1;
auto input_shape = get_input_partial_shape(0).to_shape();
size_t input_elements = shape_size(input_shape);
for (size_t i = 0; i < static_cast<size_t>(output_rank); i++)
{
if (out_shape_val[i] == 0 && m_zero_flag)
{
// Copy input_shape[i] for zero values
NGRAPH_CHECK(i < input_shape.size());
partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i];
}
else if (out_shape_val[i] == -1)
{
negative_dim = i;
}
else
{
output_elements *= out_shape_val[i];
}
}
if (negative_dim != -1)
{
// Infer size such that number of output elements matches
// input elements
if (output_elements == 0)
{
NGRAPH_CHECK(input_elements == 0);
partial_shape[negative_dim] = Dimension(0);
}
else
{
NGRAPH_CHECK(input_elements % output_elements == 0);
partial_shape[negative_dim] = Dimension(input_elements / output_elements);
}
}
}
set_output_type(0, get_input_element_type(0), PartialShape(partial_shape));
}
}
else
{
......@@ -60,7 +141,7 @@ void op::DynReshape::validate_and_infer_types()
shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynReshape>(new_args.at(0), new_args.at(1));
return make_shared<DynReshape>(new_args.at(0), new_args.at(1), m_zero_flag);
}
void op::DynReshape::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -37,16 +37,27 @@ namespace ngraph
/// \param pattern The node that defines output shape pattern.
/// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
DynReshape(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& pattern);
/// A value of -1 is allowed for at most one dimension, in which case the dimension
/// size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy from input
/// shape at the same index.
DynReshape(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& pattern,
bool zero_flag = false);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_zero_flag() const { return m_zero_flag; }
void set_zero_flag(bool zero_flag) { m_zero_flag = zero_flag; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
bool m_zero_flag;
};
}
}
......@@ -12978,6 +12978,95 @@ TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_rank_dynamic_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreshape_arg_rank_static_pattern_zero)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 0, 2, 8});
auto dynamic_arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = op::Constant::create(element::i64, Shape{4}, {1, 2, 0, 32});
auto r1 = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r1->get_output_shape(0), (Shape{1, 2, 0, 32}));
auto r2 = make_shared<op::DynReshape>(arg, pattern, true /*zero_flag*/);
EXPECT_EQ(r2->get_output_shape(0), (Shape{1, 2, 2, 32}));
auto r3 = make_shared<op::DynReshape>(dynamic_arg, pattern, true /*zero_flag*/);
EXPECT_TRUE(
r3->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic(), 32}));
}
TEST(type_prop, dynreshape_arg_rank_static_pattern_negative)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 8});
auto dynamic_arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = op::Constant::create(element::i64, Shape{4}, {1, 2, 4, -1});
auto r1 = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r1->get_output_shape(0), (Shape{1, 2, 4, 16}));
auto r2 = make_shared<op::DynReshape>(dynamic_arg, pattern);
EXPECT_TRUE(
r2->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 4, Dimension::dynamic()}));
}
TEST(type_prop, dynreshape_arg_rank_static_pattern_zero_negative)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 2, 0});
auto dynamic_arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = op::Constant::create(element::i64, Shape{2}, {0, -1});
auto r1 = make_shared<op::DynReshape>(arg, pattern);
auto r2 = make_shared<op::DynReshape>(arg, pattern, true);
EXPECT_EQ(r1->get_output_shape(0), (Shape{0, 0}));
EXPECT_EQ(r2->get_output_shape(0), (Shape{2, 0}));
auto r3 = make_shared<op::DynReshape>(dynamic_arg, pattern);
auto r4 = make_shared<op::DynReshape>(dynamic_arg, pattern, true);
EXPECT_TRUE(r3->get_output_partial_shape(0).same_scheme(PartialShape{0, Dimension::dynamic()}));
EXPECT_TRUE(r4->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, dynreshape_arg_rank_static_pattern_negative_failure1)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 8});
auto pattern = op::Constant::create(element::i64, Shape{4}, {1, 2, -1, -1});
try
{
auto r = make_shared<op::DynReshape>(arg, pattern);
FAIL() << "Expected failure on dynreshape construction";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("More than one dimension has size of -1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dynreshape_arg_rank_static_pattern_negative_failure2)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 2, 8});
auto pattern = op::Constant::create(element::i64, Shape{4}, {1, 2, 4, -2});
try
{
auto r = make_shared<op::DynReshape>(arg, pattern);
FAIL() << "Expected failure on dynreshape construction";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dim size cannot be less than -1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void DynReshape_Test_Shape_Except(const shared_ptr<Node>& param_0, const shared_ptr<Node>& param_1)
{
try
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment