Commit 0f8c8bf7 authored by Adam Rogowiec's avatar Adam Rogowiec

Fix retrieving bias input.

parent 103d5fbc
...@@ -191,7 +191,7 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -191,7 +191,7 @@ NodeVector op::GRUCell::decompose_op() const
std::shared_ptr<Node> W = get_argument(1); std::shared_ptr<Node> W = get_argument(1);
std::shared_ptr<Node> R = get_argument(2); std::shared_ptr<Node> R = get_argument(2);
std::shared_ptr<Node> H_t = get_argument(3); std::shared_ptr<Node> H_t = get_argument(3);
std::shared_ptr<Node> B = get_bias(); std::shared_ptr<Node> B = get_argument(4);
// Get W and R biases separately. // Get W and R biases separately.
NodeVector b_W_R = builder::split(B, 2); NodeVector b_W_R = builder::split(B, 2);
...@@ -270,15 +270,6 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -270,15 +270,6 @@ NodeVector op::GRUCell::decompose_op() const
return {H_t}; return {H_t};
} }
shared_ptr<Node> op::GRUCell::get_bias() const
{
shared_ptr<Node> bias;
// Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(4), 2);
bias = b_W_R.at(0) + b_W_R.at(1);
return bias;
}
void op::GRUCell::add_default_bias_input() void op::GRUCell::add_default_bias_input()
{ {
shared_ptr<Node> B = shared_ptr<Node> B =
......
...@@ -132,8 +132,6 @@ namespace ngraph ...@@ -132,8 +132,6 @@ namespace ngraph
bool get_linear_before_reset() const { return m_linear_before_reset; } bool get_linear_before_reset() const { return m_linear_before_reset; }
private: private:
std::shared_ptr<Node> get_bias() const;
/// brief Add and initialize bias input to all zeros. /// brief Add and initialize bias input to all zeros.
void add_default_bias_input(); void add_default_bias_input();
......
...@@ -171,10 +171,7 @@ NodeVector op::RNNCell::decompose_op() const ...@@ -171,10 +171,7 @@ NodeVector op::RNNCell::decompose_op() const
std::shared_ptr<Node> W = get_argument(1); std::shared_ptr<Node> W = get_argument(1);
std::shared_ptr<Node> R = get_argument(2); std::shared_ptr<Node> R = get_argument(2);
std::shared_ptr<Node> H_t = get_argument(3); std::shared_ptr<Node> H_t = get_argument(3);
std::shared_ptr<Node> B = get_bias(); std::shared_ptr<Node> bias = get_bias();
NodeVector b_W_R = builder::split(B, 2);
auto bias = b_W_R.at(0) + b_W_R.at(1);
// Xt*(W^T) // Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W)); auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment