Commit dfb1476a authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

[Spec] Implement Gather:v1 (#3590)

* Gather:v1 was introduced

* Added support for negative axis

* Removed unsused serialization

* Code review remarks introduced

* Change returned type of get_axis method

* Code review remarks introduced

* Chnaged axis_node to scalar during transformation

* Clang style applied

* Fixed clang errors

* style
parent ab440246
...@@ -15,23 +15,36 @@ ...@@ -15,23 +15,36 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include <limits>
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
static int PARAMS = 0; static const int PARAMS = 0;
static int INDICES = 1; static const int INDICES = 1;
static const int AXIS = 2;
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
constexpr NodeTypeInfo op::v0::Gather::type_info;
constexpr NodeTypeInfo op::Gather::type_info; op::v0::Gather::Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis)
: Op({params, indices})
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Gather::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Gather>(new_args.at(PARAMS), new_args.at(INDICES), m_axis); return make_shared<v0::Gather>(new_args.at(PARAMS), new_args.at(INDICES), m_axis);
} }
void op::Gather::validate_and_infer_types() void op::v0::Gather::validate_and_infer_types()
{ {
element::Type result_et = get_input_element_type(PARAMS); element::Type result_et = get_input_element_type(PARAMS);
element::Type indices_et = get_input_element_type(INDICES); element::Type indices_et = get_input_element_type(INDICES);
...@@ -82,3 +95,115 @@ void op::Gather::validate_and_infer_types() ...@@ -82,3 +95,115 @@ void op::Gather::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
void op::v0::Gather::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("Not yet implemented");
}
constexpr NodeTypeInfo op::v1::Gather::type_info;
op::v1::Gather::Gather(const Output<Node>& params,
const Output<Node>& indices,
const Output<Node>& axes)
: Op({params, indices, axes})
{
constructor_validate_and_infer_types();
}
void op::v1::Gather::validate_and_infer_types()
{
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
const auto& axis_shape = get_input_partial_shape(AXIS);
const auto& axis_rank = axis_shape.rank();
if (axis_rank.is_static() && axis_shape.is_static())
{
const auto axis_is_scalar = static_cast<size_t>(axis_rank) == 0;
const auto axis_has_one_elem =
static_cast<size_t>(axis_rank) == 1 && static_cast<size_t>(axis_shape[0]) == 1;
NODE_VALIDATION_CHECK(this,
axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element (shape: ",
axis_shape,
").");
}
auto axis = get_axis();
if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < static_cast<size_t>(input_rank),
"The axis must => 0 and <= input_rank (axis: ",
axis,
").");
}
element::Type result_et = get_input_element_type(PARAMS);
element::Type indices_et = get_input_element_type(INDICES);
const PartialShape& params_shape = get_input_partial_shape(PARAMS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
PartialShape result_shape;
if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
axis != AXIS_NOT_SET_VALUE)
{
std::vector<Dimension> result_dims(static_cast<size_t>(params_shape.rank()) +
static_cast<size_t>(indices_shape.rank()) - 1);
size_t i = 0;
for (; i < static_cast<size_t>(axis); i++)
{
result_dims[i] = params_shape[i];
}
for (size_t j = 0; j < static_cast<size_t>(indices_shape.rank()); i++, j++)
{
result_dims[i] = indices_shape[j];
}
for (size_t j = static_cast<size_t>(axis) + 1; j < static_cast<size_t>(params_shape.rank());
i++, j++)
{
result_dims[i] = params_shape[j];
}
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
}
size_t op::v1::Gather::get_axis() const
{
int64_t axis = AXIS_NOT_SET_VALUE;
auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node))
{
axis = const_op->get_vector<int64_t>()[0];
}
if (axis < 0)
{
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
if (input_rank.is_static())
{
axis += static_cast<size_t>(input_rank);
}
}
return static_cast<size_t>(axis);
}
void op::v1::Gather::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("Not yet implemented");
}
shared_ptr<Node> op::v1::Gather::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
}
...@@ -22,39 +22,67 @@ namespace ngraph ...@@ -22,39 +22,67 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Gather slices from axis of params according to indices namespace v0
class Gather : public Op
{ {
public: /// \brief Gather slices from axis of params according to indices
NGRAPH_API class Gather : public Op
static constexpr NodeTypeInfo type_info{"Gather", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather
Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
: Op({params, indices})
, m_axis(axis)
{ {
constructor_validate_and_infer_types(); public:
} NGRAPH_API
static constexpr NodeTypeInfo type_info{"Gather", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather
Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& /* adjoints */, void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& /* deltas */) override const NodeVector& deltas) override;
size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
size_t m_axis;
};
}
namespace v1
{
/// \brief Gather slices from axis of params according to indices
class Gather : public Op
{ {
throw ngraph_error("Not yet implemented"); public:
} NGRAPH_API
static constexpr NodeTypeInfo type_info{"Gather", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
Gather(const Output<Node>& params,
const Output<Node>& indices,
const Output<Node>& axis);
size_t get_version() const override { return 1; }
size_t get_axis() const;
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t get_axis() const { return m_axis; } virtual std::shared_ptr<Node>
void set_axis(size_t axis) { m_axis = axis; } copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node> };
copy_with_new_args(const NodeVector& new_args) const override; }
protected: // latest stable opset version
size_t m_axis; using v0::Gather;
};
} }
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
...@@ -113,6 +114,18 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -113,6 +114,18 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get());
int64_t axis = tmp->get_axis();
auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{axis});
auto replacement_node = make_shared<op::v1::Gather>(
node->input(0).get_source_output(), node->input(1).get_source_output(), axis_node);
replace_node(node, replacement_node);
modified = true;
break;
}
default: break; default: break;
} }
#if defined(__clang__) #if defined(__clang__)
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/code_writer.hpp" #include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
...@@ -90,7 +91,6 @@ namespace ngraph ...@@ -90,7 +91,6 @@ namespace ngraph
class ArgMin; class ArgMin;
class ArgMax; class ArgMax;
class TopK; class TopK;
class Gather;
class GatherND; class GatherND;
class ScatterAdd; class ScatterAdd;
class ScatterNDAdd; class ScatterNDAdd;
......
...@@ -1214,8 +1214,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1214,8 +1214,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
{ {
auto axis = node_js.at("axis").get<size_t>(); if (op_version == 0)
node = make_shared<op::Gather>(args[0], args[1], axis); {
auto axis = node_js.at("axis").get<size_t>();
node = make_shared<op::v0::Gather>(args[0], args[1], axis);
}
if (op_version == 1)
{
node = make_shared<op::v1::Gather>(args[0], args[1], args[2]);
}
break; break;
} }
case OP_TYPEID::GatherND: case OP_TYPEID::GatherND:
...@@ -2445,8 +2452,11 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2445,8 +2452,11 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
{ {
auto tmp = dynamic_cast<const op::Gather*>(&n); if (op_version == 0)
node["axis"] = tmp->get_axis(); {
auto tmp = dynamic_cast<const op::v0::Gather*>(&n);
node["axis"] = tmp->get_axis();
}
break; break;
} }
case OP_TYPEID::GatherND: { break; case OP_TYPEID::GatherND: { break;
......
...@@ -70,6 +70,7 @@ set(SRC ...@@ -70,6 +70,7 @@ set(SRC
nop_elimination.cpp nop_elimination.cpp
op.cpp op.cpp
opset_pass/softmax_opset_pass.cpp opset_pass/softmax_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/pad_opset_pass.cpp opset_pass/pad_opset_pass.cpp
partial_shape.cpp partial_shape.cpp
pass.cpp pass.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(serialize, opset1_gather_pass)
{
auto params = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
size_t axis = 1;
auto gather_v0 = make_shared<op::v0::Gather>(params, indices, axis);
auto result = make_shared<op::Result>(gather_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{params, indices});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto gather_s1_result = f->get_results().at(0);
auto node = gather_s1_result->input(0).get_source_output().get_node_shared_ptr();
auto gather_v1_node = static_pointer_cast<op::v1::Gather>(node);
EXPECT_EQ(gather_v1_node->description(), "Gather");
EXPECT_EQ(gather_v1_node->get_version(), 1);
EXPECT_EQ(gather_v1_node->get_axis(), axis);
}
...@@ -357,6 +357,25 @@ TEST(serialize, opset1_softmax) ...@@ -357,6 +357,25 @@ TEST(serialize, opset1_softmax)
EXPECT_EQ(g_softmax->get_version(), 1); EXPECT_EQ(g_softmax->get_version(), 1);
} }
TEST(serialize, opset1_gather)
{
auto params = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
auto axis = make_shared<op::Parameter>(element::i64, Shape{1});
auto gather_v1 = make_shared<op::v1::Gather>(params, indices, axis);
auto result = make_shared<op::Result>(gather_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{params, indices, axis});
string s = serialize(f);
shared_ptr<Function> g = deserialize(s);
auto g_result = g->get_results().at(0);
auto g_gather = g_result->input(0).get_source_output().get_node_shared_ptr();
EXPECT_EQ(g_gather->description(), "Gather");
EXPECT_EQ(g_gather->get_version(), 1);
}
TEST(serialize, opset1_pad) TEST(serialize, opset1_pad)
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{4, 5, 6}); auto arg = make_shared<op::Parameter>(element::f32, Shape{4, 5, 6});
......
...@@ -91,3 +91,57 @@ TEST(type_prop, gather_fail_indices_element_type) ...@@ -91,3 +91,57 @@ TEST(type_prop, gather_fail_indices_element_type)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, gather_v1_incorrect_axis_shape)
{
auto params = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
auto axis = make_shared<op::Parameter>(element::i64, Shape{2});
try
{
auto G = make_shared<op::v1::Gather>(params, indices, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect axis input shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Axes input must be scalar or have 1 element (shape:"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v1_axis_out_of_input_rank)
{
auto params = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
auto axis = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
try
{
auto G = make_shared<op::v1::Gather>(params, indices, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect element of axis input";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The axis must => 0 and <= input_rank (axis:"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v1_negative_axis)
{
auto params = make_shared<op::Parameter>(element::f32, Shape{5, 6, 7});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
int64_t axis = -2;
auto axis_node = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto gather_v1 = make_shared<op::v1::Gather>(params, indices, axis_node);
ASSERT_EQ(gather_v1->get_axis(), 1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment