Commit a06b896c authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by Jayaram Bobba

v1::Reshape zero_flag renamed. Default value unset (#3945)

parent 160b91bf
...@@ -151,7 +151,7 @@ constexpr NodeTypeInfo op::v1::Reshape::type_info; ...@@ -151,7 +151,7 @@ constexpr NodeTypeInfo op::v1::Reshape::type_info;
op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag) op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
: Op({arg, pattern}) : Op({arg, pattern})
, m_zero_flag(zero_flag) , m_special_zero(zero_flag)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -193,7 +193,7 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -193,7 +193,7 @@ void op::v1::Reshape::validate_and_infer_types()
negative_dims, negative_dims,
")"); ")");
if (!(zero_dims && m_zero_flag) && !negative_dims) if (!(zero_dims && m_special_zero) && !negative_dims)
{ {
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val()); set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
} }
...@@ -205,9 +205,9 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -205,9 +205,9 @@ void op::v1::Reshape::validate_and_infer_types()
out_shape_val.end(), out_shape_val.end(),
partial_shape.begin(), partial_shape.begin(),
[&](const int64_t& v) { [&](const int64_t& v) {
return (v < 0) return (v < 0) ? Dimension()
? Dimension() : ((v == 0 && m_special_zero) ? Dimension()
: ((v == 0 && m_zero_flag) ? Dimension() : Dimension(v)); : Dimension(v));
}); });
if (get_input_partial_shape(0).is_static()) if (get_input_partial_shape(0).is_static())
...@@ -219,7 +219,7 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -219,7 +219,7 @@ void op::v1::Reshape::validate_and_infer_types()
size_t input_elements = shape_size(input_shape); size_t input_elements = shape_size(input_shape);
for (size_t i = 0; i < static_cast<size_t>(output_rank); i++) for (size_t i = 0; i < static_cast<size_t>(output_rank); i++)
{ {
if (out_shape_val[i] == 0 && m_zero_flag) if (out_shape_val[i] == 0 && m_special_zero)
{ {
// Copy input_shape[i] for zero values // Copy input_shape[i] for zero values
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -274,7 +274,7 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -274,7 +274,7 @@ void op::v1::Reshape::validate_and_infer_types()
shared_ptr<Node> op::v1::Reshape::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Reshape::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v1::Reshape>(new_args.at(0), new_args.at(1), m_zero_flag); return make_shared<v1::Reshape>(new_args.at(0), new_args.at(1), m_special_zero);
} }
void op::v1::Reshape::generate_adjoints(autodiff::Adjoints& /* adjoints */, void op::v1::Reshape::generate_adjoints(autodiff::Adjoints& /* adjoints */,
......
...@@ -132,11 +132,10 @@ namespace ngraph ...@@ -132,11 +132,10 @@ namespace ngraph
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$. /// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
/// A value of -1 is allowed for at most one dimension, in which case the /// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor. /// dimension size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy /// \param special_zero Treats zeros in `pattern` as wildcard flags indicating a
/// copy
/// from input shape at the same index. /// from input shape at the same index.
Reshape(const Output<Node>& arg, Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
const Output<Node>& pattern,
bool zero_flag = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -144,14 +143,14 @@ namespace ngraph ...@@ -144,14 +143,14 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
bool get_zero_flag() const { return m_zero_flag; } bool get_special_zero() const { return m_special_zero; }
void set_zero_flag(bool zero_flag) { m_zero_flag = zero_flag; } void set_special_zero(bool special_zero) { m_special_zero = special_zero; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private: private:
bool m_zero_flag; bool m_special_zero;
}; };
} }
using v0::Reshape; using v0::Reshape;
......
...@@ -50,7 +50,8 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -50,7 +50,8 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label = auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label); auto dyn_reshape =
make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label, false);
// Note: No need to capture or consider constant_shape_label, because // Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's // shape propagation will have transferred the info to dyn_reshape's
......
...@@ -215,7 +215,7 @@ namespace ...@@ -215,7 +215,7 @@ namespace
bool op_cast(shared_ptr<op::v1::Reshape> node) bool op_cast(shared_ptr<op::v1::Reshape> node)
{ {
auto replacement_node = make_shared<op::v0::DynReshape>( auto replacement_node = make_shared<op::v0::DynReshape>(
node->input_value(0), node->input_value(1), node->get_zero_flag()); node->input_value(0), node->input_value(1), node->get_special_zero());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return true;
} }
......
...@@ -3359,7 +3359,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3359,7 +3359,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Reshape_v1: case OP_TYPEID::Reshape_v1:
{ {
auto tmp = static_cast<const op::v1::Reshape*>(&n); auto tmp = static_cast<const op::v1::Reshape*>(&n);
node["zero_flag"] = tmp->get_zero_flag(); node["special_zero"] = tmp->get_special_zero();
break; break;
} }
case OP_TYPEID::DynSlice: case OP_TYPEID::DynSlice:
......
...@@ -120,7 +120,7 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v1) ...@@ -120,7 +120,7 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v1)
{ {
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1)); auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern); auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern, false);
auto f = std::make_shared<Function>(NodeVector{reshape_v1}, ParameterVector{arg, pattern}); auto f = std::make_shared<Function>(NodeVector{reshape_v1}, ParameterVector{arg, pattern});
......
...@@ -1548,7 +1548,7 @@ TEST(constant_folding, constant_dyn_reshape) ...@@ -1548,7 +1548,7 @@ TEST(constant_folding, constant_dyn_reshape)
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in); auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape); auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape); auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{}); auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager; pass::Manager pass_manager;
...@@ -1582,7 +1582,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant) ...@@ -1582,7 +1582,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a); auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b); auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape = auto dyn_reshape =
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b); make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{}); auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic()); ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment