Commit 8c2c9b46 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by arogowie-intel

Add node validation.

parent 10649e87
......@@ -170,6 +170,14 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
// Split B onto Wb an Rb and add them.
else
{
NODE_VALIDATION_CHECK(this,
(B->get_shape() == Shape{2 * m_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
8 * get_hidden_size(),
"). Actual shape is:",
B->get_shape(),
".");
NodeVector b_W_R = builder::split(B, 2);
m_bias = b_W_R.at(0) + b_W_R.at(1);
}
......@@ -183,12 +191,66 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
vector<float>(m_peepholes_count * get_hidden_size(), 0.f));
}
NODE_VALIDATION_CHECK(this,
(P->get_shape() == Shape{3 * get_hidden_size()}),
"Input tensor P must have shape (",
3 * get_hidden_size(),
"). Actual shape is:",
P->get_shape(),
".");
m_p_iof = builder::split(peephole_weights, m_peepholes_count);
constructor_validate_and_infer_types();
}
void op::LSTMCell::pre_validate_and_infer_types()
{
const auto& x_shape = input(0).get_shape();
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const auto& w_shape = input(1).get_shape();
const auto& r_shape = input(2).get_shape();
const auto& ht_shape = input(3).get_shape();
const auto& ct_shape = input(4).get_shape();
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{4 * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
4 * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{4 * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
4 * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor H_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ct_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor C_t must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
}
NodeVector op::LSTMCell::decompose_op() const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment