......@@ -36,15 +36,8 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size)
: RNNCell(X,
: RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
......@@ -65,11 +58,10 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
, m_H_t{H_t}
, m_activation_f{get_activation_function(0)}
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros,
// thus just initialize bias with appropriate shape and zeros.
m_bias = ngraph::op::Constant::create(element::f32,
Shape{m_gates_count * get_hidden_size()},
vector<float>(m_gates_count * get_hidden_size(), 0.f));
// As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
m_B = op::Constant::create(m_X->get_element_type(),
Shape{2 * get_hidden_size()},
vector<float>(2 * get_hidden_size(), 0.f));
......@@ -90,33 +82,32 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
, m_W{W}
, m_R{R}
, m_H_t{H_t}
, m_B{B}
, m_activation_f{get_activation_function(0)}
// Split B onto Wb and Rb and add them.
(B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
NodeVector b_W_R = builder::split(B, 2);
m_bias = +;
void op::RNNCell::pre_validate_and_infer_types()
const auto& x_shape = input(0).get_shape();
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
"GRUCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size =;
const size_t input_size =;
const auto& w_shape = input(1).get_shape();
const auto& r_shape = input(2).get_shape();
const auto& ht_shape = input(3).get_shape();
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
(w_shape == Shape{get_hidden_size(), input_size}),
......@@ -145,6 +136,24 @@ void op::RNNCell::pre_validate_and_infer_types()
"). Actual shape is:",
if (get_input_size() > 4)
const auto& b_pshape = get_input_partial_shape(4);
this, b_pshape.is_static(), "GRUCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
(b_shape == Shape{2 * get_hidden_size()}),
"Input tensor B must have shape (",
2 * get_hidden_size(),
"). Actual shape is:",
NodeVector op::RNNCell::decompose_op() const
......@@ -155,15 +164,12 @@ NodeVector op::RNNCell::decompose_op() const
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// W - W parameter weight matrix for input gate.
// R - R recurrence weight matrix for input gate.
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
......@@ -174,12 +180,15 @@ NodeVector op::RNNCell::decompose_op() const
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
NodeVector b_W_R = builder::split(m_B, 2);
auto bias = +;
// Xt*(W^T)
auto Xt_W = std::make_shared<ngraph::op::Dot>(m_X, ngraph::op::util::transpose(m_W));
// Ht-1*(R^T)
auto Ht_R = std::make_shared<ngraph::op::Dot>(m_H_t, ngraph::op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto i_t = add(Xt_W, add(Ht_R, m_bias));
auto i_t = add(Xt_W, add(Ht_R, bias));
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
i_t = m_activation_f(clip(i_t));
......@@ -113,8 +113,7 @@ namespace ngraph
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations =
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f);
......@@ -142,17 +141,16 @@ namespace ngraph
std::shared_ptr<Node> m_H_t;
/// \brief The bias tensor for the gates. Shape: [2 * gates_count * hidden_size].
/// \note Concatenation of `[Wb[i], Rb[i]]`.
std::shared_ptr<Node> m_B;
/// \brief The Activation function f.
ActivationFunction m_activation_f;
static constexpr std::size_t m_gates_count{1};
/// \brief Sum of biases (weight and recurrence) for input gate.
/// Sum of `[Wb, Rb]`.
std::shared_ptr<Node> m_bias;
