Commit 34f04d37 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Implement Reverse:v1 downgrade pass (#3707)

* Implemented downgrade pass

* Using Pad:v1 in onnx_importer

* Downgrade transformation doc fixed

* Downgrade pass added for all backends

* Apply suggestions from code review

Changed pad_opset_pass to opset_downgrade
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* Changed order of passes

* Changed downgrade pass order of CPU backend

* Added reverse:v1 downgrade

* Code review remakrs introduced

* Fixed problem with foward declaration
parent bd50f338
......@@ -24,16 +24,16 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Reverse::type_info;
constexpr NodeTypeInfo op::v0::Reverse::type_info;
op::Reverse::Reverse(const Output<Node>& arg, const AxisSet& reversed_axes)
op::v0::Reverse::Reverse(const Output<Node>& arg, const AxisSet& reversed_axes)
: Op({arg})
, m_reversed_axes(reversed_axes)
{
constructor_validate_and_infer_types();
}
void op::Reverse::validate_and_infer_types()
void op::v0::Reverse::validate_and_infer_types()
{
const auto input_shape = get_input_partial_shape(0);
const Dimension input_rank = input_shape.rank();
......@@ -56,13 +56,13 @@ void op::Reverse::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), input_shape);
}
shared_ptr<Node> op::Reverse::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Reverse::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Reverse>(new_args.at(0), m_reversed_axes);
return make_shared<v0::Reverse>(new_args.at(0), m_reversed_axes);
}
void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
......
......@@ -22,61 +22,64 @@ namespace ngraph
{
namespace op
{
// clang-format off
/// \brief Axis-reverse operation.
///
/// Reverses the direction of zero or more axes in a tensor, where "reversing" an axis means
/// that at the output tensor.
///
/// ## Parameters
///
/// | | Description |
/// | --------------- | ------------------------ |
/// | `reversed_axes` | The axes to be reversed. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | -------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any type and shape. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. |
// clang-format on
class Reverse : public Op
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Reverse", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reverse() = default;
/// \brief Constructs a reverse operation.
// clang-format off
/// \brief Axis-reverse operation.
///
/// \param arg The input tensor, some of whose axes are to be reversed.
/// \param reversed_axes The axes to reverse.
Reverse(const Output<Node>& arg, const AxisSet& reversed_axes);
/// Reverses the direction of zero or more axes in a tensor, where "reversing" an axis means
/// that at the output tensor.
///
/// ## Parameters
///
/// | | Description |
/// | --------------- | ------------------------ |
/// | `reversed_axes` | The axes to be reversed. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | -------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any type and shape. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. |
// clang-format on
class Reverse : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Reverse", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reverse() = default;
/// \brief Constructs a reverse operation.
///
/// \param arg The input tensor, some of whose axes are to be reversed.
/// \param reversed_axes The axes to reverse.
Reverse(const Output<Node>& arg, const AxisSet& reversed_axes);
void validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
void set_reversed_axes(const AxisSet& reversed_axes)
{
m_reversed_axes = reversed_axes;
}
/// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
void set_reversed_axes(const AxisSet& reversed_axes)
{
m_reversed_axes = reversed_axes;
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
AxisSet m_reversed_axes;
};
AxisSet m_reversed_axes;
};
}
namespace v1
{
......@@ -131,5 +134,7 @@ namespace ngraph
Mode m_mode;
};
}
// default opset version
using v0::Reverse;
}
}
......@@ -15,8 +15,10 @@
//*****************************************************************************
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/reverse.hpp"
using namespace std;
using namespace ngraph;
......@@ -87,6 +89,38 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Reverse:
{
auto tmp = as_type_ptr<op::v1::Reverse>(node);
auto axes_node = tmp->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(axes_node->is_constant(),
"Unable to convert Reverse:v1 to Reverse:v0 "
"if reduction axes are not constant. Node: ",
*node);
const auto axes_node_const = as_type_ptr<op::Constant>(axes_node);
AxisSet axes{};
if (tmp->get_mode() == op::v1::Reverse::Mode::INDEX)
{
axes = axes_node_const->get_axis_vector_val();
}
else // Mode::MASK
{
auto axes_mask = axes_node_const->get_vector<bool>();
for (size_t i = 0; i < axes_mask.size(); ++i)
{
if (axes_mask[i])
{
axes.emplace(i);
}
}
}
auto replacement_node =
make_shared<op::v0::Reverse>(node->input(0).get_source_output(), axes);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break;
}
#if defined(__clang__)
......
......@@ -29,6 +29,7 @@
#include "ngraph/op/min.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
......@@ -127,7 +128,6 @@ namespace ngraph
class QuantizedMaxPool;
class QuantizedAvgPool;
class MaxPoolWithIndices;
class Reverse;
class ReverseSequence;
class MaxPoolWithIndicesBackprop;
class Erf;
......
......@@ -19,18 +19,19 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_reverse_upgrade)
TEST(opset_upgrade, opset1_reverse_upgrade)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const AxisSet reverse_axes{1, 2};
const auto reverse_v0 = make_shared<op::Reverse>(data, reverse_axes);
const auto reverse_v0 = make_shared<op::v0::Reverse>(data, reverse_axes);
const auto result = make_shared<op::Result>(reverse_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
......@@ -50,3 +51,78 @@ TEST(serialize, opset1_reverse_upgrade)
// should match the number of elements of v0::Reverse reverse_axes attribute
EXPECT_EQ(rev_axes_input_shape, Shape{2});
}
TEST(opset_downgrade, opset0_reverse_downgrade_index_mode)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto reverse_axes =
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto mode = op::v1::Reverse::Mode::INDEX;
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, reverse_axes, mode);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({1, 2}));
}
TEST(opset_downgrade, opset0_reverse_downgrade_mask_mode)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto reverse_axes =
make_shared<op::Constant>(element::boolean, Shape{3}, vector<bool>{true, false, true});
auto mode = op::v1::Reverse::Mode::MASK;
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, reverse_axes, mode);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v0 = static_pointer_cast<op::v0::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v0->description(), "Reverse");
EXPECT_EQ(reverse_v0->get_version(), 0);
EXPECT_EQ(reverse_v0->get_reversed_axes(), AxisSet({0, 2}));
}
TEST(opset_downgrade, opset0_reverse_downgrade_axes_not_constant)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto axes = make_shared<op::Parameter>(element::boolean, Shape{3});
const auto reverse_v1 = make_shared<op::v1::Reverse>(data, axes, op::v1::Reverse::Mode::MASK);
const auto result = make_shared<op::Result>(reverse_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert Reverse:v1 to Reverse:v0"));
}
catch (...)
{
FAIL() << "Reverse:v1 pass failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment