Commit 6e6c8af4 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] Enhance LSTM support. (#2408)

parent 25c9152f
......@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
......@@ -258,6 +258,15 @@ namespace ngraph
name, std::move(default_value));
}
template <>
std::vector<std::string>
Node::get_attribute_value(const std::string& name,
std::vector<std::string> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<std::string>>(
name, std::move(default_value));
}
template <>
std::vector<Tensor> Node::get_attribute_value(const std::string& name,
std::vector<Tensor> default_value) const
......
This diff is collapsed.
......@@ -138,7 +138,7 @@ namespace ngraph
// Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis.
small_dots.at(g) = reshape::add_empty_axes(sub_dot);
small_dots.at(g) = reshape::expand_dims(sub_dot);
}
// Concatenate sub_dots on groups axis.
......
......@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------|
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 430 | Not fully supported. |
| LSTM | 1-7- | | 476 | Mixed sequences length not supported yet. |
| MaxUnpool | (9) | 286, 289 | 447 | |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
......
......@@ -221,17 +221,14 @@ namespace ngraph
node, get_default_axis_vector(node->get_shape().size()), shape);
}
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count,
std::size_t innermost_axes_count)
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t axis)
{
// Add outermost empty dimensions.
Shape output_shape(outermost_axes_count, 1);
output_shape.insert(std::end(output_shape),
std::begin(node->get_shape()),
std::end(node->get_shape()));
// Add innermost empty dimensions.
output_shape.insert(std::end(output_shape), innermost_axes_count, 1);
Shape output_shape(node->get_shape());
// Add empty axis at specified position.
auto empty_axis_it = std::begin(output_shape);
std::advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1);
return std::make_shared<ngraph::op::Reshape>(
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape);
}
......
......@@ -127,19 +127,17 @@ namespace ngraph
return reshape(node, get_default_axis_vector(node->get_shape().size()), shape);
}
/// \brief Expands node tensor shape with empty axes.
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] node The node to be expanded.
/// \param[in] outermost_axes_count The number of added outermost axes.
/// At the front of the shape.
/// \param[in] innermost_axes_count The number of added innermost axes.
/// At the end of the shape.
/// \param[in] node The node to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return The node with added empty axes.
/// \return The node with added empty axis.
///
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count = 1,
std::size_t innermost_axes_count = 0);
std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Relu>(arg);
}
} // namespace detail
ActivationFunction get_activation_func_by_name(const std::string& func_name)
{
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
static ActivationFunctionMap func_map{
{"sigmoid", std::bind(detail::sigmoid, std::placeholders::_1)},
{"tanh", std::bind(detail::tanh, std::placeholders::_1)},
{"relu", std::bind(detail::relu, std::placeholders::_1)}};
auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <string>
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace error
{
struct UnknownActivationFunction : ngraph_error
{
UnknownActivationFunction(const std::string& func_name)
: ngraph_error{"Unknown activation function: " + func_name}
{
}
};
}
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg);
}
using ActivationFunction =
std::function<std::shared_ptr<ngraph::Node>(const std::shared_ptr<ngraph::Node>&)>;
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
......@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE(test::all_close(expected_indices_output, indices_output));
}
TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_with_clip.onnx"));
Inputs inputs{};
// X
inputs.emplace_back(std::vector<float>{-0.455351, -0.276391, -0.185934, -0.269585});
// W
inputs.emplace_back(std::vector<float>{-0.494659f,
0.0453352f,
-0.487793f,
0.417264f,
-0.0175329f,
0.489074f,
-0.446013f,
0.414029f,
-0.0091708f,
-0.255364f,
-0.106952f,
-0.266717f,
-0.0888852f,
-0.428709f,
-0.283349f,
0.208792f});
// R
inputs.emplace_back(std::vector<float>{0.146626f,
-0.0620289f,
-0.0815302f,
0.100482f,
-0.219535f,
-0.306635f,
-0.28515f,
-0.314112f,
-0.228172f,
0.405972f,
0.31576f,
0.281487f,
-0.394864f,
0.42111f,
-0.386624f,
-0.390225f});
// B
inputs.emplace_back(std::vector<float>{0.381619f,
0.0323954f,
-0.14449f,
0.420804f,
-0.258721f,
0.45056f,
-0.250755f,
0.0967895f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f});
// P
inputs.emplace_back(std::vector<float>{0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f});
Outputs expected_output{};
// Y_data
expected_output.emplace_back(
std::vector<float>{-0.02280854f, 0.02744377f, -0.03516197f, 0.03875681f});
// Y_h_data
expected_output.emplace_back(std::vector<float>{-0.03516197f, 0.03875681f});
// Y_c_data
expected_output.emplace_back(std::vector<float>{-0.07415761f, 0.07395997f});
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(outputs.size() == expected_output.size());
for (std::size_t i{0}; i < expected_output.size(); ++i)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE(test::all_close_f(expected_output.at(i), outputs.at(i), 3));
}
}
TEST(onnx_${BACKEND_NAME}, model_missing_input)
{
onnx_import::register_operator(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment