Commit 6b528fb8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[Fused Op] LSTMCell (#2966)

* Move split utility functions into core builder.

* Move activation functions to nGraph core.

* RNN cell base class.

* LSTM cell fused operator.

* Update LSTM ONNX operator to use LSTMCell fused op.

* Use Constant::create instead of make_constant.

* Remove ngraph:: prefixes and include standard headers.

* Store member shared_ptrs as object.

* Formatting.

* Run validation at the end of constructor.

* Add more doc to ActivationFunction.

* Run FusedOpDecomposition pass two times in interpreter backend.

* Remove unnecesary class member.

* Add node validation.

* Disambiguate constructors.

* Add type property test.

* Formatting and add comment with equations.

* Update IGPU backend with LSTMCell fused op.

* Fix: clip activation function input.

* Unit tests.

* Workaround for nested fused op: run FusedOpDecomposition twice.

* Fix compilation on CentOS and on GPU.

* PR feedback.

* Fix CentOS bugs.

* Address review comments.

Remove stored inputs as class members. Use node inputs directly in
decomposition.

* Fix errors.

* Review feedback: don't use decompose_op while generating Function in UTs.

* Fix merge artifacts.

* Move RNNCellBase to op/util directory.

* Fix typo for avg_pool setter method.

* Set default values for optional inputs.

* Fix typo in comment.
parent 746a90e2
......@@ -312,6 +312,8 @@ set (SRC
op/fused/group_conv_transpose.cpp
op/fused/leaky_relu.cpp
op/fused/leaky_relu.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize.cpp
......@@ -332,6 +334,8 @@ set (SRC
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/activation_functions.cpp
op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
......@@ -349,6 +353,8 @@ set (SRC
op/util/logical_reduction.cpp
op/util/logical_reduction.hpp
op/util/reshape.hpp
op/util/rnn_cell_base.cpp
op/util/rnn_cell_base.hpp
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp
partial_shape.cpp
......
......@@ -198,8 +198,6 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
......@@ -14,46 +14,31 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "core/null_node.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
namespace ngraph
{
......@@ -63,61 +48,6 @@ namespace ngraph
{
namespace
{
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> clip(const std::shared_ptr<ngraph::Node>& data,
float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
std::size_t size = ngraph::shape_size(data->get_shape());
const std::shared_ptr<ngraph::Node> min_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, min_val));
const std::shared_ptr<ngraph::Node> max_val_node =
ngraph::op::Constant::create(data->get_element_type(),
data->get_shape(),
std::vector<float>(size, max_val));
return std::make_shared<ngraph::op::Minimum>(
max_val_node, std::make_shared<ngraph::op::Maximum>(data, min_val_node));
}
// Modify input vector in-place and return reference to modified vector.
std::vector<std::string>& to_lower_case(std::vector<std::string>&& vs)
{
std::transform(std::begin(vs),
std::end(vs),
std::begin(vs),
[](std::string& s) { return ngraph::to_lower(s); });
return vs;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class LSTMInput
......@@ -168,8 +98,11 @@ namespace ngraph
}
else
{
m_map[LSTMInput::LSTM_INPUT_B] = ngraph::builder::make_constant<float>(
element::f32, {num_directions, 2 * gates_count * hidden_size}, 0.f);
m_map[LSTMInput::LSTM_INPUT_B] = ngraph::op::Constant::create(
element::f32,
Shape{num_directions, 2 * gates_count * hidden_size},
std::vector<float>(num_directions * 2 * gates_count * hidden_size,
0.f));
}
// The lengths of the sequences in a batch. Shape [batch_size]
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
......@@ -191,9 +124,10 @@ namespace ngraph
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
ngraph::builder::make_constant<float>(
element::f32, {num_directions, batch_size, hidden_size}, 0.f);
m_map[LSTMInput::LSTM_INPUT_INIT_H] = ngraph::op::Constant::create(
element::f32,
Shape{num_directions, batch_size, hidden_size},
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
}
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
......@@ -202,9 +136,10 @@ namespace ngraph
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
ngraph::builder::make_constant<float>(
element::f32, {num_directions, batch_size, hidden_size}, 0.f);
m_map[LSTMInput::LSTM_INPUT_INIT_C] = ngraph::op::Constant::create(
element::f32,
Shape{num_directions, batch_size, hidden_size},
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
......@@ -213,8 +148,11 @@ namespace ngraph
}
else
{
m_map[LSTMInput::LSTM_INPUT_P] = ngraph::builder::make_constant<float>(
element::f32, {num_directions, peepholes_count * hidden_size}, 0.f);
m_map[LSTMInput::LSTM_INPUT_P] = ngraph::op::Constant::create(
element::f32,
Shape{num_directions, peepholes_count * hidden_size},
std::vector<float>(num_directions * peepholes_count * hidden_size,
0.f));
}
}
......@@ -257,9 +195,8 @@ namespace ngraph
explicit LSTMAttributes(const Node& node)
: m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
, m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
, m_activations{to_lower_case(
node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"}))}
, m_activations{node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh", "tanh"})}
// Default values for activation functions are same as for corresponding
// ONNX operator.
, m_activation_alpha{node.get_attribute_value<std::vector<float>>(
......@@ -292,33 +229,25 @@ namespace ngraph
class LSTMForward
{
public:
explicit LSTMForward(std::shared_ptr<ngraph::Node> X,
std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> R,
std::shared_ptr<ngraph::Node> B,
std::shared_ptr<ngraph::Node> P,
std::shared_ptr<ngraph::Node> initial_h,
std::shared_ptr<ngraph::Node> initial_c,
std::shared_ptr<ngraph::Node> seq_lengths,
rnn::ActivationFunction activation_f,
rnn::ActivationFunction activation_g,
rnn::ActivationFunction activation_h,
bool input_forget = false,
float clip_threshold = 0.f)
explicit LSTMForward(const std::shared_ptr<ngraph::Node>& X,
const std::shared_ptr<ngraph::Node>& W,
const std::shared_ptr<ngraph::Node>& R,
const std::shared_ptr<ngraph::Node>& B,
const std::shared_ptr<ngraph::Node>& P,
const std::shared_ptr<ngraph::Node>& initial_h,
const std::shared_ptr<ngraph::Node>& initial_c,
const std::shared_ptr<ngraph::Node>& seq_lengths,
const LSTMAttributes& attributes)
: m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)}
, m_R{reshape::squeeze(R)}
, m_B{reshape::squeeze(B)}
, m_P{reshape::squeeze(P)}
, m_initial_h{reshape::squeeze(initial_h)}
, m_initial_c{reshape::squeeze(initial_c)}
, m_seq_lengths{seq_lengths}
, m_activation_f{activation_f}
, m_activation_g{activation_g}
, m_activation_h{activation_h}
, m_input_forget{input_forget}
, m_clip_threshold{clip_threshold}
, m_W(reshape::squeeze(W))
, m_R(reshape::squeeze(R))
, m_B(reshape::squeeze(B))
, m_P(reshape::squeeze(P))
, m_initial_h(reshape::squeeze(initial_h))
, m_initial_c(reshape::squeeze(initial_c))
, m_seq_lengths(seq_lengths)
, m_attributes(attributes)
{
}
......@@ -332,7 +261,7 @@ namespace ngraph
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
// P - The weight tensor forr peepholes. [num_directions, 3*hidde_size]
// P - The weight tensor for peepholes. [num_directions, 3*hidde_size]
// ------ ACRONYMS ------
// i - input gate
// o - output gate
......@@ -340,32 +269,11 @@ namespace ngraph
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// b_W_R - Bias vectors for input, output, forget, and cell gates.
// Concatenation of `[Wb, Rb]`.
// p_[iof] - P peephole weight vector for respectively: input, output,
// and forget gates.
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// h_list - The list of hidden states at all processed time steps.
//
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
NodeVector p_iof = ngraph::builder::split(m_P, 3);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
NodeVector h_list;
NodeVector b_W_R = ngraph::builder::split(m_B, 2);
std::shared_ptr<ngraph::Node> bias = b_W_R.at(0) + b_W_R.at(1);
NodeVector h_list;
std::shared_ptr<ngraph::Node> H_t = m_initial_h;
std::shared_ptr<ngraph::Node> C_t = m_initial_c;
......@@ -393,47 +301,24 @@ namespace ngraph
std::int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>(
in_x, ngraph::builder::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>(
H_t, ngraph::builder::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = ngraph::builder::split(gates, 4, -1);
auto i = split_gates.at(0);
auto o = split_gates.at(1);
auto f = split_gates.at(2);
auto c = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i = m_activation_f(clip(add(i, mul(p_i, C_t)), m_clip_threshold));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i
f = sub(ngraph::op::Constant::create(
i->get_element_type(),
i->get_shape(),
std::vector<float>(shape_size(i->get_shape()), 1.f)),
i);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f = m_activation_f(clip(add(f, mul(p_f, C_t)), m_clip_threshold));
}
// ft (.) Ct-1 + it (.) ct
auto C =
add(mul(f, C_t), mul(i, m_activation_g(clip(c, m_clip_threshold))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o = m_activation_f(clip(add(o, mul(p_o, C)), m_clip_threshold));
// ot (.) h(Ct)
auto H = mul(o, m_activation_h(C));
std::shared_ptr<ngraph::Node> lstm_cell =
std::make_shared<ngraph::op::LSTMCell>(
in_x,
m_W,
m_R,
H_t,
C_t,
m_attributes.m_hidden_size,
m_B,
m_P,
m_attributes.m_activations,
m_attributes.m_activation_alpha,
m_attributes.m_activation_beta,
m_attributes.m_clip_threshold,
m_attributes.m_input_forget);
std::shared_ptr<ngraph::Node> H = get_output_element(lstm_cell, 0);
std::shared_ptr<ngraph::Node> C = get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
......@@ -535,34 +420,9 @@ namespace ngraph
std::shared_ptr<ngraph::Node> m_initial_h;
std::shared_ptr<ngraph::Node> m_initial_c;
std::shared_ptr<ngraph::Node> m_seq_lengths;
rnn::ActivationFunction m_activation_f;
rnn::ActivationFunction m_activation_g;
rnn::ActivationFunction m_activation_h;
// For coupling input and forget gates.
bool m_input_forget;
// For clipping cell input in the range [-clip_threshold, clip_threshold].
float m_clip_threshold;
const LSTMAttributes& m_attributes;
};
rnn::ActivationFunction get_activation_function(const LSTMAttributes& attributes,
std::size_t idx)
{
rnn::ActivationFunction afunc =
rnn::get_activation_func_by_name(attributes.m_activations.at(idx));
// Set activation functions parameters (if any)
if (attributes.m_activation_alpha.size() > idx)
{
afunc.set_alpha(attributes.m_activation_alpha.at(idx));
}
if (attributes.m_activation_beta.size() > idx)
{
afunc.set_beta(attributes.m_activation_beta.at(idx));
}
return afunc;
}
} // anonymous namespace
namespace set_1
......@@ -572,14 +432,6 @@ namespace ngraph
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
// Get activation functions.
const rnn::ActivationFunction& activation_f =
get_activation_function(attributes, 0);
const rnn::ActivationFunction& activation_g =
get_activation_function(attributes, 1);
const rnn::ActivationFunction& activation_h =
get_activation_function(attributes, 2);
NodeVector results;
if (attributes.m_direction == LSTMDirection::LSTM_DIRECTION_FORWARD ||
......@@ -593,11 +445,7 @@ namespace ngraph
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
results = lstm_fwd.run(
(attributes.m_direction == LSTMDirection::LSTM_DIRECTION_REVERSE));
}
......@@ -625,11 +473,7 @@ namespace ngraph
H.at(0),
C.at(0),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
LSTMForward lstm_reversed(input_map.at(LSTMInput::LSTM_INPUT_X),
W.at(1),
R.at(1),
......@@ -638,11 +482,7 @@ namespace ngraph
H.at(1),
C.at(1),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
activation_f,
activation_g,
activation_h,
attributes.m_input_forget,
attributes.m_clip_threshold);
attributes);
NodeVector fwd_results{lstm_fwd.run()};
NodeVector rev_results{lstm_fwd.run(true)};
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace detail
{
std::shared_ptr<ngraph::Node>
sigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node>
tanh(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
std::shared_ptr<ngraph::Node>
relu(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Relu>(arg);
}
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::HardSigmoid>(arg, alpha, beta);
}
} // namespace detail
ActivationFunction::ActivationFunction(ActivationFunctionType f,
float alpha,
float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, std::nanf(""))
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, std::nanf(""), std::nanf(""))
{
}
std::shared_ptr<ngraph::Node> ActivationFunction::
operator()(const std::shared_ptr<ngraph::Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
ActivationFunction get_activation_func_by_name(const std::string& func_name)
{
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
using namespace std::placeholders;
static ActivationFunctionMap func_map{
{"sigmoid", ActivationFunction{detail::sigmoid}},
{"tanh", ActivationFunction{detail::tanh}},
{"relu", ActivationFunction{detail::relu}},
{"hardsigmoid", ActivationFunction{detail::hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
......@@ -106,6 +106,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......
......@@ -191,7 +191,7 @@ bool op::AvgPool::get_include_padding_in_avg_computation() const
return m_include_padding_in_avg_computation;
}
void op::AvgPool::get_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
void op::AvgPool::set_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
......
......@@ -148,7 +148,7 @@ namespace ngraph
const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_above);
bool get_include_padding_in_avg_computation() const;
void get_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
/// \return The pad type for pooling.
const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <functional>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// ------------- HELPER FUNCTIONS ---------------------------------------------
static shared_ptr<Node> add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
static shared_ptr<Node> sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
static shared_ptr<Node> mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
static shared_ptr<Node> clip(const shared_ptr<Node>& data, float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
// ------------- LSTM_CELL ----------------------------------------------------
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size)
: LSTMCell(X,
W,
R,
H_t,
C_t,
hidden_size,
vector<string>{"sigmoid", "tanh", "tanh"},
vector<float>{},
vector<float>{},
0.f,
false)
{
}
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip,
bool input_forget)
: FusedOp("LSTMCell", {X, W, R, H_t, C_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
add_default_bias_input();
add_default_peepholes_input();
constructor_validate_and_infer_types();
}
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
const shared_ptr<Node>& C_t,
size_t hidden_size,
const shared_ptr<Node>& B,
const shared_ptr<Node>& P,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta,
float clip,
bool input_forget)
: FusedOp("LSTMCell", {X, W, R, H_t, C_t, B, P})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
constructor_validate_and_infer_types();
}
void op::LSTMCell::pre_validate_and_infer_types()
{
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
const auto& ct_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static() || ct_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
const Shape& ct_shape{ct_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{m_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
m_gates_count * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{m_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
m_gates_count * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ct_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor C_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
const auto& b_pshape = get_input_partial_shape(5);
const auto& p_pshape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
(b_pshape.is_static() || p_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
const Shape& p_shape{p_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
NODE_VALIDATION_CHECK(this,
(p_shape == Shape{m_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (",
m_peepholes_count * get_hidden_size(),
"). Actual shape is:",
p_shape,
".");
}
NodeVector op::LSTMCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// P - The peephole weights for input, output and forget gates.
// ------ VARIABLE NAMES ------
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight matrix for input, output, forget, and cell gates
// Shape: [4*hidden_size, input_size]
// R - The recurrence weight matrix for input, output, forget, and cell gates.
// Shape: [4*hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// C_t - The cell state tensor at current time step. Shape: [batch_size, hidden_size].
// bias - The sum of biases (weight and recurrence) for input, output, forget, and cell gates.
// Shape: [4 * hidden_size]
// p_[iof] - The peephole weight vector for respectively: input, output, and forget gates.
// Each peephole has shape [hidden_size].
//
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
//
// ---- Equations ----
// f, g, h - are activation functions.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// Ct = ft (.) Ct-1 + it (.) ct
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)
// --------------------
shared_ptr<Node> X = get_argument(0);
shared_ptr<Node> W = get_argument(1);
shared_ptr<Node> R = get_argument(2);
shared_ptr<Node> H_t = get_argument(3);
shared_ptr<Node> C_t = get_argument(4);
shared_ptr<Node> bias = get_bias();
NodeVector p_iof = get_peephole_weigths();
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = make_shared<op::Dot>(X, builder::transpose(W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = make_shared<op::Dot>(H_t, builder::transpose(R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = builder::split(gates, 4, -1);
auto i_t = split_gates.at(0);
auto o_t = split_gates.at(1);
auto f_t = split_gates.at(2);
auto c_t = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t)), get_clip()));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i_t
f_t = sub(op::Constant::create(i_t->get_element_type(),
i_t->get_shape(),
vector<float>(shape_size(i_t->get_shape()), 1.f)),
i_t);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t)), get_clip()));
}
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t, get_clip()))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o_t = m_activation_f(clip(add(o_t, mul(p_o, C)), get_clip()));
// ot (.) h(Ct)
auto H = mul(o_t, m_activation_h(clip(C, get_clip())));
return {H, C};
}
shared_ptr<Node> op::LSTMCell::get_bias() const
{
shared_ptr<Node> bias;
// Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(5), 2);
bias = b_W_R.at(0) + b_W_R.at(1);
return bias;
}
NodeVector op::LSTMCell::get_peephole_weigths() const
{
shared_ptr<Node> P;
P = get_argument(6);
return builder::split(P, m_peepholes_count);
}
void op::LSTMCell::add_default_bias_input()
{
shared_ptr<Node> B =
op::Constant::create(input(0).get_element_type(),
Shape{2 * m_gates_count * get_hidden_size()},
vector<float>(2 * m_gates_count * get_hidden_size(), 0.f));
set_argument(5, B->output(0));
}
void op::LSTMCell::add_default_peepholes_input()
{
shared_ptr<Node> P =
op::Constant::create(input(0).get_element_type(),
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
set_argument(6, P->output(0));
}
shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 5)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip(),
m_input_forget);
}
else if (new_args.size() == 7)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
new_args.at(5),
new_args.at(6),
get_activations(),
get_activation_alpha(),
get_activation_beta(),
get_clip(),
m_input_forget);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for lstm cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// Note this class represents only single *cell* and not whole LSTM *layer*.
///
class LSTMCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool input_forget);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape: [8*hidden_size].
/// \param[in] P The weight tensor for peepholes with shape:
/// [3*hidden_size] - 3 equals to only iof gates.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::shared_ptr<Node>& P,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool input_forget = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_input_forget() const { return m_input_forget; }
private:
std::shared_ptr<Node> get_bias() const;
NodeVector get_peephole_weigths() const;
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
/// brief Add and initialize peepholes weights input to all zeros.
void add_default_peepholes_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
util::ActivationFunction m_activation_g;
///
/// \brief The Activation function h.
///
util::ActivationFunction m_activation_h;
///
/// \brief Controls whether to couple input and forget gates.
///
bool m_input_forget = false;
static constexpr std::size_t m_gates_count{4};
static constexpr std::size_t m_peepholes_count{3};
};
}
}
......@@ -13,9 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
using namespace std;
using namespace ngraph;
......@@ -77,10 +78,10 @@ NodeVector op::ShuffleChannels::decompose_op() const
const auto data = get_argument(0);
const auto& data_shape = data->get_shape();
const auto reshaped = util::reshape(data, get_pre_shuffle_shape(data_shape));
const auto shuffled = util::reorder_axes(reshaped, {0, 2, 1, 3});
const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape));
const auto shuffled = builder::reorder_axes(reshaped, {0, 2, 1, 3});
return {util::reshape(shuffled, data_shape)};
return {builder::reshape(shuffled, data_shape)};
}
shared_ptr<Node> op::ShuffleChannels::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -30,6 +30,7 @@ NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <memory>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
using namespace std;
using namespace ngraph;
static shared_ptr<Node> sigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Sigmoid>(arg);
}
static shared_ptr<Node> tanh(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Tanh>(arg);
}
static shared_ptr<Node> relu(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Relu>(arg);
}
static shared_ptr<Node> hardsigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::HardSigmoid>(arg, alpha, beta);
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha, float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, nanf(""))
{
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, nanf(""), nanf(""))
{
}
shared_ptr<Node> op::util::ActivationFunction::operator()(const shared_ptr<Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
op::util::ActivationFunction op::util::get_activation_func_by_name(const string& func_name)
{
using ActivationFunctionMap = unordered_map<string, op::util::ActivationFunction>;
static ActivationFunctionMap func_map{
{"sigmoid", op::util::ActivationFunction{sigmoid}},
{"tanh", op::util::ActivationFunction{tanh}},
{"relu", op::util::ActivationFunction{relu}},
{"hardsigmoid", op::util::ActivationFunction{hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == end(func_map))
{
throw op::util::error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
......@@ -31,19 +31,19 @@
// Prevents the compiler from complaining about or optimizing away variables
// that appear unused on Linux
#if (defined(__GNUC__) && !defined(__clang__))
#undef ONNX_ATTRIBUTE_UNUSED
#define ONNX_ATTRIBUTE_UNUSED __attribute__((__unused__))
#undef NG_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
#define ONNX_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED
#endif
#define UNUSED_PARAMETER ONNX_ATTRIBUTE_UNUSED = 0
#define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
namespace ngraph
{
namespace onnx_import
namespace op
{
namespace rnn
namespace util
{
namespace error
{
......@@ -58,22 +58,26 @@ namespace ngraph
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta);
std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node>
hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
}
using ActivationFunctionType = std::shared_ptr<ngraph::Node> (*)(
const std::shared_ptr<ngraph::Node>&, float, float);
using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&,
float,
float);
///
/// \brief Class representing activation function used in RNN cells.
///
class ActivationFunction
{
public:
......@@ -81,14 +85,19 @@ namespace ngraph
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
std::shared_ptr<ngraph::Node>
operator()(const std::shared_ptr<ngraph::Node>& arg) const;
///
/// \brief Calls stored activation function with provided node argument.
///
std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
void set_alpha(float alpha) { m_alpha = alpha; }
void set_beta(float beta) { m_beta = beta; }
private:
/// \brief Activation function wrapper.
ActivationFunctionType m_function;
/// \brief Activation function alpha parameter (may be unused).
float m_alpha;
/// \brief Activation function beta parameter (may be unused).
float m_beta;
};
......@@ -101,10 +110,9 @@ namespace ngraph
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} // namespace util
} //namespace rnn
} // namespace onnx_import
} // namespace op
} // namespace ngraph
......@@ -115,6 +123,6 @@ namespace ngraph
#ifdef UNUSED_PARAMETER
#undef UNUSED_PARAMETER
#endif
#ifdef ONNX_ATTRIBUTE_UNUSED
#undef ONNX_ATTRIBUTE_UNUSED
#ifdef NG_ATTRIBUTE_UNUSED
#undef NG_ATTRIBUTE_UNUSED
#endif
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iterator>
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// Modify input vector in-place and return reference to modified vector.
static vector<string> to_lower_case(const vector<string>& vs)
{
vector<string> res(vs);
transform(begin(res), end(res), begin(res), [](string& s) { return to_lower(s); });
return res;
}
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
float clip,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta)
: m_hidden_size(hidden_size)
, m_clip(clip)
, m_activations(to_lower_case(activations))
, m_activation_alpha(activation_alpha)
, m_activation_beta(activation_beta)
{
}
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
// Set activation functions parameters (if any)
if (m_activation_alpha.size() > idx)
{
afunc.set_alpha(m_activation_alpha.at(idx));
}
if (m_activation_beta.size() > idx)
{
afunc.set_beta(m_activation_beta.at(idx));
}
return afunc;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <string>
#include <vector>
#include "ngraph/op/util/activation_functions.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Base class for all recurrent network cells.
///
/// \note It holds all common attributes.
///
class RNNCellBase
{
public:
///
/// \brief Constructs a RNNCellBase class.
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
RNNCellBase(std::size_t hidden_size,
float clip,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta);
std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; }
const std::vector<float>& get_activation_alpha() const
{
return m_activation_alpha;
}
const std::vector<float>& get_activation_beta() const { return m_activation_beta; }
protected:
///
/// \brief Constructs activation function object.
///
/// \param[in] idx The index of the activation function name.
///
/// \return The object representing activation function.
///
ActivationFunction get_activation_function(std::size_t idx) const;
private:
std::size_t m_hidden_size = 0.f;
float m_clip = 0.f;
const std::vector<std::string> m_activations;
const std::vector<float> m_activation_alpha;
const std::vector<float> m_activation_beta;
};
}
}
}
......@@ -173,6 +173,9 @@ void runtime::gpu::GPUCompiledFunction::compile()
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
......@@ -89,6 +89,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
......@@ -427,6 +428,10 @@ shared_ptr<runtime::Executable>
if (m_disable_backend_optimizations < 2)
{
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
......@@ -2067,6 +2072,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
......@@ -2187,6 +2193,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
......
......@@ -47,6 +47,9 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
......
......@@ -77,6 +77,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......@@ -1126,6 +1127,29 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
args[5],
args[6],
activations,
activation_alpha,
activation_beta,
clip,
input_forget);
break;
}
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -2152,6 +2176,17 @@ static json write(const Node& n, bool binary_constant_data)
node["nsize"] = tmp->get_nsize();
break;
}
case OP_TYPEID::LSTMCell:
{
auto tmp = dynamic_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
......@@ -1031,6 +1031,342 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_no_bias_no_peepholes)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
auto ht_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 0),
ParameterVector{X, W, R, H_t, C_t});
auto ht_test_case = ngraph::test::NgraphTestCase(ht_function, "${BACKEND_NAME}");
// X
vector<float> in_X{0.81342685f, 0.84108883f, 0.8152282f, 0.46893653f, 0.0901856f, 0.37088776f};
// W
vector<float> in_W{3.3330739e-01f, 3.6229487e-04f, 4.6773660e-01f, 4.3046016e-01f,
7.3950343e-02f, 3.8063636e-01f, 9.6921772e-01f, 9.6897459e-01f,
6.2964785e-01f, 3.1134409e-01f, 8.4709978e-01f, 9.4928098e-01f,
6.1676943e-01f, 6.6020679e-01f, 1.9072217e-01f, 8.8032126e-02f,
4.0472135e-01f, 6.8342745e-01f, 8.3432144e-01f, 4.4928190e-01f,
7.9524308e-01f, 5.3966165e-01f, 8.5936421e-01f, 8.3136767e-01f,
5.5125546e-02f, 4.7791195e-01f, 3.5788772e-01f, 6.7507404e-01f,
2.1716513e-01f, 2.7473119e-01f, 3.3999152e-02f, 9.6835363e-01f,
3.7581277e-01f, 2.4026000e-01f, 6.7418844e-01f, 3.4199652e-01f};
// R
vector<float> in_R{
0.0987983f, 0.52032113f, 0.5848073f, 0.5356095f, 0.74497133f, 0.73260087f,
0.1700787f, 0.45684233f, 0.1495722f, 0.42734373f, 0.4433832f, 0.25906256f,
0.03854987f, 0.47480518f, 0.37215272f, 0.99890584f, 0.74019486f, 0.3518967f,
0.6881257f, 0.8170279f, 0.54088944f, 0.81225616f, 0.14619833f, 0.42941234f,
0.86843914f, 0.45967972f, 0.6237719f, 0.11074839f, 0.6029616f, 0.3149305f,
0.46504205f, 0.5843412f, 0.8733427f, 0.7687243f, 0.07074859f, 0.39188156f};
// Ht
vector<float> in_Ht{0.77956f, 0.5331557f, 0.04297554f, 0.7962175f, 0.7635707f, 0.11989366f};
// Ct
vector<float> in_Ct{0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f};
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f});
ht_test_case.run();
auto ct_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 1),
ParameterVector{X, W, R, H_t, C_t});
auto ct_test_case = ngraph::test::NgraphTestCase(ct_function, "${BACKEND_NAME}");
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f});
ct_test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
auto ht_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 0),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ht_test_case = ngraph::test::NgraphTestCase(ht_function, "${BACKEND_NAME}");
// X
vector<float> in_X{0.81342685f, 0.84108883f, 0.8152282f, 0.46893653f, 0.0901856f, 0.37088776f};
// W
vector<float> in_W{3.3330739e-01f, 3.6229487e-04f, 4.6773660e-01f, 4.3046016e-01f,
7.3950343e-02f, 3.8063636e-01f, 9.6921772e-01f, 9.6897459e-01f,
6.2964785e-01f, 3.1134409e-01f, 8.4709978e-01f, 9.4928098e-01f,
6.1676943e-01f, 6.6020679e-01f, 1.9072217e-01f, 8.8032126e-02f,
4.0472135e-01f, 6.8342745e-01f, 8.3432144e-01f, 4.4928190e-01f,
7.9524308e-01f, 5.3966165e-01f, 8.5936421e-01f, 8.3136767e-01f,
5.5125546e-02f, 4.7791195e-01f, 3.5788772e-01f, 6.7507404e-01f,
2.1716513e-01f, 2.7473119e-01f, 3.3999152e-02f, 9.6835363e-01f,
3.7581277e-01f, 2.4026000e-01f, 6.7418844e-01f, 3.4199652e-01f};
// R
vector<float> in_R{
0.0987983f, 0.52032113f, 0.5848073f, 0.5356095f, 0.74497133f, 0.73260087f,
0.1700787f, 0.45684233f, 0.1495722f, 0.42734373f, 0.4433832f, 0.25906256f,
0.03854987f, 0.47480518f, 0.37215272f, 0.99890584f, 0.74019486f, 0.3518967f,
0.6881257f, 0.8170279f, 0.54088944f, 0.81225616f, 0.14619833f, 0.42941234f,
0.86843914f, 0.45967972f, 0.6237719f, 0.11074839f, 0.6029616f, 0.3149305f,
0.46504205f, 0.5843412f, 0.8733427f, 0.7687243f, 0.07074859f, 0.39188156f};
// Ht
vector<float> in_Ht{0.77956f, 0.5331557f, 0.04297554f, 0.7962175f, 0.7635707f, 0.11989366f};
// Ct
vector<float> in_Ct{0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f};
// B
vector<float> in_B{0.81130236f, 0.31332242f, 0.6423671f, 0.09981899f, 0.7847627f,
0.8405669f, 0.0330242f, 0.45014873f, 0.5599519f, 0.31807426f,
0.7356558f, 0.6298691f, 0.26263478f, 0.8391581f, 0.52434635f,
0.11468413f, 0.4533051f, 0.67632145f, 0.43415946f, 0.46795473f,
0.5674715f, 0.19214648f, 0.37824264f, 0.11187395f};
// P
vector<float> in_P{0.38557124f,
0.9482306f,
0.6808912f,
0.93585867f,
0.74540526f,
0.10507805f,
0.8180733f,
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.9218244f, 0.78787273f, 0.8754273f, 0.7361462f, 0.70927656f, 0.83522964f});
ht_test_case.run();
auto ct_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 1),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ct_test_case = ngraph::test::NgraphTestCase(ct_function, "${BACKEND_NAME}");
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{1.7094649f, 1.1259761f, 1.444019f, 1.086587f, 0.9762144f, 1.3066899f});
ct_test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const float clip_threshold = 3.5f;
bool input_forget = true;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X,
W,
R,
H_t,
C_t,
hidden_size,
B,
P,
vector<string>{"sigmoid", "tanh", "tanh"},
vector<float>{},
vector<float>{},
clip_threshold,
input_forget);
auto ht_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 0),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ht_test_case = ngraph::test::NgraphTestCase(ht_function, "${BACKEND_NAME}");
// X
vector<float> in_X{0.81342685f, 0.84108883f, 0.8152282f, 0.46893653f, 0.0901856f, 0.37088776f};
// W
vector<float> in_W{3.3330739e-01f, 3.6229487e-04f, 4.6773660e-01f, 4.3046016e-01f,
7.3950343e-02f, 3.8063636e-01f, 9.6921772e-01f, 9.6897459e-01f,
6.2964785e-01f, 3.1134409e-01f, 8.4709978e-01f, 9.4928098e-01f,
6.1676943e-01f, 6.6020679e-01f, 1.9072217e-01f, 8.8032126e-02f,
4.0472135e-01f, 6.8342745e-01f, 8.3432144e-01f, 4.4928190e-01f,
7.9524308e-01f, 5.3966165e-01f, 8.5936421e-01f, 8.3136767e-01f,
5.5125546e-02f, 4.7791195e-01f, 3.5788772e-01f, 6.7507404e-01f,
2.1716513e-01f, 2.7473119e-01f, 3.3999152e-02f, 9.6835363e-01f,
3.7581277e-01f, 2.4026000e-01f, 6.7418844e-01f, 3.4199652e-01f};
// R
vector<float> in_R{
0.0987983f, 0.52032113f, 0.5848073f, 0.5356095f, 0.74497133f, 0.73260087f,
0.1700787f, 0.45684233f, 0.1495722f, 0.42734373f, 0.4433832f, 0.25906256f,
0.03854987f, 0.47480518f, 0.37215272f, 0.99890584f, 0.74019486f, 0.3518967f,
0.6881257f, 0.8170279f, 0.54088944f, 0.81225616f, 0.14619833f, 0.42941234f,
0.86843914f, 0.45967972f, 0.6237719f, 0.11074839f, 0.6029616f, 0.3149305f,
0.46504205f, 0.5843412f, 0.8733427f, 0.7687243f, 0.07074859f, 0.39188156f};
// Ht
vector<float> in_Ht{0.77956f, 0.5331557f, 0.04297554f, 0.7962175f, 0.7635707f, 0.11989366f};
// Ct
vector<float> in_Ct{0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f};
// B
vector<float> in_B{0.81130236f, 0.31332242f, 0.6423671f, 0.09981899f, 0.7847627f,
0.8405669f, 0.0330242f, 0.45014873f, 0.5599519f, 0.31807426f,
0.7356558f, 0.6298691f, 0.26263478f, 0.8391581f, 0.52434635f,
0.11468413f, 0.4533051f, 0.67632145f, 0.43415946f, 0.46795473f,
0.5674715f, 0.19214648f, 0.37824264f, 0.11187395f};
// P
vector<float> in_P{0.38557124f,
0.9482306f,
0.6808912f,
0.93585867f,
0.74540526f,
0.10507805f,
0.8180733f,
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.71485436f, 0.71844107f, 0.72704613f, 0.6235602f, 0.68306124f, 0.6978715f});
ht_test_case.run();
auto ct_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 1),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ct_test_case = ngraph::test::NgraphTestCase(ct_function, "${BACKEND_NAME}");
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
ct_test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const float clip_threshold = 3.5f;
bool input_forget = true;
vector<string> activations{"sigmoid", "tanh", "hardsigmoid"};
vector<float> activation_alpha{0.f, 0.f, 1.8345f};
vector<float> activation_beta{0.f, 0.f, 3.05f};
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X,
W,
R,
H_t,
C_t,
hidden_size,
B,
P,
activations,
activation_alpha,
activation_beta,
clip_threshold,
input_forget);
auto ht_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 0),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ht_test_case = ngraph::test::NgraphTestCase(ht_function, "${BACKEND_NAME}");
// X
vector<float> in_X{0.81342685f, 0.84108883f, 0.8152282f, 0.46893653f, 0.0901856f, 0.37088776f};
// W
vector<float> in_W{3.3330739e-01f, 3.6229487e-04f, 4.6773660e-01f, 4.3046016e-01f,
7.3950343e-02f, 3.8063636e-01f, 9.6921772e-01f, 9.6897459e-01f,
6.2964785e-01f, 3.1134409e-01f, 8.4709978e-01f, 9.4928098e-01f,
6.1676943e-01f, 6.6020679e-01f, 1.9072217e-01f, 8.8032126e-02f,
4.0472135e-01f, 6.8342745e-01f, 8.3432144e-01f, 4.4928190e-01f,
7.9524308e-01f, 5.3966165e-01f, 8.5936421e-01f, 8.3136767e-01f,
5.5125546e-02f, 4.7791195e-01f, 3.5788772e-01f, 6.7507404e-01f,
2.1716513e-01f, 2.7473119e-01f, 3.3999152e-02f, 9.6835363e-01f,
3.7581277e-01f, 2.4026000e-01f, 6.7418844e-01f, 3.4199652e-01f};
// R
vector<float> in_R{
0.0987983f, 0.52032113f, 0.5848073f, 0.5356095f, 0.74497133f, 0.73260087f,
0.1700787f, 0.45684233f, 0.1495722f, 0.42734373f, 0.4433832f, 0.25906256f,
0.03854987f, 0.47480518f, 0.37215272f, 0.99890584f, 0.74019486f, 0.3518967f,
0.6881257f, 0.8170279f, 0.54088944f, 0.81225616f, 0.14619833f, 0.42941234f,
0.86843914f, 0.45967972f, 0.6237719f, 0.11074839f, 0.6029616f, 0.3149305f,
0.46504205f, 0.5843412f, 0.8733427f, 0.7687243f, 0.07074859f, 0.39188156f};
// Ht
vector<float> in_Ht{0.77956f, 0.5331557f, 0.04297554f, 0.7962175f, 0.7635707f, 0.11989366f};
// Ct
vector<float> in_Ct{0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f};
// B
vector<float> in_B{0.81130236f, 0.31332242f, 0.6423671f, 0.09981899f, 0.7847627f,
0.8405669f, 0.0330242f, 0.45014873f, 0.5599519f, 0.31807426f,
0.7356558f, 0.6298691f, 0.26263478f, 0.8391581f, 0.52434635f,
0.11468413f, 0.4533051f, 0.67632145f, 0.43415946f, 0.46795473f,
0.5674715f, 0.19214648f, 0.37824264f, 0.11187395f};
// P
vector<float> in_P{0.38557124f,
0.9482306f,
0.6808912f,
0.93585867f,
0.74540526f,
0.10507805f,
0.8180733f,
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.96834344f, 0.9695254f, 0.97068775f, 0.9077866f, 0.94161016f, 0.96599925f});
ht_test_case.run();
auto ct_function = make_shared<Function>(make_shared<op::GetOutputElement>(lstm_cell, 1),
ParameterVector{X, W, R, H_t, C_t, B, P});
auto ct_test_case = ngraph::test::NgraphTestCase(ct_function, "${BACKEND_NAME}");
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_W, in_R, in_Ht, in_Ct, in_B, in_P});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
ct_test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, fake_quantize)
{
const Shape data_shape{1, 2, 3, 4};
......
......@@ -14869,6 +14869,120 @@ TEST(type_prop, split)
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
}
TEST(type_prop, lstm_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
EXPECT_EQ(lstm_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->output(1).get_element_type(), element::f32);
EXPECT_EQ(lstm_cell->output(1).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, lstm_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid C_t tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
C_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor C_t must have shape"));
}
// Invalid B tensor shape.
C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
// Invalid P tensor shape.
B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor P must have shape"));
}
}
TEST(type_prop, fake_quantize)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment