Commit 34f04d37 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Implement Reverse:v1 downgrade pass (#3707)

* Implemented downgrade pass

* Using Pad:v1 in onnx_importer

* Downgrade transformation doc fixed

* Downgrade pass added for all backends

* Apply suggestions from code review

Changed pad_opset_pass to opset_downgrade
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* Changed order of passes

* Changed downgrade pass order of CPU backend

* Added reverse:v1 downgrade

* Code review remakrs introduced

* Fixed problem with foward declaration
parent bd50f338
...@@ -24,16 +24,16 @@ ...@@ -24,16 +24,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Reverse::type_info; constexpr NodeTypeInfo op::v0::Reverse::type_info;
op::Reverse::Reverse(const Output<Node>& arg, const AxisSet& reversed_axes) op::v0::Reverse::Reverse(const Output<Node>& arg, const AxisSet& reversed_axes)
: Op({arg}) : Op({arg})
, m_reversed_axes(reversed_axes) , m_reversed_axes(reversed_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::Reverse::validate_and_infer_types() void op::v0::Reverse::validate_and_infer_types()
{ {
const auto input_shape = get_input_partial_shape(0); const auto input_shape = get_input_partial_shape(0);
const Dimension input_rank = input_shape.rank(); const Dimension input_rank = input_shape.rank();
...@@ -56,13 +56,13 @@ void op::Reverse::validate_and_infer_types() ...@@ -56,13 +56,13 @@ void op::Reverse::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), input_shape); set_output_type(0, get_input_element_type(0), input_shape);
} }
shared_ptr<Node> op::Reverse::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Reverse::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Reverse>(new_args.at(0), m_reversed_axes); return make_shared<v0::Reverse>(new_args.at(0), m_reversed_axes);
} }
void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{
namespace v0
{ {
// clang-format off // clang-format off
/// \brief Axis-reverse operation. /// \brief Axis-reverse operation.
...@@ -77,6 +79,7 @@ namespace ngraph ...@@ -77,6 +79,7 @@ namespace ngraph
AxisSet m_reversed_axes; AxisSet m_reversed_axes;
}; };
}
namespace v1 namespace v1
{ {
...@@ -131,5 +134,7 @@ namespace ngraph ...@@ -131,5 +134,7 @@ namespace ngraph
Mode m_mode; Mode m_mode;
}; };
} }
// default opset version
using v0::Reverse;
} }
} }
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/reverse.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -87,6 +89,38 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -87,6 +89,38 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Reverse:
{
auto tmp = as_type_ptr<op::v1::Reverse>(node);
auto axes_node = tmp->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(axes_node->is_constant(),
"Unable to convert Reverse:v1 to Reverse:v0 "
"if reduction axes are not constant. Node: ",
*node);
const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
AxisSet axes{};
if (tmp->get_mode() == op::v1::Reverse::Mode::INDEX)
{
axes = axes_node_const->get_axis_vector_val();
}
else // Mode::MASK
{
auto axes_mask = axes_node_const->get_vector<bool>();
for (size_t i = 0; i < axes_mask.size(); ++i)
{
if (axes_mask[i])
{
axes.emplace(i);
}
}
}
auto replacement_node =
make_shared<op::v0::Reverse>(node->input(0).get_source_output(), axes);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break; default: break;
} }
#if defined(__clang__) #if defined(__clang__)
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "ngraph/op/min.hpp" #include "ngraph/op/min.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
...@@ -127,7 +128,6 @@ namespace ngraph ...@@ -127,7 +128,6 @@ namespace ngraph
class QuantizedMaxPool; class QuantizedMaxPool;
class QuantizedAvgPool; class QuantizedAvgPool;
class MaxPoolWithIndices; class MaxPoolWithIndices;
class Reverse;
class ReverseSequence; class ReverseSequence;
class MaxPoolWithIndicesBackprop; class MaxPoolWithIndicesBackprop;
class Erf; class Erf;
......
...@@ -19,18 +19,19 @@ ...@@ -19,18 +19,19 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(serialize, opset1_reverse_upgrade) TEST(opset_upgrade, opset1_reverse_upgrade)
{ {
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2}); const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const AxisSet reverse_axes{1, 2}; const AxisSet reverse_axes{1, 2};
const auto reverse_v0 = make_shared<op::Reverse>(data, reverse_axes); const auto reverse_v0 = make_shared<op::v0::Reverse>(data, reverse_axes);
const auto result = make_shared<op::Result>(reverse_v0); const auto result = make_shared<op::Result>(reverse_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
...@@ -50,3 +51,78 @@ TEST(serialize, opset1_reverse_upgrade) ...@@ -50,3 +51,78 @@ TEST(serialize, opset1_reverse_upgrade)
// should match the number of elements of v0::Reverse reverse_axes attribute // should match the number of elements of v0::Reverse reverse_axes attribute
EXPECT_EQ(rev_axes_input_shape, Shape{2}); EXPECT_EQ(rev_axes_input_shape, Shape{2});
} }
TEST(opset_downgrade, opset0_reverse_downgrade_index_mode)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto reverse_axes =
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto mode = op::v1::Reverse::Mode::INDEX;
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, reverse_axes, mode);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2}));
}
TEST(opset_downgrade, opset0_reverse_downgrade_mask_mode)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto reverse_axes =
make_shared<op::Constant>(element::boolean, Shape{3}, vector<bool>{true, false, true});
auto mode = op::v1::Reverse::Mode::MASK;
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, reverse_axes, mode);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2}));
}
TEST(opset_downgrade, opset0_reverse_downgrade_axes_not_constant)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto axes = make_shared<op::Parameter>(element::boolean, Shape{3});
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, axes, op::v1::Reverse::Mode::MASK);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert Reverse:v1 to Reverse:v0"));
}
catch (...)
{
FAIL() << "Reverse:v1 pass failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment