Commit 8ccddb19 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

[Spec] Implement Reshape:v1 (#3633)

parent 1ce31a49
......@@ -23,16 +23,16 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::DynReshape::type_info;
constexpr NodeTypeInfo op::v0::DynReshape::type_info;
op::DynReshape::DynReshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
op::v0::DynReshape::DynReshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
: Op({arg, pattern})
, m_zero_flag(zero_flag)
{
constructor_validate_and_infer_types();
}
void op::DynReshape::validate_and_infer_types()
void op::v0::DynReshape::validate_and_infer_types()
{
auto pattern_et = get_input_element_type(1);
// check data types
......@@ -147,13 +147,13 @@ void op::DynReshape::validate_and_infer_types()
}
}
shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::DynReshape::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynReshape>(new_args.at(0), new_args.at(1), m_zero_flag);
return make_shared<v0::DynReshape>(new_args.at(0), new_args.at(1), m_zero_flag);
}
void op::DynReshape::generate_adjoints(autodiff::Adjoints& /* adjoints */,
void op::v0::DynReshape::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("generate_adjoints not implemented for DynReshape");
......
......@@ -22,6 +22,8 @@
namespace ngraph
{
namespace op
{
namespace v0
{
/// \brief Tensor dynamic reshape operation.
///
......@@ -40,11 +42,13 @@ namespace ngraph
///
/// \param arg The tensor to be reshaped.
/// \param pattern The node that defines output shape pattern.
/// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape
/// must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
/// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy from
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy
/// from
/// input shape at the same index.
DynReshape(const Output<Node>& arg,
const Output<Node>& pattern,
......@@ -65,4 +69,7 @@ namespace ngraph
bool m_zero_flag;
};
}
// default opset version
using v0::DynReshape;
}
}
......@@ -18,6 +18,7 @@
#include <iostream>
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
......@@ -145,3 +146,139 @@ void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
adjoints.add_delta(input_value(0), reshape);
}
constexpr NodeTypeInfo op::v1::Reshape::type_info;
op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
: Op({arg, pattern})
, m_zero_flag(zero_flag)
{
constructor_validate_and_infer_types();
}
void op::v1::Reshape::validate_and_infer_types()
{
auto pattern_et = get_input_element_type(1);
// check data types
NODE_VALIDATION_CHECK(
this, pattern_et.compatible(element::Type_t::i64), "Pattern must have element type i64.");
// check shapes
const PartialShape& pattern_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
pattern_shape.rank().compatible(1),
"Pattern shape must have rank 1, got ",
pattern_shape.rank(),
".");
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_input_is_relevant_to_shape(1);
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
NODE_VALIDATION_CHECK(this,
std::none_of(out_shape_val.begin(),
out_shape_val.end(),
[](int64_t v) { return v < -1; }),
"Dim size cannot be less than -1 ");
int zero_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
int negative_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
NODE_VALIDATION_CHECK(this,
negative_dims <= 1,
"More than one dimension has size of -1 (",
negative_dims,
")");
if (!(zero_dims && m_zero_flag) && !negative_dims)
{
set_output_type(0, get_input_element_type(0), const_shape->get_shape_val());
}
else
{
std::vector<Dimension> partial_shape(static_cast<size_t>(output_rank));
// Replace zeros and negatives with Dynamic dimensions as needed
std::transform(out_shape_val.begin(),
out_shape_val.end(),
partial_shape.begin(),
[&](const int64_t& v) {
return (v < 0)
? Dimension()
: ((v == 0 && m_zero_flag) ? Dimension() : Dimension(v));
});
if (get_input_partial_shape(0).is_static())
{
size_t output_elements = 1;
int negative_dim = -1;
auto input_shape = get_input_partial_shape(0).to_shape();
size_t input_elements = shape_size(input_shape);
for (size_t i = 0; i < static_cast<size_t>(output_rank); i++)
{
if (out_shape_val[i] == 0 && m_zero_flag)
{
// Copy input_shape[i] for zero values
NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i];
}
else if (out_shape_val[i] == -1)
{
negative_dim = i;
}
else
{
output_elements *= out_shape_val[i];
}
}
if (negative_dim != -1)
{
// Infer size such that number of output elements matches
// input elements
if (output_elements == 0)
{
// TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
partial_shape[negative_dim] = Dimension(0);
}
else
{
NODE_VALIDATION_CHECK(
this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
partial_shape[negative_dim] = Dimension(input_elements / output_elements);
}
}
}
set_output_type(0, get_input_element_type(0), PartialShape(partial_shape));
}
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
}
}
shared_ptr<Node> op::v1::Reshape::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Reshape>(new_args.at(0), new_args.at(1), m_zero_flag);
}
void op::v1::Reshape::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("generate_adjoints not implemented for Reshape");
}
......@@ -105,5 +105,52 @@ namespace ngraph
Shape m_output_shape;
bool m_is_transpose{false};
};
namespace v1
{
/// \brief Tensor dynamic reshape operation.
///
/// "Converts" an input tensor into a new shape with the same number of elements.
/// This op does not touch the actual data. If needed, use Transpose for that purpose.
///
class Reshape : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"DynReshape", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform
/// transpose.
///
/// \param arg The tensor to be reshaped.
/// \param pattern The node that defines output shape pattern.
/// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape
/// must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
/// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy
/// from input shape at the same index.
Reshape(const Output<Node>& arg,
const Output<Node>& pattern,
bool zero_flag = false);
void validate_and_infer_types() override;
size_t get_version() const override { return 1; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_zero_flag() const { return m_zero_flag; }
void set_zero_flag(bool zero_flag) { m_zero_flag = zero_flag; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
bool m_zero_flag;
};
}
}
}
......@@ -18,6 +18,7 @@
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
......@@ -25,6 +26,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
......@@ -233,6 +235,15 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::DynReshape:
{
auto zero_flag = false;
auto replacement_node = make_shared<op::v1::Reshape>(
node->input(0).get_source_output(), node->input(1).get_source_output(), zero_flag);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get());
......@@ -245,60 +256,6 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Product:
{
bool keep_dims = false;
auto replacement_node = make_shared<op::v1::ReduceProd>(
node->input(0).get_source_output(), node->input(1).get_source_output(), keep_dims);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Sum:
{
bool keep_dims = false;
auto replacement_node = make_shared<op::v1::ReduceSum>(
node->input(0).get_source_output(), node->input(1).get_source_output(), keep_dims);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = dynamic_cast<const op::v0::Pad*>(node.get());
auto padding_below = tmp->get_padding_below();
auto pads_begin_node =
make_shared<op::Constant>(element::i64, Shape{padding_below.size()}, padding_below);
auto padding_above = tmp->get_padding_above();
auto pads_end_node =
make_shared<op::Constant>(element::i64, Shape{padding_above.size()}, padding_above);
auto replacement_node = make_shared<op::v1::Pad>(node->input(0).get_source_output(),
pads_begin_node,
pads_end_node,
node->input(1).get_source_output(),
tmp->get_pad_mode());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
AxisSet axes = tmp->get_axes();
NGRAPH_CHECK(
axes.size() == 1,
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
*node);
auto replacement_node =
make_shared<op::v1::Softmax>(node->input(0).get_source_output(), axes.to_vector()[0]);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::MaxPool:
{
auto tmp = dynamic_cast<const op::v0::MaxPool*>(node.get());
......@@ -356,6 +313,35 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = dynamic_cast<const op::v0::Pad*>(node.get());
auto padding_below = tmp->get_padding_below();
auto pads_begin_node =
make_shared<op::Constant>(element::i64, Shape{padding_below.size()}, padding_below);
auto padding_above = tmp->get_padding_above();
auto pads_end_node =
make_shared<op::Constant>(element::i64, Shape{padding_above.size()}, padding_above);
auto replacement_node = make_shared<op::v1::Pad>(node->input(0).get_source_output(),
pads_begin_node,
pads_end_node,
node->input(1).get_source_output(),
tmp->get_pad_mode());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Product:
{
bool keep_dims = false;
auto replacement_node = make_shared<op::v1::ReduceProd>(
node->input(0).get_source_output(), node->input(1).get_source_output(), keep_dims);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Reverse:
{
// creates a Constant node from the v0::Reverse reversed_axes attribute
......@@ -375,6 +361,31 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
break;
}
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
AxisSet axes = tmp->get_axes();
NGRAPH_CHECK(
axes.size() == 1,
"Unable to convert Softmax:0 to Softmax:1 with zero or more than one axis. Node: ",
*node);
auto replacement_node =
make_shared<op::v1::Softmax>(node->input(0).get_source_output(), axes.to_vector()[0]);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Sum:
{
bool keep_dims = false;
auto replacement_node = make_shared<op::v1::ReduceSum>(
node->input(0).get_source_output(), node->input(1).get_source_output(), keep_dims);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break;
}
......
......@@ -1239,7 +1239,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::DynReshape:
{
node = make_shared<op::DynReshape>(args[0], args[1]);
const auto zero_flag = node_js.at("zero_flag").get<bool>();
if (op_version == 0)
{
node = make_shared<op::v0::DynReshape>(args[0], args[1], zero_flag);
}
if (op_version == 1)
{
node = make_shared<op::v1::Reshape>(args[0], args[1], zero_flag);
}
break;
}
case OP_TYPEID::DynSlice:
......@@ -2672,7 +2680,19 @@ json JSONSerializer::serialize_node(const Node& n)
node["ellipsis_mask"] = tmp->get_ellipsis_mask();
break;
}
case OP_TYPEID::DynReshape: { break;
case OP_TYPEID::DynReshape:
{
if (op_version == 0)
{
auto tmp = dynamic_cast<const op::v0::DynReshape*>(&n);
node["zero_flag"] = tmp->get_zero_flag();
}
if (op_version == 1)
{
auto tmp = dynamic_cast<const op::v1::Reshape*>(&n);
node["zero_flag"] = tmp->get_zero_flag();
}
break;
}
case OP_TYPEID::DynSlice:
{
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_dyn_reshape_upgrade)
{
const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto pattern = make_shared<op::Parameter>(element::i64, Shape{6});
const auto dyn_reshape_v0 = make_shared<op::v0::DynReshape>(arg, pattern, true);
const auto result = make_shared<op::Result>(dyn_reshape_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg, pattern});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = static_pointer_cast<op::v1::Reshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 1);
}
......@@ -279,3 +279,17 @@ TEST(type_prop, dynreshape_pattern_et_wrong)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, reshape_v1_arg_rank_static_pattern_zero)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 0, 2, 8});
auto pattern = op::Constant::create(element::i64, Shape{4}, {1, 2, 0, 32});
auto reshape_v1_static = make_shared<op::v1::Reshape>(arg, pattern, true);
EXPECT_EQ(reshape_v1_static->get_output_shape(0), Shape({1, 2, 2, 32}));
auto dynamic_arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto reshape_v1_dynamic = make_shared<op::v1::Reshape>(dynamic_arg, pattern, true);
EXPECT_TRUE(reshape_v1_dynamic->get_output_partial_shape(0).same_scheme(
PartialShape{1, 2, Dimension::dynamic(), 32}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment