Commit b5f14735 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by arogowie-intel

Disambiguate constructors.

parent 8c2c9b46
...@@ -158,48 +158,27 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, ...@@ -158,48 +158,27 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
, m_activation_h{get_activation_function(2)} , m_activation_h{get_activation_function(2)}
, m_input_forget{input_forget} , m_input_forget{input_forget}
{ {
// Normally we would split B onto Wb an Rb and add them, however here they are all zeros, // Split B onto Wb and Rb and add them.
// thus just initialize bias with appropriate shape and zeros. NODE_VALIDATION_CHECK(this,
if (!B) (B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
{ "Input tensor B must have shape (",
m_bias = 8 * get_hidden_size(),
ngraph::op::Constant::create(element::f32, "). Actual shape is:",
Shape{m_gates_count * get_hidden_size()}, B->get_shape(),
vector<float>(m_gates_count * get_hidden_size(), 0.f)); ".");
}
// Split B onto Wb an Rb and add them.
else
{
NODE_VALIDATION_CHECK(this,
(B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
B->get_shape(),
".");
NodeVector b_W_R = builder::split(B, 2);
m_bias = b_W_R.at(0) + b_W_R.at(1);
}
auto peephole_weights = P; NodeVector b_W_R = builder::split(B, 2);
if (!peephole_weights) m_bias = b_W_R.at(0) + b_W_R.at(1);
{
peephole_weights =
ngraph::op::Constant::create(element::f32,
Shape{m_peepholes_count * get_hidden_size()},
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
}
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
(P->get_shape() == Shape{3 * get_hidden_size()}), (P->get_shape() == Shape{m_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (", "Input tensor P must have shape (",
3 * get_hidden_size(), m_peepholes_count * get_hidden_size(),
"). Actual shape is:", "). Actual shape is:",
P->get_shape(), P->get_shape(),
"."); ".");
m_p_iof = builder::split(peephole_weights, m_peepholes_count); m_p_iof = builder::split(P, m_peepholes_count);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -127,8 +127,8 @@ namespace ngraph ...@@ -127,8 +127,8 @@ namespace ngraph
const std::shared_ptr<Node>& H_t, const std::shared_ptr<Node>& H_t,
const std::shared_ptr<Node>& C_t, const std::shared_ptr<Node>& C_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::shared_ptr<Node>& B = nullptr, const std::shared_ptr<Node>& B,
const std::shared_ptr<Node>& P = nullptr, const std::shared_ptr<Node>& P,
const std::vector<std::string>& activations = const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"}, std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activation_alpha = {}, const std::vector<float>& activation_alpha = {},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment