Commit 5b459fd9 authored by Adam Rogowiec's avatar Adam Rogowiec

Store original Bias input as RNNCell member.

parent 835dd295
This diff is collapsed.
...@@ -53,10 +53,10 @@ namespace ngraph ...@@ -53,10 +53,10 @@ namespace ngraph
/// \param[in] hidden_size The number of hidden units for recurrent cell. /// \param[in] hidden_size The number of hidden units for recurrent cell.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W, const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R, const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t, const std::shared_ptr<Node>& H_t,
std::size_t hidden_size); std::size_t hidden_size);
/// ///
/// \brief Constructs RNNCell node. /// \brief Constructs RNNCell node.
...@@ -78,14 +78,14 @@ namespace ngraph ...@@ -78,14 +78,14 @@ namespace ngraph
/// input of activation functions. /// input of activation functions.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W, const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R, const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t, const std::shared_ptr<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::vector<std::string>& activations, const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha, const std::vector<float>& activation_alpha,
const std::vector<float>& activation_beta, const std::vector<float>& activation_beta,
float clip); float clip);
/// ///
/// \brief Constructs RNNCell node. /// \brief Constructs RNNCell node.
...@@ -108,16 +108,15 @@ namespace ngraph ...@@ -108,16 +108,15 @@ namespace ngraph
/// input of activation functions. /// input of activation functions.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& W, const std::shared_ptr<Node>& W,
const std::shared_ptr<Node>& R, const std::shared_ptr<Node>& R,
const std::shared_ptr<Node>& H_t, const std::shared_ptr<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::shared_ptr<Node>& B, const std::shared_ptr<Node>& B,
const std::vector<std::string>& activations = const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
std::vector<std::string>{"tanh"}, const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_alpha = {}, const std::vector<float>& activation_beta = {},
const std::vector<float>& activation_beta = {}, float clip = 0.f);
float clip = 0.f);
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
...@@ -142,17 +141,16 @@ namespace ngraph ...@@ -142,17 +141,16 @@ namespace ngraph
/// ///
std::shared_ptr<Node> m_H_t; std::shared_ptr<Node> m_H_t;
/// ///
/// \brief The bias tensor for the gates. Shape: [2 * gates_count * hidden_size].
/// \note Concatenation of `[Wb[i], Rb[i]]`.
///
std::shared_ptr<Node> m_B;
///
/// \brief The Activation function f. /// \brief The Activation function f.
/// ///
ActivationFunction m_activation_f; ActivationFunction m_activation_f;
static constexpr std::size_t m_gates_count{1}; static constexpr std::size_t m_gates_count{1};
///
/// \brief Sum of biases (weight and recurrence) for input gate.
///
/// Sum of `[Wb, Rb]`.
///
std::shared_ptr<Node> m_bias;
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment