Commit 16550767 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[ONNX] Add ReverseSequence operator (#3239)

* ReverseSequence operator introduced

* Code review remarks introduced

* Added missing EOF

* Removed unused whitespaces in onnx_import.in.cpp

* Added convert to i32 for sequence_lenghts

* Coode review remarks introduced

* Disable reverse sequence for plaidml backend

* Code style fixed
parent c0e2714c
......@@ -155,6 +155,8 @@ add_library(onnx_import STATIC
op/relu.hpp
op/reshape.cpp
op/reshape.hpp
op/reverse_sequence.cpp
op/reverse_sequence.h
op/selu.cpp
op/selu.hpp
op/shape.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reverse_sequence(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
const auto sequence_lengths = node.get_ng_inputs().at(1);
//nGraph supports only int32 type of sequence_lengths
const auto sequence_lengths_i32 = std::make_shared<ngraph::op::Convert>(
node.get_ng_inputs().at(1), element::i32);
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
return {std::make_shared<ngraph::op::ReverseSequence>(
data, sequence_lengths_i32, batch_axis, time_axis)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reverse_sequence(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -96,6 +96,7 @@
#include "op/reduce.hpp"
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/reverse_sequence.h"
#include "op/selu.hpp"
#include "op/shape.hpp"
#include "op/shrink.hpp"
......@@ -319,6 +320,7 @@ namespace ngraph
REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Shrink", 1, shrink);
......
......@@ -16,8 +16,6 @@
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
namespace ngraph
......
......@@ -297,6 +297,8 @@ quantized_conv_int32_output
# unsupported op: `ReverseSequence`
model_lstm_bdir_short_input_seq
model_lstm_mixed_seq_reverse
model_reverse_sequence_0_batch_1
model_reverse_sequence_1_batch_0
# node validation error: "Argument shapes are inconsistent."
model_lstm_fwd_with_clip
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "sequence_lengths"
output: "y"
op_type: "ReverseSequence"
attribute {
name: "batch_axis"
i: 1
type: INT
}
attribute {
name: "time_axis"
i: 0
type: INT
}
}
name: "reverse_sequence_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "sequence_lengths"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 10
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "sequence_lengths"
output: "y"
op_type: "ReverseSequence"
attribute {
name: "batch_axis"
i: 0
type: INT
}
attribute {
name: "time_axis"
i: 1
type: INT
}
}
name: "reverse_sequence_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "sequence_lengths"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 10
}
......@@ -1577,3 +1577,35 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_eye_like)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reverse_sequence_0_batch_1)
{
const auto reverse_sequence_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reverse_sequence_time_0_batch_1.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(reverse_sequence_fn, "${BACKEND_NAME}");
test_case.add_input<float>(
{0.f, 4.f, 8.f, 12.f, 1.f, 5.f, 9.f, 13.f, 2.f, 6.f, 10.f, 14.f, 3.f, 7.f, 11.f, 15.f});
test_case.add_input<int>({4, 3, 2, 1});
test_case.add_expected_output<float>(
Shape{4, 4},
{3.f, 6.f, 9.f, 12.f, 2.f, 5.f, 8.f, 13.f, 1.f, 4.f, 10.f, 14.f, 0.f, 7.f, 11.f, 15.f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reverse_sequence_1_batch_0)
{
const auto reverse_sequence_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reverse_sequence_time_1_batch_0.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(reverse_sequence_fn, "${BACKEND_NAME}");
test_case.add_input<float>(
{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f});
test_case.add_input<int>({1, 2, 3, 4});
test_case.add_expected_output<float>(
Shape{4, 4},
{0.f, 1.f, 2.f, 3.f, 5.f, 4.f, 6.f, 7.f, 10.f, 9.f, 8.f, 11.f, 15.f, 14.f, 13.f, 12.f});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment