Commit ac4676ff authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

[Spec] Implementation of Reverse:v1 (#3536)

parent 385770d8
......@@ -18,6 +18,7 @@
#include <sstream>
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
using namespace std;
......@@ -34,8 +35,8 @@ op::Reverse::Reverse(const Output<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types()
{
auto input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
const auto input_shape = get_input_partial_shape(0);
const Dimension input_rank = input_shape.rank();
if (input_rank.is_static())
{
......@@ -69,3 +70,127 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes));
}
constexpr NodeTypeInfo op::v1::Reverse::type_info;
op::v1::Reverse::Reverse(const Output<Node>& data,
const Output<Node>& reversed_axes,
const std::string& mode)
: Op({data, reversed_axes})
, m_mode{mode_from_string(mode)}
{
constructor_validate_and_infer_types();
}
op::v1::Reverse::Reverse(const Output<Node>& data,
const Output<Node>& reversed_axes,
const Mode mode)
: Op({data, reversed_axes})
, m_mode{mode}
{
constructor_validate_and_infer_types();
}
void op::v1::Reverse::validate_and_infer_types()
{
if (m_mode == Mode::MASK)
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(1) == element::boolean,
"In 'mask' mode the second input must contain boolean values.");
}
const auto input_shape = get_input_partial_shape(0);
const auto input_rank = input_shape.rank();
const auto rev_axes_shape = get_input_partial_shape(1);
const auto rev_axes_rank = rev_axes_shape.rank();
if (rev_axes_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(rev_axes_rank) == 1,
"The reversed_axes input must be a 1D tensor (got ",
static_cast<size_t>(rev_axes_rank),
").");
if (m_mode == Mode::MASK)
{
if (input_rank.is_static() && rev_axes_shape[0].is_static())
{
const auto rev_axes_mask_elems_count = static_cast<size_t>(rev_axes_shape[0]);
NODE_VALIDATION_CHECK(this,
rev_axes_mask_elems_count == static_cast<size_t>(input_rank),
"The number of elements in the reversed_axes tensor (",
rev_axes_mask_elems_count,
") must match the input data tensor rank (",
static_cast<size_t>(input_rank),
") in 'mask' mode.");
}
}
}
if (input_rank.is_static())
{
const auto rank = static_cast<size_t>(input_rank);
const auto rev_axes_node = input_value(1).get_node_shared_ptr();
if (rev_axes_node->is_constant())
{
const auto rev_axes_constant = dynamic_pointer_cast<op::Constant>(rev_axes_node);
if (m_mode == Mode::INDEX)
{
const AxisSet rev_axes = rev_axes_constant->get_axis_set_val();
NODE_VALIDATION_CHECK(this,
rev_axes.size() <= rank,
"Too many axes(",
rev_axes,
") have been provided for given input shape(",
input_shape,
").");
bool all_axes_in_range = all_of(rev_axes.begin(),
rev_axes.end(),
[&rank](const size_t axis) { return axis < rank; });
NODE_VALIDATION_CHECK(this,
all_axes_in_range,
"Some of the provided axes (",
rev_axes,
") are out of bounds (input rank: ",
static_cast<size_t>(input_rank),
").");
}
}
}
set_output_type(0, get_input_element_type(0), input_shape);
}
shared_ptr<Node> op::v1::Reverse::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Reverse>(new_args.at(0), new_args.at(1), m_mode);
}
void op::v1::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
const auto delta = deltas.at(0);
const auto x = input_value(0);
const auto reversed_axes = input_value(1);
adjoints.add_delta(x, make_shared<op::v1::Reverse>(delta, reversed_axes, m_mode));
}
op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode) const
{
static const std::map<std::string, Mode> allowed_values = {{"index", Mode::INDEX},
{"mask", Mode::MASK}};
NODE_VALIDATION_CHECK(this, allowed_values.count(mode) > 0, "Invalid 'mode' value passed in.");
return allowed_values.at(mode);
}
......@@ -77,5 +77,59 @@ namespace ngraph
AxisSet m_reversed_axes;
};
namespace v1
{
class Reverse : public Op
{
public:
enum class Mode
{
INDEX,
MASK
};
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Reverse", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reverse() = default;
/// \brief Constructs a reverse operation.
///
/// \param data The input tensor, some of whose axes are to be reversed.
/// \param reversed_axes The axes to reverse in a form of a set of indices or
/// boolean mask.
/// \param mode The way reversed_axes should be interpreted - a set or a mask.
Reverse(const Output<Node>& data,
const Output<Node>& reversed_axes,
const std::string& mode);
Reverse(const Output<Node>& data,
const Output<Node>& reversed_axes,
const Mode mode);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The second input data interpretation mode.
Mode get_mode() const { return m_mode; }
void set_mode(const Mode mode) { m_mode = mode; }
virtual size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
Mode mode_from_string(const std::string& mode) const;
/// \brief Indicates how the values from the second input should be interpreted.
///
/// The second input can contain a set of indices pointing to axes in the data
/// tensor shape.
/// Alternatively it can contain a boolean mask that indicates which axes should be
/// reversed.
Mode m_mode;
};
}
}
}
......@@ -24,6 +24,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
......@@ -237,6 +238,25 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Reverse:
{
// creates a Constant node from the v0::Reverse reversed_axes attribute
// and uses it as the second input of v1::Reverse
const auto reverse_v0 = dynamic_cast<const op::Reverse*>(node.get());
const auto reversed_axes = reverse_v0->get_reversed_axes();
const auto reversed_axes_constant = op::Constant::create(
element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
const auto reverse_v1 = make_shared<op::v1::Reverse>(node->input(0).get_source_output(),
reversed_axes_constant,
op::v1::Reverse::Mode::INDEX);
replace_node(node, reverse_v1);
modified = true;
break;
}
case OP_TYPEID::Softmax:
{
auto tmp = dynamic_cast<const op::v0::Softmax*>(node.get());
......
......@@ -1897,10 +1897,20 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Reverse:
{
auto reversed_axes = deserialize_axis_set(node_js.at("reversed_axes"));
if (op_version == 0)
{
const auto reversed_axes = deserialize_axis_set(node_js.at("reversed_axes"));
node = make_shared<op::Reverse>(args[0], reversed_axes);
break;
}
else if (op_version == 1)
{
const auto mode = node_js.at("mode").get<op::v1::Reverse::Mode>();
node = make_shared<op::v1::Reverse>(args[0], args[1], mode);
break;
}
break;
}
case OP_TYPEID::ReverseSequence:
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
......@@ -3071,10 +3081,20 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Reverse:
{
auto tmp = dynamic_cast<const op::Reverse*>(&n);
if (op_version == 0)
{
const auto tmp = dynamic_cast<const op::Reverse*>(&n);
node["reversed_axes"] = serialize_axis_set(tmp->get_reversed_axes());
break;
}
else if (op_version == 1)
{
const auto tmp = dynamic_cast<const op::v1::Reverse*>(&n);
node["mode"] = tmp->get_mode();
break;
}
break;
}
case OP_TYPEID::ReverseSequence:
{
auto tmp = dynamic_cast<const op::ReverseSequence*>(&n);
......
......@@ -69,12 +69,12 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
opset_pass/sum_opset_pass.cpp
opset_pass/product_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
opset_pass/product_opset_pass.cpp
opset_pass/reverse_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/sum_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp
partial_shape.cpp
pass.cpp
......
......@@ -382,3 +382,41 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_012)
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, reverse_v1_incorrect_rev_axes_rank_index_mode)
{
const auto Data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto Rev_Axes = make_shared<op::Parameter>(element::i64, Shape{1, 1}); // correct: 1D
EXPECT_THROW(make_shared<Function>(
make_shared<op::v1::Reverse>(Data, Rev_Axes, op::v1::Reverse::Mode::INDEX),
ParameterVector{Data, Rev_Axes}),
ngraph::NodeValidationFailure);
}
NGRAPH_TEST(${BACKEND_NAME}, reverse_v1_incorrect_rev_axes_elems_mask_mode)
{
const auto Data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto Rev_Axes = make_shared<op::Parameter>(element::boolean, Shape{2}); // correct: 3
EXPECT_THROW(make_shared<op::v1::Reverse>(Data, Rev_Axes, op::v1::Reverse::Mode::MASK),
ngraph::NodeValidationFailure);
}
NGRAPH_TEST(${BACKEND_NAME}, reverse_v1_axes_out_of_bounds)
{
const auto Data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto Rev_Axes = op::Constant::create(element::i64, Shape{2}, {1, 10});
EXPECT_THROW(make_shared<op::v1::Reverse>(Data, Rev_Axes, op::v1::Reverse::Mode::INDEX),
ngraph::NodeValidationFailure);
}
NGRAPH_TEST(${BACKEND_NAME}, reverse_v1_too_many_axes)
{
const auto Data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const auto Rev_Axes = op::Constant::create(element::i64, Shape{4}, {0, 1, 2, 3});
EXPECT_THROW(make_shared<op::v1::Reverse>(Data, Rev_Axes, op::v1::Reverse::Mode::INDEX),
ngraph::NodeValidationFailure);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_reverse_upgrade)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 2});
const AxisSet reverse_axes{1, 2};
const auto reverse_v0 = make_shared<op::Reverse>(data, reverse_axes);
const auto result = make_shared<op::Result>(reverse_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reverse_v1 = static_pointer_cast<op::v1::Reverse>(pass_replacement_node);
EXPECT_EQ(reverse_v1->get_mode(), op::v1::Reverse::Mode::INDEX);
EXPECT_EQ(reverse_v1->description(), "Reverse");
EXPECT_EQ(reverse_v1->get_version(), 1);
const auto& rev_axes_input_shape = reverse_v1->get_input_shape(1);
// should match the number of elements of v0::Reverse reverse_axes attribute
EXPECT_EQ(rev_axes_input_shape, Shape{2});
}
......@@ -343,15 +343,15 @@ TEST(serialize, non_zero_node_output)
TEST(serialize, opset1_softmax)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{10});
auto softmax = make_shared<op::v1::Softmax>(arg, 0);
auto result = make_shared<op::Result>(softmax);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
const auto arg = make_shared<op::Parameter>(element::f32, Shape{10});
const auto softmax = make_shared<op::v1::Softmax>(arg, 0);
const auto result = make_shared<op::Result>(softmax);
const auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
string s = serialize(f);
shared_ptr<Function> g = deserialize(s);
auto g_result = g->get_results().at(0);
auto g_softmax = g_result->input(0).get_source_output().get_node_shared_ptr();
const auto g_result = g->get_results().at(0);
const auto g_softmax = g_result->input(0).get_source_output().get_node_shared_ptr();
EXPECT_EQ(g_softmax->description(), "Softmax");
EXPECT_EQ(g_softmax->get_version(), 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment