Commit a06b896c authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by Jayaram Bobba

v1::Reshape zero_flag renamed. Default value unset (#3945)

parent 160b91bf
......@@ -151,7 +151,7 @@ constexpr NodeTypeInfo op::v1::Reshape::type_info;
op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
: Op({arg, pattern})
, m_zero_flag(zero_flag)
, m_special_zero(zero_flag)
{
constructor_validate_and_infer_types();
}
......@@ -193,7 +193,7 @@ void op::v1::Reshape::validate_and_infer_types()
negative_dims,
")");
if (!(zero_dims && m_zero_flag) && !negative_dims)
if (!(zero_dims && m_special_zero) && !negative_dims)
{
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
}
......@@ -205,9 +205,9 @@ void op::v1::Reshape::validate_and_infer_types()
out_shape_val.end(),
partial_shape.begin(),
[&](const int64_t& v) {
return (v < 0)
? Dimension()
: ((v == 0 && m_zero_flag) ? Dimension() : Dimension(v));
return (v < 0) ? Dimension()
: ((v == 0 && m_special_zero) ? Dimension()
: Dimension(v));
});
if (get_input_partial_shape(0).is_static())
......@@ -219,7 +219,7 @@ void op::v1::Reshape::validate_and_infer_types()
size_t input_elements = shape_size(input_shape);
for (size_t i = 0; i < static_cast<size_t>(output_rank); i++)
{
if (out_shape_val[i] == 0 && m_zero_flag)
if (out_shape_val[i] == 0 && m_special_zero)
{
// Copy input_shape[i] for zero values
NODE_VALIDATION_CHECK(
......@@ -274,7 +274,7 @@ void op::v1::Reshape::validate_and_infer_types()
shared_ptr<Node> op::v1::Reshape::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Reshape>(new_args.at(0), new_args.at(1), m_zero_flag);
return make_shared<v1::Reshape>(new_args.at(0), new_args.at(1), m_special_zero);
}
void op::v1::Reshape::generate_adjoints(autodiff::Adjoints& /* adjoints */,
......
......@@ -132,11 +132,10 @@ namespace ngraph
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
/// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy
/// \param special_zero Treats zeros in `pattern` as wildcard flags indicating a
/// copy
/// from input shape at the same index.
Reshape(const Output<Node>& arg,
const Output<Node>& pattern,
bool zero_flag = false);
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
void validate_and_infer_types() override;
......@@ -144,14 +143,14 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_zero_flag() const { return m_zero_flag; }
void set_zero_flag(bool zero_flag) { m_zero_flag = zero_flag; }
bool get_special_zero() const { return m_special_zero; }
void set_special_zero(bool special_zero) { m_special_zero = special_zero; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
bool m_zero_flag;
bool m_special_zero;
};
}
using v0::Reshape;
......
......@@ -50,7 +50,8 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label);
auto dyn_reshape =
make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label, false);
// Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's
......
......@@ -215,7 +215,7 @@ namespace
bool op_cast(shared_ptr<op::v1::Reshape> node)
{
auto replacement_node = make_shared<op::v0::DynReshape>(
node->input_value(0), node->input_value(1), node->get_zero_flag());
node->input_value(0), node->input_value(1), node->get_special_zero());
replace_node(node, replacement_node);
return true;
}
......
......@@ -3359,7 +3359,7 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::Reshape_v1:
{
auto tmp = static_cast<const op::v1::Reshape*>(&n);
node["zero_flag"] = tmp->get_zero_flag();
node["special_zero"] = tmp->get_special_zero();
break;
}
case OP_TYPEID::DynSlice:
......
......@@ -120,7 +120,7 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v1)
{
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern);
auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern, false);
auto f = std::make_shared<Function>(NodeVector{reshape_v1}, ParameterVector{arg, pattern});
......
......@@ -1548,7 +1548,7 @@ TEST(constant_folding, constant_dyn_reshape)
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape);
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape, false);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager;
......@@ -1582,7 +1582,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape =
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b);
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment