Commit 6b528fb8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[Fused Op] LSTMCell (#2966)

* Move split utility functions into core builder.

* Move activation functions to nGraph core.

* RNN cell base class.

* LSTM cell fused operator.

* Update LSTM ONNX operator to use LSTMCell fused op.

* Use Constant::create instead of make_constant.

* Remove ngraph:: prefixes and include standard headers.

* Store member shared_ptrs as object.

* Formatting.

* Run validation at the end of constructor.

* Add more doc to ActivationFunction.

* Run FusedOpDecomposition pass two times in interpreter backend.

* Remove unnecesary class member.

* Add node validation.

* Disambiguate constructors.

* Add type property test.

* Formatting and add comment with equations.

* Update IGPU backend with LSTMCell fused op.

* Fix: clip activation function input.

* Unit tests.

* Workaround for nested fused op: run FusedOpDecomposition twice.

* Fix compilation on CentOS and on GPU.

* PR feedback.

* Fix CentOS bugs.

* Address review comments.

Remove stored inputs as class members. Use node inputs directly in
decomposition.

* Fix errors.

* Review feedback: don't use decompose_op while generating Function in UTs.

* Fix merge artifacts.

* Move RNNCellBase to op/util directory.

* Fix typo for avg_pool setter method.

* Set default values for optional inputs.

* Fix typo in comment.
parent 746a90e2
......@@ -312,6 +312,8 @@ set (SRC
op/fused/group_conv_transpose.cpp
op/fused/leaky_relu.cpp
op/fused/leaky_relu.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize.cpp
......@@ -332,6 +334,8 @@ set (SRC
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/activation_functions.cpp
op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
......@@ -349,6 +353,8 @@ set (SRC
op/util/logical_reduction.cpp
op/util/logical_reduction.hpp
op/util/reshape.hpp
op/util/rnn_cell_base.cpp
op/util/rnn_cell_base.hpp
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp
partial_shape.cpp
......
......@@ -198,8 +198,6 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace rnn
{
namespace detail
{
std::shared_ptr<ngraph::Node>
sigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Sigmoid>(arg);
}
std::shared_ptr<ngraph::Node>
tanh(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Tanh>(arg);
}
std::shared_ptr<ngraph::Node>
relu(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::Relu>(arg);
}
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta)
{
return std::make_shared<ngraph::op::HardSigmoid>(arg, alpha, beta);
}
} // namespace detail
ActivationFunction::ActivationFunction(ActivationFunctionType f,
float alpha,
float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, std::nanf(""))
{
}
ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, std::nanf(""), std::nanf(""))
{
}
std::shared_ptr<ngraph::Node> ActivationFunction::
operator()(const std::shared_ptr<ngraph::Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
ActivationFunction get_activation_func_by_name(const std::string& func_name)
{
using ActivationFunctionMap = std::unordered_map<std::string, ActivationFunction>;
using namespace std::placeholders;
static ActivationFunctionMap func_map{
{"sigmoid", ActivationFunction{detail::sigmoid}},
{"tanh", ActivationFunction{detail::tanh}},
{"relu", ActivationFunction{detail::relu}},
{"hardsigmoid", ActivationFunction{detail::hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == std::end(func_map))
{
throw error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
} //namespace rnn
} // namespace onnx_import
} // namespace ngraph
......@@ -106,6 +106,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......
......@@ -191,7 +191,7 @@ bool op::AvgPool::get_include_padding_in_avg_computation() const
return m_include_padding_in_avg_computation;
}
void op::AvgPool::get_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
void op::AvgPool::set_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
......
......@@ -148,7 +148,7 @@ namespace ngraph
const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_above);
bool get_include_padding_in_avg_computation() const;
void get_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
/// \return The pad type for pooling.
const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type);
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for lstm cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// Note this class represents only single *cell* and not whole LSTM *layer*.
///
class LSTMCell : public util::FusedOp, public util::RNNCellBase
{
public:
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with shape:
/// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool input_forget);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] C_t The cell state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape: [8*hidden_size].
/// \param[in] P The weight tensor for peepholes with shape:
/// [3*hidden_size] - 3 equals to only iof gates.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates.
///
LSTMCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::shared_ptr<Node>& P,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool input_forget = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_input_forget() const { return m_input_forget; }
private:
std::shared_ptr<Node> get_bias() const;
NodeVector get_peephole_weigths() const;
/// brief Add and initialize bias input to all zeros.
void add_default_bias_input();
/// brief Add and initialize peepholes weights input to all zeros.
void add_default_peepholes_input();
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
util::ActivationFunction m_activation_g;
///
/// \brief The Activation function h.
///
util::ActivationFunction m_activation_h;
///
/// \brief Controls whether to couple input and forget gates.
///
bool m_input_forget = false;
static constexpr std::size_t m_gates_count{4};
static constexpr std::size_t m_peepholes_count{3};
};
}
}
......@@ -13,9 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
using namespace std;
using namespace ngraph;
......@@ -77,10 +78,10 @@ NodeVector op::ShuffleChannels::decompose_op() const
const auto data = get_argument(0);
const auto& data_shape = data->get_shape();
const auto reshaped = util::reshape(data, get_pre_shuffle_shape(data_shape));
const auto shuffled = util::reorder_axes(reshaped, {0, 2, 1, 3});
const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape));
const auto shuffled = builder::reorder_axes(reshaped, {0, 2, 1, 3});
return {util::reshape(shuffled, data_shape)};
return {builder::reshape(shuffled, data_shape)};
}
shared_ptr<Node> op::ShuffleChannels::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -30,6 +30,7 @@ NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LeakyRelu, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iterator>
#include <memory>
#include <unordered_map>
#include "activation_functions.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/tanh.hpp"
using namespace std;
using namespace ngraph;
static shared_ptr<Node> sigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Sigmoid>(arg);
}
static shared_ptr<Node> tanh(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Tanh>(arg);
}
static shared_ptr<Node> relu(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::Relu>(arg);
}
static shared_ptr<Node> hardsigmoid(const shared_ptr<Node>& arg, float alpha, float beta)
{
return make_shared<op::HardSigmoid>(arg, alpha, beta);
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha, float beta)
: m_function{f}
, m_alpha{alpha}
, m_beta{beta}
{
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f, float alpha)
: ActivationFunction(f, alpha, nanf(""))
{
}
op::util::ActivationFunction::ActivationFunction(ActivationFunctionType f)
: ActivationFunction(f, nanf(""), nanf(""))
{
}
shared_ptr<Node> op::util::ActivationFunction::operator()(const shared_ptr<Node>& arg) const
{
return m_function(arg, m_alpha, m_beta);
}
op::util::ActivationFunction op::util::get_activation_func_by_name(const string& func_name)
{
using ActivationFunctionMap = unordered_map<string, op::util::ActivationFunction>;
static ActivationFunctionMap func_map{
{"sigmoid", op::util::ActivationFunction{sigmoid}},
{"tanh", op::util::ActivationFunction{tanh}},
{"relu", op::util::ActivationFunction{relu}},
{"hardsigmoid", op::util::ActivationFunction{hardsigmoid, 0.2f, 0.5f}},
};
auto func_it = func_map.find(func_name);
if (func_it == end(func_map))
{
throw op::util::error::UnknownActivationFunction(func_name);
}
return func_it->second;
}
......@@ -31,19 +31,19 @@
// Prevents the compiler from complaining about or optimizing away variables
// that appear unused on Linux
#if (defined(__GNUC__) && !defined(__clang__))
#undef ONNX_ATTRIBUTE_UNUSED
#define ONNX_ATTRIBUTE_UNUSED __attribute__((__unused__))
#undef NG_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
#define ONNX_ATTRIBUTE_UNUSED
#define NG_ATTRIBUTE_UNUSED
#endif
#define UNUSED_PARAMETER ONNX_ATTRIBUTE_UNUSED = 0
#define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
namespace ngraph
{
namespace onnx_import
namespace op
{
namespace rnn
namespace util
{
namespace error
{
......@@ -58,22 +58,26 @@ namespace ngraph
namespace detail
{
std::shared_ptr<ngraph::Node> sigmoid(const std::shared_ptr<ngraph::Node>& arg,
std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> tanh(const std::shared_ptr<ngraph::Node>& arg,
std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node> relu(const std::shared_ptr<ngraph::Node>& arg,
std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<ngraph::Node>
hardsigmoid(const std::shared_ptr<ngraph::Node>& arg, float alpha, float beta);
std::shared_ptr<Node>
hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
}
using ActivationFunctionType = std::shared_ptr<ngraph::Node> (*)(
const std::shared_ptr<ngraph::Node>&, float, float);
using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&,
float,
float);
///
/// \brief Class representing activation function used in RNN cells.
///
class ActivationFunction
{
public:
......@@ -81,14 +85,19 @@ namespace ngraph
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
std::shared_ptr<ngraph::Node>
operator()(const std::shared_ptr<ngraph::Node>& arg) const;
///
/// \brief Calls stored activation function with provided node argument.
///
std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
void set_alpha(float alpha) { m_alpha = alpha; }
void set_beta(float beta) { m_beta = beta; }
private:
/// \brief Activation function wrapper.
ActivationFunctionType m_function;
/// \brief Activation function alpha parameter (may be unused).
float m_alpha;
/// \brief Activation function beta parameter (may be unused).
float m_beta;
};
......@@ -101,10 +110,9 @@ namespace ngraph
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
} // namespace util
} //namespace rnn
} // namespace onnx_import
} // namespace op
} // namespace ngraph
......@@ -115,6 +123,6 @@ namespace ngraph
#ifdef UNUSED_PARAMETER
#undef UNUSED_PARAMETER
#endif
#ifdef ONNX_ATTRIBUTE_UNUSED
#undef ONNX_ATTRIBUTE_UNUSED
#ifdef NG_ATTRIBUTE_UNUSED
#undef NG_ATTRIBUTE_UNUSED
#endif
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iterator>
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
// Modify input vector in-place and return reference to modified vector.
static vector<string> to_lower_case(const vector<string>& vs)
{
vector<string> res(vs);
transform(begin(res), end(res), begin(res), [](string& s) { return to_lower(s); });
return res;
}
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
float clip,
const vector<string>& activations,
const vector<float>& activation_alpha,
const vector<float>& activation_beta)
: m_hidden_size(hidden_size)
, m_clip(clip)
, m_activations(to_lower_case(activations))
, m_activation_alpha(activation_alpha)
, m_activation_beta(activation_beta)
{
}
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
// Set activation functions parameters (if any)
if (m_activation_alpha.size() > idx)
{
afunc.set_alpha(m_activation_alpha.at(idx));
}
if (m_activation_beta.size() > idx)
{
afunc.set_beta(m_activation_beta.at(idx));
}
return afunc;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <string>
#include <vector>
#include "ngraph/op/util/activation_functions.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Base class for all recurrent network cells.
///
/// \note It holds all common attributes.
///
class RNNCellBase
{
public:
///
/// \brief Constructs a RNNCellBase class.
///
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
RNNCellBase(std::size_t hidden_size,
float clip,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta);
std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; }
const std::vector<float>& get_activation_alpha() const
{
return m_activation_alpha;
}
const std::vector<float>& get_activation_beta() const { return m_activation_beta; }
protected:
///
/// \brief Constructs activation function object.
///
/// \param[in] idx The index of the activation function name.
///
/// \return The object representing activation function.
///
ActivationFunction get_activation_function(std::size_t idx) const;
private:
std::size_t m_hidden_size = 0.f;
float m_clip = 0.f;
const std::vector<std::string> m_activations;
const std::vector<float> m_activation_alpha;
const std::vector<float> m_activation_beta;
};
}
}
}
......@@ -173,6 +173,9 @@ void runtime::gpu::GPUCompiledFunction::compile()
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
......@@ -89,6 +89,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
......@@ -427,6 +428,10 @@ shared_ptr<runtime::Executable>
if (m_disable_backend_optimizations < 2)
{
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
pass_manager.register_pass<ngraph::pass::ImplicitBroadcastElimination>();
......@@ -2067,6 +2072,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
......@@ -2187,6 +2193,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolutionTranspose:
case OP_TYPEID::LeakyRelu:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::PRelu:
......
......@@ -47,6 +47,9 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
......
......@@ -77,6 +77,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/leaky_relu.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......@@ -1126,6 +1127,29 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
args[5],
args[6],
activations,
activation_alpha,
activation_beta,
clip,
input_forget);
break;
}
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -2152,6 +2176,17 @@ static json write(const Node& n, bool binary_constant_data)
node["nsize"] = tmp->get_nsize();
break;
}
case OP_TYPEID::LSTMCell:
{
auto tmp = dynamic_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
This diff is collapsed.
......@@ -14869,6 +14869,120 @@ TEST(type_prop, split)
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
}
TEST(type_prop, lstm_cell)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
EXPECT_EQ(lstm_cell->output(0).get_element_type(), element::f32);
EXPECT_EQ(lstm_cell->output(0).get_shape(), (Shape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->output(1).get_element_type(), element::f32);
EXPECT_EQ(lstm_cell->output(1).get_shape(), (Shape{batch_size, hidden_size}));
}
TEST(type_prop, lstm_cell_invalid_input)
{
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor W must have shape"));
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor R must have shape"));
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor H_t must have shape"));
}
// Invalid C_t tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
C_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor C_t must have shape"));
}
// Invalid B tensor shape.
C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor B must have shape"));
}
// Invalid P tensor shape.
B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size, B, P);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input tensor P must have shape"));
}
}
TEST(type_prop, fake_quantize)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment