Commit 835dd295 authored by Adam Rogowiec's avatar Adam Rogowiec

Store biases and peephole weights as LSTMCell members.

parent 93255d43
......@@ -74,17 +74,15 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros,
// thus just initialize bias with appropriate shape and zeros.
m_bias = ngraph::op::Constant::create(element::f32,
Shape{m_gates_count * get_hidden_size()},
vector<float>(m_gates_count * get_hidden_size(), 0.f));
// As default bias is all zeros, thus just initialize it with appropriate shape and zeros.
m_B = op::Constant::create(m_X->get_element_type(),
Shape{2 * m_gates_count * get_hidden_size()},
vector<float>(2 * m_gates_count * get_hidden_size(), 0.f));
m_P = op::Constant::create(m_X->get_element_type(),
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
const auto& peephole_weights =
ngraph::op::Constant::create(element::f32,
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
m_p_iof = builder::split(peephole_weights, m_peepholes_count);
constructor_validate_and_infer_types();
}
......@@ -108,60 +106,52 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
, m_R{R}
, m_H_t{H_t}
, m_C_t{C_t}
, m_B{B}
, m_P{P}
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget}
{
// Split B onto Wb and Rb and add them.
NODE_VALIDATION_CHECK(this,
(B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
B->get_shape(),
".");
NodeVector b_W_R = builder::split(B, 2);
m_bias = b_W_R.at(0) + b_W_R.at(1);
NODE_VALIDATION_CHECK(this,
(P->get_shape() == Shape{m_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (",
m_peepholes_count * get_hidden_size(),
"). Actual shape is:",
P->get_shape(),
".");
m_p_iof = builder::split(P, m_peepholes_count);
constructor_validate_and_infer_types();
}
void op::LSTMCell::pre_validate_and_infer_types()
{
const auto& x_shape = input(0).get_shape();
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
const auto& ht_pshape = get_input_partial_shape(3);
const auto& ct_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static() || ct_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const auto& w_shape = input(1).get_shape();
const auto& r_shape = input(2).get_shape();
const auto& ht_shape = input(3).get_shape();
const auto& ct_shape = input(4).get_shape();
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
const Shape& ct_shape{ct_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{4 * get_hidden_size(), input_size}),
(w_shape == Shape{m_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
4 * get_hidden_size(),
m_gates_count * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{4 * get_hidden_size(), get_hidden_size()}),
(r_shape == Shape{m_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
4 * get_hidden_size(),
m_gates_count * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
......@@ -185,6 +175,34 @@ void op::LSTMCell::pre_validate_and_infer_types()
"). Actual shape is:",
w_shape,
".");
if (get_input_size() > 5)
{
const auto& b_pshape = get_input_partial_shape(5);
const auto& p_pshape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
(b_pshape.is_static() || p_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
const Shape& p_shape{p_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
2 * m_gates_count * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
NODE_VALIDATION_CHECK(this,
(p_shape == Shape{m_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (",
m_peepholes_count * get_hidden_size(),
"). Actual shape is:",
p_shape,
".");
}
}
NodeVector op::LSTMCell::decompose_op() const
......@@ -224,16 +242,20 @@ NodeVector op::LSTMCell::decompose_op() const
// Ht = ot (.) h(Ct)
// --------------------
const auto& p_i = m_p_iof.at(0);
const auto& p_o = m_p_iof.at(1);
const auto& p_f = m_p_iof.at(2);
NodeVector b_W_R = builder::split(m_B, 2);
auto bias = b_W_R.at(0) + b_W_R.at(1);
NodeVector p_iof = builder::split(m_P, m_peepholes_count);
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = std::make_shared<ngraph::op::Dot>(m_X, ngraph::op::util::transpose(m_W));
auto Xt_W = make_shared<op::Dot>(m_X, op::util::transpose(m_W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = std::make_shared<ngraph::op::Dot>(m_H_t, ngraph::op::util::transpose(m_R));
auto Ht_R = make_shared<op::Dot>(m_H_t, op::util::transpose(m_R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, m_bias));
auto gates = add(Xt_W, add(Ht_R, bias));
NodeVector split_gates = builder::split(gates, 4, -1);
auto i_t = split_gates.at(0);
......@@ -246,11 +268,10 @@ NodeVector op::LSTMCell::decompose_op() const
if (m_input_forget)
{
// Couple input with forget gate: 1 - i_t
f_t =
sub(ngraph::op::Constant::create(i_t->get_element_type(),
i_t->get_shape(),
std::vector<float>(shape_size(i_t->get_shape()), 1.f)),
i_t);
f_t = sub(op::Constant::create(i_t->get_element_type(),
i_t->get_shape(),
vector<float>(shape_size(i_t->get_shape()), 1.f)),
i_t);
}
else
{
......
......@@ -164,6 +164,15 @@ namespace ngraph
///
std::shared_ptr<Node> m_C_t;
///
/// \brief The bias tensor for the gates. Shape: [2 * gates_count * hidden_size].
/// \note Concatenation of `[Wb[zrh], Rb[zrh]]`.
///
std::shared_ptr<Node> m_B;
///
/// \brief Peephole weights for iof gates. Shape: [3 * hidden_size]
///
std::shared_ptr<Node> m_P;
///
/// \brief The Activation function f.
///
ActivationFunction m_activation_f;
......@@ -182,18 +191,6 @@ namespace ngraph
static constexpr std::size_t m_gates_count{4};
static constexpr std::size_t m_peepholes_count{3};
///
/// \brief Peephole weights vector for respectively: input, output, and forget gates.
///
/// Each peephole has shape [hidden_size].
///
NodeVector m_p_iof;
///
/// \brief Sum of biases (weight and recurrence) for input, output, forget, and cell gates.
///
/// Sum of `[Wb, Rb]`.
///
std::shared_ptr<Node> m_bias;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment