Commit 6e6c8af4 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] Enhance LSTM support. (#2408)

parent 25c9152f
...@@ -177,6 +177,8 @@ add_library(onnx_import STATIC ...@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
utils/reduction.hpp utils/reduction.hpp
utils/reshape.cpp utils/reshape.cpp
utils/reshape.hpp utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp) utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "") set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
...@@ -258,6 +258,15 @@ namespace ngraph ...@@ -258,6 +258,15 @@ namespace ngraph
name, std::move(default_value)); name, std::move(default_value));
} }
template <>
std::vector<std::string>
Node::get_attribute_value(const std::string& name,
std::vector<std::string> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<std::string>>(
name, std::move(default_value));
}
template <> template <>
std::vector<Tensor> Node::get_attribute_value(const std::string& name, std::vector<Tensor> Node::get_attribute_value(const std::string& name,
std::vector<Tensor> default_value) const std::vector<Tensor> default_value) const
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
...@@ -24,21 +26,30 @@ ...@@ -24,21 +26,30 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "lstm.hpp" #include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/broadcasting.hpp" #include "utils/broadcasting.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -55,6 +66,13 @@ namespace ngraph ...@@ -55,6 +66,13 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
} }
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
...@@ -62,16 +80,38 @@ namespace ngraph ...@@ -62,16 +80,38 @@ namespace ngraph
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
} }
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ACTIVATION FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ std::shared_ptr<ngraph::Node> clip(const std::shared_ptr<ngraph::Node>& data,
float threshold)
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg)
{ {
return std::make_shared<ngraph::op::Sigmoid>(arg); if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
std::size_t size = ngraph::shape_size(data->get_shape());
const std::shared_ptr<ngraph::Node> min_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, min_val));
const std::shared_ptr<ngraph::Node> max_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, max_val));
return std::make_shared<ngraph::op::Minimum>(
max_val_node, std::make_shared<ngraph::op::Maximum>(data, min_val_node));
} }
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg) // Modify input vector in-place and return reference to modified vector.
std::vector<std::string>& to_lower_case(std::vector<std::string>&& vs)
{ {
return std::make_shared<ngraph::op::Tanh>(arg); std::transform(std::begin(vs),
std::end(vs),
std::begin(vs),
[](std::string& s) { return ngraph::to_lower(s); });
return vs;
} }
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -88,22 +128,6 @@ namespace ngraph ...@@ -88,22 +128,6 @@ namespace ngraph
LSTM_INPUT_P LSTM_INPUT_P
}; };
std::string to_str(const LSTMInput& in)
{
switch (in)
{
case LSTMInput::LSTM_INPUT_X: return "X";
case LSTMInput::LSTM_INPUT_W: return "W";
case LSTMInput::LSTM_INPUT_R: return "R";
case LSTMInput::LSTM_INPUT_B: return "B";
case LSTMInput::LSTM_INPUT_SEQ_LENGTHS: return "sequence_lens";
case LSTMInput::LSTM_INPUT_INIT_H: return "initial_h";
case LSTMInput::LSTM_INPUT_INIT_C: return "initial_c";
case LSTMInput::LSTM_INPUT_P: return "P";
default: return "Unrecognized input value!";
}
}
struct LSTMNgInputMap struct LSTMNgInputMap
{ {
using container_type = std::map<LSTMInput, std::shared_ptr<ngraph::Node>>; using container_type = std::map<LSTMInput, std::shared_ptr<ngraph::Node>>;
...@@ -134,7 +158,7 @@ namespace ngraph ...@@ -134,7 +158,7 @@ namespace ngraph
// ------ Optional inputs ------ // ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 8*hidden_size] // The bias tensor for input gate. Shape [num_directions, 8*hidden_size]
if (ng_inputs.size() >= 4) if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{ {
m_map[LSTMInput::LSTM_INPUT_B] = ng_inputs.at(3); m_map[LSTMInput::LSTM_INPUT_B] = ng_inputs.at(3);
} }
...@@ -146,21 +170,20 @@ namespace ngraph ...@@ -146,21 +170,20 @@ namespace ngraph
{0.f}); {0.f});
} }
// The lengths of the sequences in a batch. Shape [batch_size] // The lengths of the sequences in a batch. Shape [batch_size]
if (ng_inputs.size() >= 5) if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
{ {
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4); m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
} }
else else
{ {
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ngraph::op::Constant::create(
common::make_constant_node<std::int32_t>( element::i32,
element::i32, Shape{batch_size},
{batch_size}, std::vector<std::int32_t>(
{static_cast<std::int32_t>( batch_size, m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0)));
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0))});
} }
// The initial value of the hidden. Shape [num_directions, batch_size, hidden_size] // The initial value of the hidden. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() >= 6) if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
{ {
m_map[LSTMInput::LSTM_INPUT_INIT_H] = ng_inputs.at(5); m_map[LSTMInput::LSTM_INPUT_INIT_H] = ng_inputs.at(5);
} }
...@@ -170,7 +193,7 @@ namespace ngraph ...@@ -170,7 +193,7 @@ namespace ngraph
element::f32, {num_directions, batch_size, hidden_size}, {0.f}); element::f32, {num_directions, batch_size, hidden_size}, {0.f});
} }
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size] // The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() >= 7) if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
{ {
m_map[LSTMInput::LSTM_INPUT_INIT_C] = ng_inputs.at(6); m_map[LSTMInput::LSTM_INPUT_INIT_C] = ng_inputs.at(6);
} }
...@@ -180,7 +203,7 @@ namespace ngraph ...@@ -180,7 +203,7 @@ namespace ngraph
element::f32, {num_directions, batch_size, hidden_size}, {0.f}); element::f32, {num_directions, batch_size, hidden_size}, {0.f});
} }
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size] // The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() >= 8) if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
{ {
m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7); m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
} }
...@@ -197,8 +220,6 @@ namespace ngraph ...@@ -197,8 +220,6 @@ namespace ngraph
{ {
return m_map.at(key); return m_map.at(key);
} }
iterator begin() { return m_map.begin(); }
iterator end() { return m_map.end(); }
container_type m_map; container_type m_map;
}; };
...@@ -208,20 +229,248 @@ namespace ngraph ...@@ -208,20 +229,248 @@ namespace ngraph
{ {
LSTM_DIRECTION_FORWARD, LSTM_DIRECTION_FORWARD,
LSTM_DIRECTION_REVERSE, LSTM_DIRECTION_REVERSE,
LSTM_DIRECTION_BIDIRECTIONAL LSTM_DIRECTION_BIDIRECTIONAL,
LSTM_DIRECTION_UNKNOWN,
}; };
LSTMDirection getLSTMDirection(const std::string& s)
{
if (s == "forward")
{
return LSTMDirection::LSTM_DIRECTION_FORWARD;
}
if (s == "reverse")
{
return LSTMDirection::LSTM_DIRECTION_REVERSE;
}
if (s == "bidirectional")
{
return LSTMDirection::LSTM_DIRECTION_BIDIRECTIONAL;
}
return LSTMDirection::LSTM_DIRECTION_UNKNOWN;
}
struct LSTMAttributes struct LSTMAttributes
{ {
explicit LSTMAttributes(const Node& node) explicit LSTMAttributes(const Node& node)
: m_direction{LSTMDirection::LSTM_DIRECTION_FORWARD} : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
, m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")} , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
, m_activations{to_lower_case(
node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"}))}
, m_input_forget{static_cast<bool>(
node.get_attribute_value<std::int64_t>("input_forget", 0))}
{ {
m_clip_threshold = std::abs(m_clip_threshold);
std::string direction{ngraph::to_lower(
node.get_attribute_value<std::string>("direction", {"forward"}))};
ASSERT_VALID_ARGUMENT(node,
getLSTMDirection(direction) !=
LSTMDirection::LSTM_DIRECTION_UNKNOWN)
<< "Provided attribute \"direction\" value is incorrect: " << direction;
m_direction = getLSTMDirection(direction);
} }
// Currently only LSTM_DIRECTION_FORWARD is supported.
LSTMDirection m_direction; LSTMDirection m_direction;
std::int64_t m_hidden_size; std::int64_t m_hidden_size;
float m_clip_threshold;
std::vector<std::string> m_activations;
bool m_input_forget;
};
class LSTMForward
{
public:
explicit LSTMForward(std::shared_ptr<ngraph::Node> X,
std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> R,
std::shared_ptr<ngraph::Node> B,
std::shared_ptr<ngraph::Node> P,
std::shared_ptr<ngraph::Node> initial_h,
std::shared_ptr<ngraph::Node> initial_c,
std::shared_ptr<ngraph::Node> seq_lengths,
rnn::ActivationFunction activation_f,
rnn::ActivationFunction activation_g,
rnn::ActivationFunction activation_h,
bool input_forget = false,
float clip_threshold = 0.f)
: m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)}
, m_R{reshape::squeeze(R)}
, m_B{reshape::squeeze(B)}
, m_P{reshape::squeeze(P)}
, m_initial_h{reshape::squeeze(initial_h)}
, m_initial_c{reshape::squeeze(initial_c)}
, m_seq_lengths{seq_lengths}
, m_activation_f{activation_f}
, m_activation_g{activation_g}
, m_activation_h{activation_h}
, m_input_forget{input_forget}
, m_clip_threshold{clip_threshold}
{
}
NodeVector run(bool reverse = false)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = reshape::split(m_P, 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
NodeVector h_list;
NodeVector b_W_R = reshape::split(m_B, 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c;
if (reverse)
{
m_X = std::make_shared<ngraph::op::Reverse>(m_X, AxisSet{0});
}
NodeVector in_seqs{};
if (m_X->get_shape().at(0) != 1)
{
in_seqs = reshape::split(m_X, m_X->get_shape().at(0));
}
else
{
in_seqs = NodeVector{m_X};
}
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = reshape::squeeze(in_x);
}
for (const auto& in_x : in_seqs)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W =
std::make_shared<ngraph::op::Dot>(in_x, reshape::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R =
std::make_shared<ngraph::op::Dot>(H_t, reshape::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = reshape::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = m_activation_f(clip(add(i, mul(p_i, C_t)), m_clip_threshold));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i
f = sub(ngraph::op::Constant::create(
i->get_element_type(),
i->get_shape(),
std::vector<float>(shape_size(i->get_shape()), 1.f)),
i);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = m_activation_f(clip(add(f, mul(p_f, C_t)), m_clip_threshold));
}
// ft (.) Ct-1 + it (.) ct
auto C =
add(mul(f, C_t), mul(i, m_activation_g(clip(c, m_clip_threshold))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold));
// ot (.) h(Ct)
auto H = mul(o, m_activation_h(C));
h_list.push_back(H);
H_t = H;
C_t = C;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
NodeVector exp_h_list;
for (const auto& ht : h_list)
{
// Expand tensors with empty outermost dim, so we can later concatenate them.
exp_h_list.push_back(reshape::expand_dims(ht));
}
std::shared_ptr<ngraph::Node> Y{
std::make_shared<ngraph::op::Concat>(exp_h_list, 0)};
// Get back the original order of the output data.
if (reverse)
{
Y = std::make_shared<ngraph::op::Reverse>(Y, AxisSet{0});
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = reshape::expand_dims(Y, 1);
// expand C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_c = reshape::expand_dims(C_t);
return {Y, exp_h_list.back(), Y_c};
}
private:
std::shared_ptr<ngraph::Node> m_X;
std::shared_ptr<ngraph::Node> m_W;
std::shared_ptr<ngraph::Node> m_R;
std::shared_ptr<ngraph::Node> m_B;
std::shared_ptr<ngraph::Node> m_P;
std::shared_ptr<ngraph::Node> m_initial_h;
std::shared_ptr<ngraph::Node> m_initial_c;
std::shared_ptr<ngraph::Node> m_seq_lengths;
rnn::ActivationFunction m_activation_f;
rnn::ActivationFunction m_activation_g;
rnn::ActivationFunction m_activation_h;
// For coupling input and forget gates.
bool m_input_forget;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float m_clip_threshold;
}; };
} // anonymous namespace } // anonymous namespace
...@@ -233,131 +482,85 @@ namespace ngraph ...@@ -233,131 +482,85 @@ namespace ngraph
LSTMNgInputMap input_map{node}; LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node}; LSTMAttributes attributes{node};
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD) rnn::ActivationFunction activation_f =
{ rnn::get_activation_func_by_name(attributes.m_activations.at(0));
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs. rnn::ActivationFunction activation_g =
for (auto& ng_in : input_map) rnn::get_activation_func_by_name(attributes.m_activations.at(1));
{ rnn::ActivationFunction activation_h =
if (ng_in.first != LSTMInput::LSTM_INPUT_X && rnn::get_activation_func_by_name(attributes.m_activations.at(2));
ng_in.first != LSTMInput::LSTM_INPUT_SEQ_LENGTHS)
{
ASSERT_VALID_ARGUMENT(node, ng_in.second->get_shape().at(0) == 1)
<< "Input: { " << to_str(ng_in.first)
<< " } first axis has size different "
"from 1, while direction attribute set to 'forward'.";
ng_in.second = reshape::squeeze(ng_in.second);
}
}
}
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------ NodeVector results;
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = reshape::split(input_map.at(LSTMInput::LSTM_INPUT_P), 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
std::shared_ptr<ngraph::Node> H_t{input_map.at(LSTMInput::LSTM_INPUT_INIT_H)};
std::shared_ptr<ngraph::Node> C_t{input_map.at(LSTMInput::LSTM_INPUT_INIT_C)};
NodeVector h_list;
NodeVector b_W_R = reshape::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
NodeVector in_seqs =
reshape::split(input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_X)->get_shape().at(0));
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = reshape::squeeze(in_x);
}
for (const auto& in_x : in_seqs) if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD ||
attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE)
{ {
// (.) - Denotes element-wise multiplication. LSTMForward lstm_fwd(input_map.at(LSTMInput::LSTM_INPUT_X),
// * - Denotes dot product. input_map.at(LSTMInput::LSTM_INPUT_W),
input_map.at(LSTMInput::LSTM_INPUT_R),
// Xt*(W^T) -- for [iofc] gates. input_map.at(LSTMInput::LSTM_INPUT_B),
auto Xt_W = std::make_shared<ngraph::op::Dot>( input_map.at(LSTMInput::LSTM_INPUT_P),
in_x, reshape::transpose(input_map.at(LSTMInput::LSTM_INPUT_W))); input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
// Ht-1*(R^T) -- for [iofc] gates. input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
auto Ht_R = std::make_shared<ngraph::op::Dot>( input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
H_t, reshape::transpose(input_map.at(LSTMInput::LSTM_INPUT_R))); activation_f,
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates. activation_g,
auto gates = add(Xt_W, add(Ht_R, bias)); activation_h,
attributes.m_input_forget,
NodeVector split_gates = reshape::split(gates, 4, -1); attributes.m_clip_threshold);
auto i = split_gates.at(0); results = lstm_fwd.run(
auto o = split_gates.at(1); (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE));
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = sigmoid(add(i, mul(p_i, C_t)));
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = sigmoid(add(f, mul(p_f, C_t)));
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f, C_t), mul(i, tanh(c)));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = sigmoid(add(o, mul(p_o, C)));
// ot (.) h(Ct)
auto H = mul(o, tanh(C));
h_list.push_back(H);
H_t = H;
C_t = C;
} }
// The tensor that concats all the intermediate output values of the hidden. if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_BIDIRECTIONAL)
// It has shape [seq_length, batch_size, hidden_size]
NodeVector exp_h_list;
for (const auto& ht : h_list)
{ {
// Expand tensors with empty outermost dim, so we can later concatenate them. // In bidirectional mode weights are stacked together, so we must split them.
exp_h_list.push_back(reshape::add_empty_axes(ht)); NodeVector W{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_W), 2)};
NodeVector R{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_R), 2)};
NodeVector B{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_B), 2)};
NodeVector P{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_P), 2)};
NodeVector H{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_H), 2)};
NodeVector C{reshape::split(input_map.at(LSTMInput::LSTM_INPUT_INIT_C), 2)};
LSTMForward lstm_fwd(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(0),
R.at(0),
B.at(0),
P.at(0),
H.at(0),
C.at(0),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(1),
R.at(1),
B.at(1),
P.at(1),
H.at(1),
C.at(1),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_fwd.run(true)};
// Stack together respective outputs from both forward and reverse passess.
std::shared_ptr<ngraph::Node> Y{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
std::shared_ptr<ngraph::Node> Y_h{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
std::shared_ptr<ngraph::Node> Y_c{std::make_shared<ngraph::op::Concat>(
NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
results = NodeVector{Y, Y_h, Y_c};
} }
std::shared_ptr<ngraph::Node> Y{
std::make_shared<ngraph::op::Concat>(exp_h_list, 0)};
// Expand Y so that it has expected shape: return results;
// [seq_length, num_directions, batch_size, hidden_size]
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD)
{
Shape shape{Y->get_shape()};
shape.insert(std::next(std::begin(shape)), 1);
Y = std::make_shared<ngraph::op::Reshape>(
Y, reshape::get_default_axis_vector(Y->get_shape().size()), shape);
}
return {Y, exp_h_list.back()};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -138,7 +138,7 @@ namespace ngraph ...@@ -138,7 +138,7 @@ namespace ngraph
// Expand sub_dot result with single empty outermost axis, in order to // Expand sub_dot result with single empty outermost axis, in order to
// later concatenate sub_dots at this axis. // later concatenate sub_dots at this axis.
small_dots.at(g) = reshape::add_empty_axes(sub_dot); small_dots.at(g) = reshape::expand_dims(sub_dot);
} }
// Concatenate sub_dots on groups axis. // Concatenate sub_dots on groups axis.
......
...@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version. ...@@ -112,7 +112,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------| |------|-----------------|--------|--------|---------|
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. | | Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Pad | 1-2- | 273 | 416 | Not fully supported. | | Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 430 | Not fully supported. | | LSTM | 1-7- | | 476 | Mixed sequences length not supported yet. |
| MaxUnpool | (9) | 286, 289 | 447 | | | MaxUnpool | (9) | 286, 289 | 447 | |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. | | LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. | | Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
......
...@@ -221,17 +221,14 @@ namespace ngraph ...@@ -221,17 +221,14 @@ namespace ngraph
node, get_default_axis_vector(node->get_shape().size()), shape); node, get_default_axis_vector(node->get_shape().size()), shape);
} }
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count, std::size_t axis)
std::size_t innermost_axes_count)
{ {
// Add outermost empty dimensions. Shape output_shape(node->get_shape());
Shape output_shape(outermost_axes_count, 1); // Add empty axis at specified position.
output_shape.insert(std::end(output_shape), auto empty_axis_it = std::begin(output_shape);
std::begin(node->get_shape()), std::advance(empty_axis_it, axis);
std::end(node->get_shape())); output_shape.insert(empty_axis_it, 1);
// Add innermost empty dimensions.
output_shape.insert(std::end(output_shape), innermost_axes_count, 1);
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape); node, reshape::get_default_axis_vector(node->get_shape().size()), output_shape);
} }
......
...@@ -127,19 +127,17 @@ namespace ngraph ...@@ -127,19 +127,17 @@ namespace ngraph
return reshape(node, get_default_axis_vector(node->get_shape().size()), shape); return reshape(node, get_default_axis_vector(node->get_shape().size()), shape);
} }
/// \brief Expands node tensor shape with empty axes. /// \brief Expands node tensor shape with empty axis at
/// specified position.
/// ///
/// \param[in] node The node to be expanded. /// \param[in] node The node to be expanded.
/// \param[in] outermost_axes_count The number of added outermost axes. /// \param[in] axis The position in the expanded axes where the
/// At the front of the shape. /// new axis is placed.
/// \param[in] innermost_axes_count The number of added innermost axes.
/// At the end of the shape.
/// ///
/// \return The node with added empty axes. /// \return The node with added empty axis.
/// ///
std::shared_ptr<ngraph::Node> add_empty_axes(const std::shared_ptr<ngraph::Node>& node, std::shared_ptr<ngraph::Node> expand_dims(const std::shared_ptr<ngraph::Node>& node,
std::size_t outermost_axes_count = 1, std::size_t axis = 0);
std::size_t innermost_axes_count = 0);
/// \brief Split node on specified axis into multiple parts. /// \brief Split node on specified axis into multiple parts.
/// ///
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg)
{
return std::make_shared<ngraph::op::Relu>(arg);
}
} // namespace detail
ActivationFunction get_activation_func_by_name(const std::string& func_name)
{
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
static ActivationFunctionMap func_map{
{"sigmoid", std::bind(detail::sigmoid, std::placeholders::_1)},
{"tanh", std::bind(detail::tanh, std::placeholders::_1)},
{"relu", std::bind(detail::relu, std::placeholders::_1)}};
auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <string>
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace error
{
struct UnknownActivationFunction : ngraph_error
{
UnknownActivationFunction(const std::string& func_name)
: ngraph_error{"Unknown activation function: " + func_name}
{
}
};
}
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg);
}
using ActivationFunction =
std::function<std::shared_ptr<ngraph::Node>(const std::shared_ptr<ngraph::Node>&)>;
/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
...@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k) ...@@ -1864,6 +1864,91 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE(test::all_close(expected_indices_output, indices_output)); EXPECT_TRUE(test::all_close(expected_indices_output, indices_output));
} }
TEST(onnx_${BACKEND_NAME}, model_lstm_fwd_with_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_fwd_with_clip.onnx"));
Inputs inputs{};
// X
inputs.emplace_back(std::vector<float>{-0.455351, -0.276391, -0.185934, -0.269585});
// W
inputs.emplace_back(std::vector<float>{-0.494659f,
0.0453352f,
-0.487793f,
0.417264f,
-0.0175329f,
0.489074f,
-0.446013f,
0.414029f,
-0.0091708f,
-0.255364f,
-0.106952f,
-0.266717f,
-0.0888852f,
-0.428709f,
-0.283349f,
0.208792f});
// R
inputs.emplace_back(std::vector<float>{0.146626f,
-0.0620289f,
-0.0815302f,
0.100482f,
-0.219535f,
-0.306635f,
-0.28515f,
-0.314112f,
-0.228172f,
0.405972f,
0.31576f,
0.281487f,
-0.394864f,
0.42111f,
-0.386624f,
-0.390225f});
// B
inputs.emplace_back(std::vector<float>{0.381619f,
0.0323954f,
-0.14449f,
0.420804f,
-0.258721f,
0.45056f,
-0.250755f,
0.0967895f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f});
// P
inputs.emplace_back(std::vector<float>{0.2345f, 0.5235f, 0.4378f, 0.3475f, 0.8927f, 0.3456f});
Outputs expected_output{};
// Y_data
expected_output.emplace_back(
std::vector<float>{-0.02280854f, 0.02744377f, -0.03516197f, 0.03875681f});
// Y_h_data
expected_output.emplace_back(std::vector<float>{-0.03516197f, 0.03875681f});
// Y_c_data
expected_output.emplace_back(std::vector<float>{-0.07415761f, 0.07395997f});
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(outputs.size() == expected_output.size());
for (std::size_t i{0}; i < expected_output.size(); ++i)
{
// We have to enlarge tolerance bits to 3 - it's only one bit more than default value.
// The discrepancies may occur at most on 7th decimal position.
EXPECT_TRUE(test::all_close_f(expected_output.at(i), outputs.at(i), 3));
}
}
TEST(onnx_${BACKEND_NAME}, model_missing_input) TEST(onnx_${BACKEND_NAME}, model_missing_input)
{ {
onnx_import::register_operator( onnx_import::register_operator(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment