Commit 16ac55e3 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] LSTM node (#1945)

parent a3cab07b
......@@ -88,6 +88,8 @@ add_library(onnx_import STATIC
op/log_softmax.hpp
op/lrn.cpp
op/lrn.hpp
op/lstm.cpp
op/lstm.hpp
op/matmul.cpp
op/matmul.hpp
op/max_pool.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "utils/broadcasting.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace
{
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ACTIVATION FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMInput
{
LSTM_INPUT_X,
LSTM_INPUT_W,
LSTM_INPUT_R,
LSTM_INPUT_B,
LSTM_INPUT_SEQ_LENGTHS,
LSTM_INPUT_INIT_H,
LSTM_INPUT_INIT_C,
LSTM_INPUT_P
};
std::string to_str(const LSTMInput& in)
{
switch (in)
{
case LSTMInput::LSTM_INPUT_X: return "X";
case LSTMInput::LSTM_INPUT_W: return "W";
case LSTMInput::LSTM_INPUT_R: return "R";
case LSTMInput::LSTM_INPUT_B: return "B";
case LSTMInput::LSTM_INPUT_SEQ_LENGTHS: return "sequence_lens";
case LSTMInput::LSTM_INPUT_INIT_H: return "initial_h";
case LSTMInput::LSTM_INPUT_INIT_C: return "initial_c";
case LSTMInput::LSTM_INPUT_P: return "P";
default: return "Unrecognized input value!";
}
}
struct LSTMNgInputMap
{
using container_type = std::map<LSTMInput, std::shared_ptr<ngraph::Node>>;
using iterator = typename container_type::iterator;
explicit LSTMNgInputMap(const Node& node)
{
const auto& ng_inputs = node.get_ng_inputs();
// We have input, output, forget and cell gates
constexpr std::size_t gates_count{4};
// Peepholes add additional connections to input, output and forget gates.
constexpr std::size_t peepholes_count{3};
// ----- Mandatory inputs ------
// Packed input sequences. Shape: [seq_length, batch_size, input_size]
m_map[LSTMInput::LSTM_INPUT_X] = ng_inputs.at(0);
// Weight tensor for the gates. Shape: [num_directions, 4*hidden_size, input_size]
m_map[LSTMInput::LSTM_INPUT_W] = ng_inputs.at(1);
// The recurrence weight tensor. Shape: [num_directions, 4*hidden_size, hidden_size]
m_map[LSTMInput::LSTM_INPUT_R] = ng_inputs.at(2);
const std::size_t hidden_size =
m_map[LSTMInput::LSTM_INPUT_R]->get_shape().back();
const std::size_t batch_size =
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(1);
const std::size_t num_directions =
m_map[LSTMInput::LSTM_INPUT_W]->get_shape().front();
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
if (ng_inputs.size() >= 4)
{
m_map[LSTMInput::LSTM_INPUT_B] = ng_inputs.at(3);
}
else
{
m_map[LSTMInput::LSTM_INPUT_B] = common::make_constant_node<float>(
element::f32,
{num_directions, 2 * gates_count * hidden_size},
{0.f});
}
// The lengths of the sequences in a batch. Shape [batch_size]
if (ng_inputs.size() >= 5)
{
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
}
else
{
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] =
common::make_constant_node<std::int32_t>(
element::i32,
{batch_size},
{static_cast<std::int32_t>(
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0))});
}
// The initial value of the hidden. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() >= 6)
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] = ng_inputs.at(5);
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] = common::make_constant_node<float>(
element::f32, {num_directions, batch_size, hidden_size}, {0.f});
}
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() >= 7)
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] = ng_inputs.at(6);
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] = common::make_constant_node<float>(
element::f32, {num_directions, batch_size, hidden_size}, {0.f});
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() >= 8)
{
m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
}
else
{
m_map[LSTMInput::LSTM_INPUT_P] = common::make_constant_node<float>(
element::f32,
{num_directions, peepholes_count * hidden_size},
{0.f});
}
}
std::shared_ptr<ngraph::Node>& at(const LSTMInput& key)
{
return m_map.at(key);
}
iterator begin() { return m_map.begin(); }
iterator end() { return m_map.end(); }
container_type m_map;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMDirection
{
LSTM_DIRECTION_FORWARD,
LSTM_DIRECTION_REVERSE,
LSTM_DIRECTION_BIDIRECTIONAL
};
struct LSTMAttributes
{
explicit LSTMAttributes(const Node& node)
: m_direction{LSTMDirection::LSTM_DIRECTION_FORWARD}
, m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
{
}
// Currently only LSTM_DIRECTION_FORWARD is supported.
LSTMDirection m_direction;
std::int64_t m_hidden_size;
};
} // anonymous namespace
namespace set_1
{
NodeVector lstm(const Node& node)
{
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD)
{
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
for (auto& ng_in : input_map)
{
if (ng_in.first != LSTMInput::LSTM_INPUT_X &&
ng_in.first != LSTMInput::LSTM_INPUT_SEQ_LENGTHS)
{
ASSERT_VALID_ARGUMENT(node, ng_in.second->get_shape().at(0) == 1)
<< "Input: { " << to_str(ng_in.first)
<< " } first axis has size different "
"from 1, while direction attribute set to 'forward'.";
ng_in.second = reshape::squeeze(ng_in.second);
}
}
}
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = reshape::split(input_map.at(LSTMInput::LSTM_INPUT_P), 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
std::shared_ptr<ngraph::Node> H_t{input_map.at(LSTMInput::LSTM_INPUT_INIT_H)};
std::shared_ptr<ngraph::Node> C_t{input_map.at(LSTMInput::LSTM_INPUT_INIT_C)};
NodeVector h_list;
NodeVector b_W_R = reshape::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
NodeVector in_seqs =
reshape::split(input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_X)->get_shape().at(0));
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = reshape::squeeze(in_x);
}
for (const auto& in_x : in_seqs)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>(
in_x, reshape::transpose(input_map.at(LSTMInput::LSTM_INPUT_W)));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>(
H_t, reshape::transpose(input_map.at(LSTMInput::LSTM_INPUT_R)));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = reshape::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = sigmoid(add(i, mul(p_i, C_t)));
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = sigmoid(add(f, mul(p_f, C_t)));
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f, C_t), mul(i, tanh(c)));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = sigmoid(add(o, mul(p_o, C)));
// ot (.) h(Ct)
auto H = mul(o, tanh(C));
h_list.push_back(H);
H_t = H;
C_t = C;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
NodeVector exp_h_list;
for (const auto& ht : h_list)
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
exp_h_list.push_back(reshape::add_empty_axes(ht));
}
std::shared_ptr<ngraph::Node> Y{
std::make_shared<ngraph::op::Concat>(exp_h_list, 0)};
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD)
{
Shape shape{Y->get_shape()};
shape.insert(std::next(std::begin(shape)), 1);
Y = std::make_shared<ngraph::op::Reshape>(
Y, reshape::get_default_axis_vector(Y->get_shape().size()), shape);
}
return {Y, exp_h_list.back()};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector lstm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -14,9 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/slice.hpp"
#include "op/split.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
......@@ -82,37 +81,6 @@ namespace ngraph
{
namespace set_1
{
namespace detail
{
template <typename T>
inline T get_valid_array_index(T left, T right)
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
}
} // namespace detail
NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
......@@ -143,16 +111,7 @@ namespace ngraph
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(detail::make_ng_slice(
input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
return reshape::split(input, length_parts, axis_to_split);
}
} // namespace set_1
......
......@@ -56,6 +56,7 @@
#include "op/log.hpp"
#include "op/log_softmax.hpp"
#include "op/lrn.hpp"
#include "op/lstm.hpp"
#include "op/matmul.hpp"
#include "op/max.hpp"
#include "op/max_pool.hpp"
......@@ -183,6 +184,7 @@ namespace ngraph
REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("LSTM", 1, lstm);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
......
......@@ -19,9 +19,15 @@
#include <cmath> // std::floor
#include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if, std::is_floating_point, std::is_integral
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
......@@ -100,6 +106,36 @@ namespace ngraph
return range;
}
/// \brief Makes a Constant Ngraph node.
///
/// \param[in] type The node element type.
/// \param[in] shape The tensor data shape.
/// \param[in] data The data to initialize node with.
///
/// \tparam T Input data value type.
///
/// \return The Ngraph node representing Constant data.
///
template <typename T>
std::shared_ptr<ngraph::Node> make_constant_node(const ngraph::element::Type& type,
const ngraph::Shape& shape,
const std::vector<T>& data)
{
std::shared_ptr<ngraph::Node> node;
// Make constant node filled with single value.
if (data.size() == 1)
{
node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data);
node = make_broadcast_node(node, shape);
}
else
{
node = std::make_shared<ngraph::op::Constant>(type, shape, data);
}
return node;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -15,11 +15,15 @@
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <functional>
#include <iterator>
#include <numeric>
#include <vector>
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "exceptions.hpp"
#include "utils/common.hpp"
......@@ -31,6 +35,33 @@ namespace ngraph
{
namespace reshape
{
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace anonymous
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
{
......@@ -206,6 +237,36 @@ namespace ngraph
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape);
}
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis)
{
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis)
{
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
}
std::size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts(split_parts,
length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
}
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
......@@ -141,6 +141,35 @@ namespace ngraph
std::size_t outermost_axes_count = 1,
std::size_t innermost_axes_count = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] length_parts The vector defining the lengts of each splitted part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \note This implementation supports negative `axis` values (similar to NumPy
/// indexing).
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
std::size_t split_parts,
int axis = 0);
} // namespace reshape
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment