Commit 5b459fd9 authored by Adam Rogowiec's avatar Adam Rogowiec

Store original Bias input as RNNCell member.

parent 835dd295
......@@ -36,15 +36,8 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& R,
const shared_ptr<Node>& H_t,
size_t hidden_size)
: RNNCell(X,
W,
R,
H_t,
hidden_size,
vector<string>{"tanh"},
vector<float>{},
vector<float>{},
0.f)
: RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
{
}
......@@ -65,11 +58,10 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
, m_H_t{H_t}
, m_activation_f{get_activation_function(0)}
{
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros,
// thus just initialize bias with appropriate shape and zeros.
m_bias = ngraph::op::Constant::create(element::f32,
Shape{m_gates_count * get_hidden_size()},
vector<float>(m_gates_count * get_hidden_size(), 0.f));
// As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
m_B = op::Constant::create(m_X->get_element_type(),
Shape{2 * get_hidden_size()},
vector<float>(2 * get_hidden_size(), 0.f));
constructor_validate_and_infer_types();
}
......@@ -90,33 +82,32 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
, m_W{W}
, m_R{R}
, m_H_t{H_t}
, m_B{B}
, m_activation_f{get_activation_function(0)}
{
// Split B onto Wb and Rb and add them.
NODE_VALIDATION_CHECK(this,
(B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
B->get_shape(),
".");
NodeVector b_W_R = builder::split(B, 2);
m_bias = b_W_R.at(0) + b_W_R.at(1);
constructor_validate_and_infer_types();
}
void op::RNNCell::pre_validate_and_infer_types()
{
const auto& x_shape = input(0).get_shape();
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static()),
"GRUCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const auto& w_shape = input(1).get_shape();
const auto& r_shape = input(2).get_shape();
const auto& ht_shape = input(3).get_shape();
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{get_hidden_size(), input_size}),
......@@ -145,6 +136,24 @@ void op::RNNCell::pre_validate_and_infer_types()
"). Actual shape is:",
w_shape,
".");
if (get_input_size() > 4)
{
const auto& b_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(
this, b_pshape.is_static(), "GRUCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * get_hidden_size()}),
"Input tensor B must have shape (",
2 * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
}
}
NodeVector op::RNNCell::decompose_op() const
......@@ -155,15 +164,12 @@ NodeVector op::RNNCell::decompose_op() const
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// W - W parameter weight matrix for input, output, forget, and
// cell gates.
// R - R recurrence weight matrix for input, output, forget, and
// cell gates.
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// W - W parameter weight matrix for input gate.
// R - R recurrence weight matrix for input gate.
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time
// step.
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
......@@ -174,12 +180,15 @@ NodeVector op::RNNCell::decompose_op() const
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
NodeVector b_W_R = builder::split(m_B, 2);
auto bias = b_W_R.at(0) + b_W_R.at(1);
// Xt*(W^T)
auto Xt_W = std::make_shared<ngraph::op::Dot>(m_X, ngraph::op::util::transpose(m_W));
// Ht-1*(R^T)
auto Ht_R = std::make_shared<ngraph::op::Dot>(m_H_t, ngraph::op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto i_t = add(Xt_W, add(Ht_R, m_bias));
auto i_t = add(Xt_W, add(Ht_R, bias));
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
i_t = m_activation_f(clip(i_t));
......
......@@ -113,8 +113,7 @@ namespace ngraph
const std::shared_ptr<Node>& H_t,
std::size_t hidden_size,
const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations =
std::vector<std::string>{"tanh"},
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {},
float clip = 0.f);
......@@ -142,17 +141,16 @@ namespace ngraph
///
std::shared_ptr<Node> m_H_t;
///
/// \brief The bias tensor for the gates. Shape: [2 * gates_count * hidden_size].
/// \note Concatenation of `[Wb[i], Rb[i]]`.
///
std::shared_ptr<Node> m_B;
///
/// \brief The Activation function f.
///
ActivationFunction m_activation_f;
static constexpr std::size_t m_gates_count{1};
///
/// \brief Sum of biases (weight and recurrence) for input gate.
///
/// Sum of `[Wb, Rb]`.
///
std::shared_ptr<Node> m_bias;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment