Commit 9b2095ed authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add GatherTree:v1 (#3967)

* GatherTree introduced

* Added GatherTree type_prop tests
parent d6c692af
......@@ -227,6 +227,8 @@ set (SRC
op/gather.hpp
op/gather_nd.cpp
op/gather_nd.hpp
op/gather_tree.cpp
op/gather_tree.hpp
op/get_output_element.cpp
op/get_output_element.hpp
op/greater.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/gather_tree.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::GatherTree::type_info;
op::v1::GatherTree::GatherTree(const Output<Node>& step_ids,
const Output<Node>& parent_idx,
const Output<Node>& max_seq_len,
const Output<Node>& end_token)
: Op({step_ids, parent_idx, max_seq_len, end_token})
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::GatherTree>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
void op::v1::GatherTree::validate_and_infer_types()
{
const auto& step_ids_rank = get_input_partial_shape(0);
const auto& parent_idx_rank = get_input_partial_shape(1);
const auto& max_seq_len_rank = get_input_partial_shape(2);
const auto& end_token_rank = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
step_ids_rank.rank().is_dynamic() ||
static_cast<size_t>(step_ids_rank.rank()) == 3,
"step_ids input rank must equal to 3 (step_ids rank: ",
static_cast<size_t>(step_ids_rank.rank()),
")");
NODE_VALIDATION_CHECK(this,
parent_idx_rank.rank().is_dynamic() ||
static_cast<size_t>(parent_idx_rank.rank()) == 3,
"parent_idx input rank must equal to 3 (parent_idx rank: ",
static_cast<size_t>(parent_idx_rank.rank()),
")");
NODE_VALIDATION_CHECK(this,
max_seq_len_rank.rank().is_dynamic() ||
static_cast<size_t>(max_seq_len_rank.rank()) == 1,
"max_seq_len input rank must equal to 1 (max_seq_len rank: ",
static_cast<size_t>(max_seq_len_rank.rank()),
")");
NODE_VALIDATION_CHECK(this,
end_token_rank.rank().is_dynamic() ||
static_cast<size_t>(end_token_rank.rank()) == 3,
"end_token input rank must equal to 3 (end_token rank: ",
static_cast<size_t>(end_token_rank.rank()),
")");
const auto& step_ids_et = get_input_element_type(0);
set_output_type(0, step_ids_et, step_ids_rank);
}
void op::v1::GatherTree::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const NodeVector& /* deltas */)
{
throw ngraph_error("generate_adjoints is not implemented for GatherTree");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Generates the complete beams from the ids per each step and the parent beam
/// ids.
class NGRAPH_API GatherTree : public Op
{
public:
static constexpr NodeTypeInfo type_info{"GatherTree", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
GatherTree() = default;
/// \param step_ids Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// indices from per each step
/// \param parent_idx Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with
/// parent beam indices
/// \param max_seq_len Tensor of shape [BATCH_SIZE] with maximum lengths for each
/// sequence in the batch
/// \param end_token Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
GatherTree(const Output<Node>& step_ids,
const Output<Node>& parent_idx,
const Output<Node>& max_seq_len,
const Output<Node>& end_token);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
}
......@@ -102,6 +102,7 @@ NGRAPH_OP(GRUCell, ngraph::op::v0, 0)
NGRAPH_OP(Gather, ngraph::op::v0, 0)
NGRAPH_OP(Gather, ngraph::op::v1, 1)
NGRAPH_OP(GatherND, ngraph::op::v0, 0)
NGRAPH_OP(GatherTree, ngraph::op::v1, 1)
NGRAPH_OP(Gelu, ngraph::op::v0, 0)
NGRAPH_OP(GeluBackpropFactor, ngraph::op::v0, 0)
NGRAPH_OP(Gemm, ngraph::op::v0, 0)
......
......@@ -119,6 +119,7 @@
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/gather_tree.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......
......@@ -82,6 +82,7 @@ NGRAPH_OP(FakeQuantize, ngraph::op::v0)
NGRAPH_OP(Floor, ngraph::op::v0)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(GatherTree, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op::v0)
......
......@@ -1524,6 +1524,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::GatherND>(args[0], args[1]);
break;
}
case OP_TYPEID::GatherTree_v1:
{
node = make_shared<op::v1::GatherTree>(args[0], args[1], args[2], args[3]);
break;
}
case OP_TYPEID::Gelu:
{
node = make_shared<op::Gelu>(args[0]);
......@@ -3527,6 +3532,8 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::GatherND: { break;
}
case OP_TYPEID::GatherTree_v1: { break;
}
case OP_TYPEID::GetOutputElement:
{
auto tmp = static_cast<const op::GetOutputElement*>(&n);
......
......@@ -135,6 +135,7 @@ set(SRC
type_prop/fake_quantize.cpp
type_prop/gather.cpp
type_prop/gather_nd.cpp
type_prop/gather_tree.cpp
type_prop/gemm.cpp
type_prop/get_output_element.cpp
type_prop/grn.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, gather_tree_output_shape)
{
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
ASSERT_EQ(gather_tree->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(gather_tree->get_output_element_type(0), element::i64);
}
TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank)
{
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
try
{
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid step_ids input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("step_ids input rank must equal to 3 (step_ids rank: 4)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_tree_parent_idx_invalid_rank)
{
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
try
{
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid parent_idx input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("parent_idx input rank must equal to 3 (parent_idx rank: 4)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_tree_max_seq_len_invalid_rank)
{
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
try
{
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid parent_idx input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("max_seq_len input rank must equal to 1 (max_seq_len rank: 2)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_tree_end_token_invalid_rank)
{
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
try
{
auto gather_tree =
make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid end_token input rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("end_token input rank must equal to 3 (end_token rank: 4)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment