Commit da4e9a0e authored by Adam Rogowiec's avatar Adam Rogowiec

GRUCell operator.

parent 3e12cefa
......@@ -302,6 +302,8 @@ set (SRC
op/fused/grn.hpp
op/fused/group_conv.hpp
op/fused/group_conv.cpp
op/fused/gru_cell.cpp
op/fused/gru_cell.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
......
......@@ -102,6 +102,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/util/activation_functions.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
///
/// \brief Class for GRU cell node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
///
/// Note this class represents only single *cell* and not whole GRU *layer*.
///
class GRUCell : public util::FusedOp, public RNNCellBase
{
public:
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta,
float clip,
bool linear_before_reset);
///
/// \brief Constructs GRUCell node.
///
/// \param[in] X The input tensor with shape: [batch_size, input_size].
/// \param[in] W The weight tensor with shape:
/// [gates_count * hidden_size, input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [gates_count * hidden_size, hidden_size].
/// \param[in] H_t The hidden state tensor at current time step with
/// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] B The bias tensor for input gate with shape:
/// [2 * gates_count * hidden_size].
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activation_alpha The vector of alpha parameters for activation
/// functions in order respective to activation list.
/// \param[in] activation_beta The vector of beta parameters for activation functions
/// in order respective to activation list.
/// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions.
///
GRUCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f,
bool linear_before_reset = false);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_linear_before_reset() const { return m_linear_before_reset; }
private:
///
/// \brief The input data tensor. Shape: [batch_size, input_size].
///
std::shared_ptr<Node> m_X;
///
/// \brief The weight tensor. Shape: [gates_count * hidden_size, input_size].
///
std::shared_ptr<Node> m_W;
///
/// \brief The recurrence weight tensor. Shape: [gates_count * hidden_size, hidden_size].
///
std::shared_ptr<Node> m_R;
///
/// \brief The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
///
std::shared_ptr<Node> m_H_t;
///
/// \brief The bias tensor for the gates. Shape: [2 * gates_count * hidden_size].
/// \note Concatenation of `[Wb[zrh], Rb[zrh]]`.
///
std::shared_ptr<Node> m_B;
///
/// \brief The Activation function f.
///
ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
ActivationFunction m_activation_g;
static constexpr std::size_t m_gates_count{3};
///
/// \brief Control whether or not apply the linear transformation.
///
/// \note The linear transformation may be applied when computing the output of hidden gate.
/// It's done before multiplying by the output of the reset gate.
///
bool m_linear_before_reset;
};
}
}
......@@ -26,6 +26,7 @@ NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
......
......@@ -84,6 +84,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
......@@ -2061,6 +2062,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::GatherND:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::GRN:
case OP_TYPEID::GRUCell:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
......@@ -2180,6 +2182,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::Elu:
case OP_TYPEID::Gemm:
case OP_TYPEID::GRN:
case OP_TYPEID::GRUCell:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
......
......@@ -73,6 +73,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
......@@ -999,6 +1000,27 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::GRN>(args[0], bias);
break;
}
case OP_TYPEID::GRUCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto linear_before_reset = node_js.at("linear_before_reset").get<bool>();
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
args[4],
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
break;
}
case OP_TYPEID::HardSigmoid:
{
auto alpha = node_js.at("alpha").get<float>();
......@@ -1989,6 +2011,17 @@ static json write(const Node& n, bool binary_constant_data)
node["bias"] = tmp->get_bias();
break;
}
case OP_TYPEID::GRUCell:
{
auto tmp = dynamic_cast<const op::GRUCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["linear_before_reset"] = tmp->get_linear_before_reset();
break;
}
case OP_TYPEID::HardSigmoid:
{
auto tmp = dynamic_cast<const op::HardSigmoid*>(&n);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment