Commit ccce8bb1 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Implement Reshape downgrade pass (#3751)

* Introduced reshape:v1 downgrade

* Enable downgrade pass in dynamic backend

* Add unit test for dynamic backend downgrade pass

* Clang styles applied

* Apply unit tests name refactor based on code review
Co-Authored-By: 's avatarMichał Karzyński <postrational@users.noreply.github.com>

* Removed redundant pass

* Changed order of downgrade pass

* Changed order of passes in dynamic backend
parent 2ba6aea4
......@@ -117,7 +117,7 @@ namespace ngraph
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Reshape", 1};
static constexpr NodeTypeInfo type_info{"DynReshape", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform
......
......@@ -22,6 +22,7 @@
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
......@@ -236,6 +237,16 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::DynReshape:
{
auto tmp = as_type_ptr<op::v1::Reshape>(node);
auto replacement_node = make_shared<op::v0::DynReshape>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
tmp->get_zero_flag());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::GenerateMask:
{
auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get());
......
......@@ -71,8 +71,9 @@ set(SRC
op.cpp
opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
opset_pass/dyn_reshape_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp
opset_pass/product_opset_pass.cpp
......
......@@ -115,3 +115,30 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_reshape)
ASSERT_EQ(results, data);
}
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_v1)
{
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern);
auto f = std::make_shared<Function>(NodeVector{reshape_v1}, ParameterVector{arg, pattern});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto arg_data = vector<int64_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto pattern_data = vector<int64_t>{2, 2, 3};
auto arg_tensor = backend->create_tensor(element::i64, Shape{arg_data.size()});
auto pattern_tensor = backend->create_tensor(element::i64, Shape{pattern_data.size()});
copy_data(arg_tensor, arg_data);
copy_data(pattern_tensor, pattern_data);
auto output = backend->create_dynamic_tensor(element::i64, PartialShape::dynamic());
ex->call_with_validate({output}, {arg_tensor, pattern_tensor});
ASSERT_EQ(output->get_element_type(), element::i64);
EXPECT_EQ(read_vector<int64_t>(output),
vector<int64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
}
......@@ -19,6 +19,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
......@@ -40,8 +41,30 @@ TEST(opset_transform, opset1_dyn_reshape_upgrade_pass)
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = static_pointer_cast<op::v1::Reshape>(pass_replacement_node);
const auto reshape_v1 = as_type_ptr<op::v1::Reshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 1);
}
TEST(opset_transform, opset1_reshape_downgrade_pass)
{
const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto pattern = make_shared<op::Parameter>(element::i64, Shape{6});
const auto dyn_reshape_v0 = make_shared<op::v1::Reshape>(arg, pattern, true);
const auto result = make_shared<op::Result>(dyn_reshape_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg, pattern});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape_v1 = as_type_ptr<op::v0::DynReshape>(pass_replacement_node);
EXPECT_EQ(reshape_v1->description(), "DynReshape");
EXPECT_EQ(reshape_v1->get_version(), 0);
EXPECT_EQ(reshape_v1->get_zero_flag(), true);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment