Commit eca45d1c authored by Adam Rogowiec's avatar Adam Rogowiec

Move common RNN utility functions to RNNCellBase.

parent 8ad92a06
...@@ -23,11 +23,6 @@ ...@@ -23,11 +23,6 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/lstm_cell.hpp" #include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/reshape.hpp" #include "ngraph/op/util/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -36,46 +31,6 @@ ...@@ -36,46 +31,6 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// ------------- HELPER FUNCTIONS ---------------------------------------------
static shared_ptr<Node> add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
static shared_ptr<Node> sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
static shared_ptr<Node> mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
static shared_ptr<Node> clip(const shared_ptr<Node>& data, float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
// ------------- LSTM_CELL ----------------------------------------------------
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
const shared_ptr<Node>& W, const shared_ptr<Node>& W,
const shared_ptr<Node>& R, const shared_ptr<Node>& R,
......
...@@ -17,7 +17,14 @@ ...@@ -17,7 +17,14 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/rnn_cell_base.hpp" #include "ngraph/op/fused/rnn_cell_base.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -60,3 +67,39 @@ op::ActivationFunction op::RNNCellBase::get_activation_function(size_t idx) cons ...@@ -60,3 +67,39 @@ op::ActivationFunction op::RNNCellBase::get_activation_function(size_t idx) cons
return afunc; return afunc;
} }
shared_ptr<Node> op::RNNCellBase::add(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::RNNCellBase::sub(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::RNNCellBase::mul(const shared_ptr<Node>& lhs, const shared_ptr<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::RNNCellBase::clip(const shared_ptr<Node>& data, const float threshold)
{
if (threshold == 0.f)
{
return data;
}
float min_val = -threshold;
float max_val = threshold;
size_t size = shape_size(data->get_shape());
const shared_ptr<Node> min_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, min_val));
const shared_ptr<Node> max_val_node = op::Constant::create(
data->get_element_type(), data->get_shape(), vector<float>(size, max_val));
return make_shared<op::Minimum>(max_val_node, make_shared<op::Maximum>(data, min_val_node));
}
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/activation_functions.hpp" #include "ngraph/op/util/activation_functions.hpp"
namespace ngraph namespace ngraph
...@@ -66,6 +68,46 @@ namespace ngraph ...@@ -66,6 +68,46 @@ namespace ngraph
/// \return The object representing activation function. /// \return The object representing activation function.
/// ///
ActivationFunction get_activation_function(std::size_t idx) const; ActivationFunction get_activation_function(std::size_t idx) const;
///
/// \brief Creates node with element-wise add operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise add operation.
///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise subtract operation.
///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise multiply operation.
///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
/// \param[in] lhs The left hand side argument node.
/// \param[in] rhs The right hand side argument node.
///
/// \return Node with element-wise clip operation.
///
static std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data,
const float threshold);
private: private:
std::size_t m_hidden_size = 0.f; std::size_t m_hidden_size = 0.f;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment