Commit ef1c5347 authored by Adam Rogowiec's avatar Adam Rogowiec

LSTM cell fused operator.

parent 03dba84d
......@@ -302,12 +302,16 @@ set (SRC
op/fused/grn.hpp
op/fused/group_conv.hpp
op/fused/group_conv.cpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize.cpp
op/fused/normalize.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/rnn_cell_base.cpp
op/fused/rnn_cell_base.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/space_to_depth.cpp
......@@ -320,6 +324,8 @@ set (SRC
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/activation_functions.cpp
op/util/activation_functions.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
......
......@@ -189,8 +189,6 @@ add_library(onnx_import STATIC
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
utils/rnn/activation_functions.cpp
utils/rnn/activation_functions.hpp
utils/variadic.hpp)
set(ONNX_IMPORT_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE INTERNAL "")
......
......@@ -103,6 +103,7 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......
This diff is collapsed.
This diff is collapsed.
......@@ -27,6 +27,7 @@ NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
......
......@@ -74,6 +74,7 @@
#include "ngraph/op/fused/grn.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
......@@ -1055,6 +1056,29 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
args[5],
args[6],
activations,
activation_alpha,
activation_beta,
clip,
input_forget);
break;
}
case OP_TYPEID::Max:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -1979,6 +2003,17 @@ static json write(const Node& n, bool binary_constant_data)
node["nsize"] = tmp->get_nsize();
break;
}
case OP_TYPEID::LSTMCell:
{
auto tmp = dynamic_cast<const op::LSTMCell*>(&n);
node["hidden_size"] = tmp->get_hidden_size();
node["clip"] = tmp->get_clip();
node["activations"] = tmp->get_activations();
node["activation_alpha"] = tmp->get_activation_alpha();
node["activation_beta"] = tmp->get_activation_beta();
node["input_forget"] = tmp->get_input_forget();
break;
}
case OP_TYPEID::Max:
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment