Unverified Commit 412f5f7b authored by Pruthvi's avatar Pruthvi Committed by GitHub

Vanilla RNN Optimization (#4439)

* WIP PM and callback for vanilla RNN

* - Graph pass to fuse vanilla RNN Cell
- Unit test case for vanilla rnn through TF serialized graph

* - emit MKLDNN kernel for vanilla RNN ii) test case for Vanilla for cpu v/s inter

* i) style check ii) serialized graph for tf vanilla rnn

* fix warnings

* i) fixed emitter code ii) test case passes

* - added support for Vanilla RNN for MKLDNN > v1.0

* fix compiler warnings

* Fix build error

* fix unit test case
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent ef4692a6
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -39,7 +40,7 @@ namespace ngraph ...@@ -39,7 +40,7 @@ namespace ngraph
} }
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
auto rnn_op = static_cast<const ngraph::op::Rnn*>(node);
auto src_layer_buffer_index = auto src_layer_buffer_index =
external_function->get_buffer_index(args[0].get_name()); external_function->get_buffer_index(args[0].get_name());
auto src_iter_buffer_index = auto src_iter_buffer_index =
...@@ -49,10 +50,10 @@ namespace ngraph ...@@ -49,10 +50,10 @@ namespace ngraph
auto dst_iter_buffer_index = external_function->get_buffer_index(out[1].get_name()); auto dst_iter_buffer_index = external_function->get_buffer_index(out[1].get_name());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto rnn_desc =
mkldnn_emitter->get_rnn_forward_desc<ngraph::op::Rnn>(node, args, out);
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
auto rnn_desc =
mkldnn_emitter->get_rnn_forward_desc<ngraph::op::Rnn>(node, args, out);
auto weights_layer_buffer_index = auto weights_layer_buffer_index =
external_function->get_buffer_index(args[2].get_name()); external_function->get_buffer_index(args[2].get_name());
auto weights_iter_buffer_index = auto weights_iter_buffer_index =
...@@ -109,75 +110,148 @@ namespace ngraph ...@@ -109,75 +110,148 @@ namespace ngraph
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
#else #else
size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc); if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_rnn))
{
auto weights_layer_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto weights_iter_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto bias_buffer_index =
external_function->get_buffer_index(args[4].get_name());
auto src_iter_c_buffer_index = // Rnn needs 9 primitives: src_layer, src_iter, weights_layer, weights_iter,
external_function->get_buffer_index(args[2].get_name()); // bias,
auto weights_layer_buffer_index = // dst_layer, dst_iter, workspace, and rnn_forward.
external_function->get_buffer_index(args[3].get_name()); // It needs a new workspace.
auto weights_iter_buffer_index = auto rnn_index = mkldnn_emitter->reserve_primitive_space(
external_function->get_buffer_index(args[4].get_name()); 9, false /* fwd and bwd */, true /* new workspace */);
auto bias_buffer_index = external_function->get_buffer_index(args[5].get_name()); auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
auto dst_iter_c_buffer_index = auto vanilla_rnn_desc =
external_function->get_buffer_index(out[2].get_name()); mkldnn_emitter->get_vanilla_rnn_forward_desc<ngraph::op::Rnn>(
node, args, out);
size_t scratchpad_size =
mkldnn_emitter->query_scratchpad_vanilla_rnn_forward(vanilla_rnn_desc);
// Rnn needs 11 primitives: src_layer, src_iter, src_iter_c, weights_layer, auto functor = [&,
// weights_iter, bias, vanilla_rnn_desc,
// dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward. rnn_index,
// It needs a new workspace. src_layer_buffer_index,
auto rnn_index = mkldnn_emitter->reserve_primitive_space( src_iter_buffer_index,
11, false /* fwd and bwd */, true /* new workspace */); weights_layer_buffer_index,
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index); weights_iter_buffer_index,
bias_buffer_index,
dst_layer_buffer_index,
dst_iter_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_vanilla_rnn_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
ctx->mkldnn_workspaces,
vanilla_rnn_desc,
deps,
rnn_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[src_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[src_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[weights_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[weights_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[bias_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[dst_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[6], ctx->buffer_data[dst_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[7], ctx->mkldnn_workspaces[deps[8]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx,
rnn_index,
deps,
cpu::mkldnn_utils::OpType::VANILLA_RNN,
scratchpad_size);
};
functors.emplace_back(functor);
}
else if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm))
{
auto src_iter_c_buffer_index =
external_function->get_buffer_index(args[2].get_name());
auto weights_layer_buffer_index =
external_function->get_buffer_index(args[3].get_name());
auto weights_iter_buffer_index =
external_function->get_buffer_index(args[4].get_name());
auto bias_buffer_index =
external_function->get_buffer_index(args[5].get_name());
auto dst_iter_c_buffer_index =
external_function->get_buffer_index(out[2].get_name());
auto functor = [&, // Rnn needs 11 primitives: src_layer, src_iter, src_iter_c, weights_layer,
rnn_desc, // weights_iter, bias,
rnn_index, // dst_layer, dst_iter, dst_iter_c, workspace, and lstm_forward.
scratchpad_size, // It needs a new workspace.
src_layer_buffer_index, auto rnn_index = mkldnn_emitter->reserve_primitive_space(
src_iter_buffer_index, 11, false /* fwd and bwd */, true /* new workspace */);
src_iter_c_buffer_index, auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
weights_layer_buffer_index, auto rnn_desc =
weights_iter_buffer_index, mkldnn_emitter->get_rnn_forward_desc<ngraph::op::Rnn>(node, args, out);
bias_buffer_index, size_t scratchpad_size = mkldnn_emitter->query_scratchpad_rnn_forward(rnn_desc);
dst_layer_buffer_index,
dst_iter_buffer_index,
dst_iter_c_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_rnn_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
ctx->mkldnn_workspaces,
rnn_desc,
deps,
rnn_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[src_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[src_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[src_iter_c_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[weights_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[weights_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[bias_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[6], ctx->buffer_data[dst_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[7], ctx->buffer_data[dst_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[8], ctx->buffer_data[dst_iter_c_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive( auto functor = [&,
ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN, scratchpad_size); rnn_desc,
}; rnn_index,
functors.emplace_back(functor); scratchpad_size,
src_layer_buffer_index,
src_iter_buffer_index,
src_iter_c_buffer_index,
weights_layer_buffer_index,
weights_iter_buffer_index,
bias_buffer_index,
dst_layer_buffer_index,
dst_iter_buffer_index,
dst_iter_c_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
if (ctx->first_iteration)
{
mkldnn_emitter->build_rnn_forward(ctx->mkldnn_memories,
ctx->mkldnn_primitives,
ctx->mkldnn_scratchpad_mds,
ctx->mkldnn_workspaces,
rnn_desc,
deps,
rnn_index);
}
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[0], ctx->buffer_data[src_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[1], ctx->buffer_data[src_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[2], ctx->buffer_data[src_iter_c_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[3], ctx->buffer_data[weights_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[4], ctx->buffer_data[weights_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[5], ctx->buffer_data[bias_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[6], ctx->buffer_data[dst_layer_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[7], ctx->buffer_data[dst_iter_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[8], ctx->buffer_data[dst_iter_c_buffer_index]);
cpu::mkldnn_utils::set_memory_ptr(
ctx, deps[9], ctx->mkldnn_workspaces[deps[10]]);
cpu::mkldnn_utils::mkldnn_invoke_primitive(
ctx, rnn_index, deps, cpu::mkldnn_utils::OpType::RNN, scratchpad_size);
};
functors.emplace_back(functor);
}
#endif #endif
} }
......
...@@ -1279,6 +1279,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1279,6 +1279,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(ImplicitBroadcastElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(ImplicitBroadcastElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(VanillaRNNFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(LSTMFusion, true, runtime::cpu::pass) REGISTER_KNOBBED_PASS(LSTMFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass) REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass) REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass)
......
...@@ -1242,6 +1242,47 @@ void MKLDNNEmitter::build_batchnorm_backward( ...@@ -1242,6 +1242,47 @@ void MKLDNNEmitter::build_batchnorm_backward(
mkldnn_primitives[batchnorm_index] = new mkldnn::batch_normalization_backward(batchnorm_pd); mkldnn_primitives[batchnorm_index] = new mkldnn::batch_normalization_backward(batchnorm_pd);
} }
void MKLDNNEmitter::build_vanilla_rnn_forward(
std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
std::vector<char*>& mkldnn_workspaces,
const mkldnn::vanilla_rnn_forward::desc& rnn_desc,
std::vector<size_t>& deps,
size_t rnn_index)
{
size_t src_layer_index = deps[0];
build_memory(mkldnn_memories, rnn_desc.data.src_layer_desc, src_layer_index);
size_t src_iter_index = deps[1];
build_memory(mkldnn_memories, rnn_desc.data.src_iter_desc, src_iter_index);
size_t weights_layer_index = deps[2];
build_memory(mkldnn_memories, rnn_desc.data.weights_layer_desc, weights_layer_index);
size_t weights_iter_index = deps[3];
build_memory(mkldnn_memories, rnn_desc.data.weights_iter_desc, weights_iter_index);
size_t bias_index = deps[4];
build_memory(mkldnn_memories, rnn_desc.data.bias_desc, bias_index);
size_t dst_layer_index = deps[5];
build_memory(mkldnn_memories, rnn_desc.data.dst_layer_desc, dst_layer_index);
size_t dst_iter_index = deps[6];
build_memory(mkldnn_memories, rnn_desc.data.dst_iter_desc, dst_iter_index);
mkldnn::primitive_attr attr;
attr.set_scratchpad_mode(mkldnn::scratchpad_mode::user);
auto rnn_layer_prim_desc =
mkldnn::vanilla_rnn_forward::primitive_desc(rnn_desc, attr, executor::global_cpu_engine);
mkldnn_scratchpad_mds[rnn_index] =
new mkldnn::memory::desc(rnn_layer_prim_desc.scratchpad_desc());
size_t workspace_index = deps[7];
build_memory(mkldnn_memories, rnn_layer_prim_desc.workspace_desc(), workspace_index);
auto workspace = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_desc().get_size()));
auto workspace_buf_index = insert_workspace(mkldnn_workspaces, workspace);
deps[8] = workspace_buf_index;
mkldnn_primitives[rnn_index] = new mkldnn::vanilla_rnn_forward(rnn_layer_prim_desc);
}
void MKLDNNEmitter::build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories, void MKLDNNEmitter::build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives, std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds, std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
...@@ -1624,6 +1665,14 @@ size_t MKLDNNEmitter::query_scratchpad_rnn_forward(const mkldnn::lstm_forward::d ...@@ -1624,6 +1665,14 @@ size_t MKLDNNEmitter::query_scratchpad_rnn_forward(const mkldnn::lstm_forward::d
GET_SIZE GET_SIZE
} }
size_t MKLDNNEmitter::query_scratchpad_vanilla_rnn_forward(
const mkldnn::vanilla_rnn_forward::desc& desc)
{
ATTR_S
auto pd = mkldnn::vanilla_rnn_forward::primitive_desc(desc, attr, executor::global_cpu_engine);
GET_SIZE
}
size_t MKLDNNEmitter::query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc) size_t MKLDNNEmitter::query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc)
{ {
ATTR_S ATTR_S
......
...@@ -1332,6 +1332,93 @@ namespace ngraph ...@@ -1332,6 +1332,93 @@ namespace ngraph
dst_iter_c_desc); dst_iter_c_desc);
} }
template <typename OP>
mkldnn::vanilla_rnn_forward::desc
get_vanilla_rnn_forward_desc(const ngraph::Node* node,
const std::vector<TensorWrapper>& args,
const std::vector<TensorWrapper>& out)
{
auto rnn_node = static_cast<const OP*>(node);
auto src_sequence_length_max =
static_cast<unsigned long>(rnn_node->get_src_sequence_length());
auto direction = static_cast<unsigned long>(rnn_node->get_direction());
auto num_fused_layers =
static_cast<unsigned long>(rnn_node->get_num_fused_layers());
auto feature_size =
static_cast<unsigned long>(rnn_node->get_src_iter_feature_size());
auto batch = static_cast<unsigned long>(rnn_node->get_batch_size());
auto rnn_cell_n_gates =
static_cast<unsigned long>(rnn_node->get_gates_per_cell());
auto get_mkldnn_rnn_direction = [&]() {
switch (direction)
{
case 1: return mkldnn::rnn_direction::unidirectional_left2right;
case 2: return mkldnn::rnn_direction::bidirectional_concat;
default: throw ngraph_error("unsupported mkldnn rnn direction");
}
};
if (out[0].get_shape().size() == 2 &&
(out[0].get_shape()[1] != direction * feature_size))
{
throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature "
"size ");
}
if (out[1].get_shape().size() == 2 && (out[1].get_shape()[1] != feature_size) &&
rnn_node->get_num_timesteps() != 1)
{
throw ngraph_error(
"input sic{ht_1|ct_1} feature size is not equal to output "
"dlc{ht_1|ct_1} "
"feature size ");
}
Shape src_layer_tz{
src_sequence_length_max,
batch,
static_cast<unsigned long>(rnn_node->get_src_layer_feature_size())};
Shape src_iter_tz{num_fused_layers, direction, batch, feature_size};
Shape wei_layer_tz{
num_fused_layers,
direction,
static_cast<unsigned long>(rnn_node->get_src_layer_feature_size()),
rnn_cell_n_gates,
feature_size};
Shape wei_iter_tz{
num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size};
Shape bias_tz{num_fused_layers, direction, rnn_cell_n_gates, feature_size};
Shape dst_layer_tz{src_sequence_length_max, batch, direction * feature_size};
Shape dst_iter_tz{num_fused_layers, direction, batch, feature_size};
// We create the memory descriptors used by the user
auto src_layer_desc = build_memory_descriptor(
src_layer_tz, args[0].get_element_type(), mkldnn::memory::FORMAT::tnc);
auto src_iter_desc = build_memory_descriptor(
src_iter_tz, args[1].get_element_type(), mkldnn::memory::FORMAT::ldnc);
auto weights_layer_desc = build_memory_descriptor(
wei_layer_tz, args[2].get_element_type(), mkldnn::memory::FORMAT::ldigo);
auto weights_iter_desc = build_memory_descriptor(
wei_iter_tz, args[3].get_element_type(), mkldnn::memory::FORMAT::ldigo);
auto bias_desc = build_memory_descriptor(
bias_tz, args[4].get_element_type(), mkldnn::memory::FORMAT::ldgo);
auto dst_layer_desc = build_memory_descriptor(
dst_layer_tz, out[0].get_element_type(), mkldnn::memory::FORMAT::tnc);
auto dst_iter_desc = build_memory_descriptor(
dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldnc);
return mkldnn::vanilla_rnn_forward::desc(mkldnn::prop_kind::forward_training,
mkldnn::algorithm::eltwise_tanh,
get_mkldnn_rnn_direction(),
src_layer_desc,
src_iter_desc,
weights_layer_desc,
weights_iter_desc,
bias_desc,
dst_layer_desc,
dst_iter_desc);
}
void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories, void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives, std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds, std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
...@@ -1340,6 +1427,15 @@ namespace ngraph ...@@ -1340,6 +1427,15 @@ namespace ngraph
std::vector<size_t>& deps, std::vector<size_t>& deps,
size_t rnn_idx); size_t rnn_idx);
void build_vanilla_rnn_forward(
std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
std::vector<char*>& mkldnn_workspaces,
const mkldnn::vanilla_rnn_forward::desc& desc,
std::vector<size_t>& deps,
size_t rnn_idx);
template <bool with_bias> template <bool with_bias>
void build_convolution_forward( void build_convolution_forward(
std::vector<mkldnn::memory*>& mkldnn_memories, std::vector<mkldnn::memory*>& mkldnn_memories,
...@@ -1459,6 +1555,8 @@ namespace ngraph ...@@ -1459,6 +1555,8 @@ namespace ngraph
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc); size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc);
size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc); size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc);
size_t query_scratchpad_vanilla_rnn_forward(
const mkldnn::vanilla_rnn_forward::desc& desc);
size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc, size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc, const mkldnn::memory::desc& output_desc,
const ngraph::Coordinate& lower_bounds, const ngraph::Coordinate& lower_bounds,
...@@ -1593,7 +1691,8 @@ namespace ngraph ...@@ -1593,7 +1691,8 @@ namespace ngraph
auto dst_iter_desc = build_memory_descriptor( auto dst_iter_desc = build_memory_descriptor(
dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldsnc); dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldsnc);
mkldnn::rnn_cell::desc rnn_cell_desc(get_mkldnn_rnn_cell_type()); mkldnn::rnn_cell::desc rnn_cell_desc(get_mkldnn_rnn_cell_type(),
mkldnn::algorithm::eltwise_tanh);
return mkldnn::rnn_forward::desc(mkldnn::prop_kind::forward_training, return mkldnn::rnn_forward::desc(mkldnn::prop_kind::forward_training,
rnn_cell_desc, rnn_cell_desc,
get_mkldnn_rnn_direction(), get_mkldnn_rnn_direction(),
......
...@@ -187,7 +187,19 @@ extern "C" void ...@@ -187,7 +187,19 @@ extern "C" void
{MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[7]]}, {MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[7]]},
{MKLDNN_ARG_DST_ITER_C, *ctx->mkldnn_memories[deps[8]]}, {MKLDNN_ARG_DST_ITER_C, *ctx->mkldnn_memories[deps[8]]},
{MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[9]]}}; {MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[9]]}};
break; break;
case OpType::VANILLA_RNN:
exec_args = {{MKLDNN_ARG_SRC_LAYER, *ctx->mkldnn_memories[deps[0]]},
{MKLDNN_ARG_SRC_ITER, *ctx->mkldnn_memories[deps[1]]},
{MKLDNN_ARG_WEIGHTS_LAYER, *ctx->mkldnn_memories[deps[2]]},
{MKLDNN_ARG_WEIGHTS_ITER, *ctx->mkldnn_memories[deps[3]]},
{MKLDNN_ARG_BIAS, *ctx->mkldnn_memories[deps[4]]},
{MKLDNN_ARG_DST_LAYER, *ctx->mkldnn_memories[deps[5]]},
{MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[6]]},
{MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[7]]}};
break;
case OpType::MAXPOOLBACKPROPFORWARD: case OpType::MAXPOOLBACKPROPFORWARD:
case OpType::MAXPOOLWITHINDICES: case OpType::MAXPOOLWITHINDICES:
exec_args = {{MKLDNN_ARG_SRC, *ctx->mkldnn_memories[deps[0]]}, exec_args = {{MKLDNN_ARG_SRC, *ctx->mkldnn_memories[deps[0]]},
......
...@@ -75,6 +75,7 @@ namespace ngraph ...@@ -75,6 +75,7 @@ namespace ngraph
RELU, RELU,
RELUBACKPROP, RELUBACKPROP,
RNN, RNN,
VANILLA_RNN,
SIGMOID, SIGMOID,
SIGMOIDBACKPROP, SIGMOIDBACKPROP,
SLICE, SLICE,
......
...@@ -26,23 +26,120 @@ constexpr NodeTypeInfo op::Rnn::type_info; ...@@ -26,23 +26,120 @@ constexpr NodeTypeInfo op::Rnn::type_info;
#if MKLDNN_VERSION_MAJOR >= 1 #if MKLDNN_VERSION_MAJOR >= 1
shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 6) if (new_args.size() != 6 && new_args.size() != 5)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<Rnn>(new_args[0],
new_args[1], if (new_args.size() == 5)
new_args[2], {
new_args[3], return make_shared<Rnn>(new_args[0],
new_args[4], new_args[1],
new_args[5], new_args[2],
m_num_timesteps, new_args[3],
m_num_gates_per_cell, new_args[4],
m_src_sequence_length, m_num_timesteps,
m_num_cell_states, m_num_gates_per_cell,
m_direction, m_src_sequence_length,
m_num_fused_layers, m_num_cell_states,
m_rnntype); m_direction,
m_num_fused_layers,
m_rnntype);
}
else
{
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
new_args[4],
new_args[5],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_num_cell_states,
m_direction,
m_num_fused_layers,
m_rnntype);
}
}
op::Rnn::Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter,
const Output<Node>& weights_layer,
const Output<Node>& weights_iter,
const Output<Node>& bias,
size_t num_timesteps,
size_t num_gates_per_cell,
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op({src_layer, src_iter, weights_layer, weights_iter, bias})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_num_cell_states(num_cell_states)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
, m_rnntype(rnn_type)
{
constructor_validate_and_infer_types();
if (src_layer.get_shape().size() != weights_layer.get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter.get_shape().size() != weights_iter.get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer.get_shape().size() == 2)
{
m_batch_size = src_layer.get_shape()[0] / m_num_timesteps;
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
m_dst_iter_feature_size = weights_iter.get_shape()[1] / (m_num_gates_per_cell);
m_dst_layer_feature_size = weights_layer.get_shape()[1] / (m_num_gates_per_cell);
m_src_iter_feature_size = weights_iter.get_shape()[0] / (m_direction * m_num_fused_layers);
m_src_layer_feature_size = weights_layer.get_shape()[0] / (m_direction * m_num_fused_layers);
if (shape_size(src_layer.get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
throw ngraph_error("src_layer size is not equal t*n*c");
}
if ((bias.get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(weights_layer.get_shape()[1]) ||
(bias.get_shape()[0] / (m_direction * m_num_fused_layers)) != (weights_iter.get_shape()[1]))
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer.get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
set_output_size(2);
set_output_type(0,
src_layer.get_element_type(),
Shape{(m_num_timesteps * m_batch_size), m_direction * m_src_iter_feature_size});
set_output_type(1,
src_layer.get_element_type(),
Shape{(m_num_cell_states * m_direction * m_num_fused_layers * m_batch_size),
m_src_iter_feature_size});
} }
op::Rnn::Rnn(const Output<Node>& src_layer, op::Rnn::Rnn(const Output<Node>& src_layer,
......
...@@ -72,6 +72,19 @@ namespace ngraph ...@@ -72,6 +72,19 @@ namespace ngraph
size_t num_fused_layers, size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type); ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
#else #else
CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter,
const Output<Node>& weights_layer,
const Output<Node>& weights_iter,
const Output<Node>& bias,
size_t num_timesteps,
size_t num_gates_per_cell,
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
CPU_BACKEND_API Rnn(const Output<Node>& src_layer, CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter, const Output<Node>& src_iter,
const Output<Node>& src_iter_c, const Output<Node>& src_iter_c,
...@@ -101,6 +114,11 @@ namespace ngraph ...@@ -101,6 +114,11 @@ namespace ngraph
size_t get_num_cell_states() const { return m_num_cell_states; } size_t get_num_cell_states() const { return m_num_cell_states; }
size_t get_direction() const { return m_direction; } size_t get_direction() const { return m_direction; }
size_t get_num_fused_layers() const { return m_num_fused_layers; } size_t get_num_fused_layers() const { return m_num_fused_layers; }
bool is_type(ngraph::runtime::cpu::rnn_utils::rnntype rnn_type) const
{
return m_rnntype == rnn_type;
}
private: private:
size_t m_num_timesteps; size_t m_num_timesteps;
size_t m_num_gates_per_cell; size_t m_num_gates_per_cell;
......
...@@ -663,6 +663,7 @@ namespace ngraph ...@@ -663,6 +663,7 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn) void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn)
{ {
(void)external_function; (void)external_function;
auto rnn_op = static_cast<ngraph::op::Rnn*>(node);
auto src_layer_rank = node->get_input_shape(0).size(); auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size(); auto src_iter_rank = node->get_input_shape(1).size();
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
...@@ -677,16 +678,33 @@ namespace ngraph ...@@ -677,16 +678,33 @@ namespace ngraph
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
} }
#else #else
auto src_iter_c_rank = node->get_input_shape(2).size();
auto weights_layer_rank = node->get_input_shape(3).size(); if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm))
auto weights_iter_rank = node->get_input_shape(4).size();
auto bias_rank = node->get_input_shape(5).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && src_iter_c_rank == 2 &&
weights_layer_rank == 2 && weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{ {
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); auto src_iter_c_rank = node->get_input_shape(2).size();
auto weights_layer_rank = node->get_input_shape(3).size();
auto weights_iter_rank = node->get_input_shape(4).size();
auto bias_rank = node->get_input_shape(5).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && src_iter_c_rank == 2 &&
weights_layer_rank == 2 && weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
else if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_rnn))
{
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
} }
#endif #endif
} }
......
...@@ -72,6 +72,77 @@ ...@@ -72,6 +72,77 @@
using namespace ngraph; using namespace ngraph;
void ngraph::runtime::cpu::pass::VanillaRNNFusion::construct_vanilla_rnn()
{
// pattern to capture the vanilla RNN
// at = W*h{t, l-1} + U *h{t-1, l} + B
// ht = activation(at)
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 34});
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 34});
auto concat =
std::make_shared<ngraph::op::Concat>(NodeVector{src_layer_label, src_iter_label}, 0);
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{34, 2});
auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{64, 2});
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return ((is_type<ngraph::op::Broadcast>(n)) || (is_type<ngraph::op::Reshape>(n)));
};
auto dot = std::make_shared<ngraph::op::Dot>(concat, weights);
auto add = std::make_shared<ngraph::op::Add>(
dot, std::make_shared<pattern::op::Skip>(bias_label, broadcast_pred));
auto activation = std::make_shared<ngraph::op::Tanh>(add);
auto callback = [src_layer_label, src_iter_label, weights, bias_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_rnn;
auto fused_weights = pattern_map[weights];
auto bias = pattern_map[bias_label];
auto src_layer = pattern_map[src_layer_label];
auto src_iter = pattern_map[src_iter_label];
size_t slc = src_layer->get_shape()[1];
size_t sic = src_iter->get_shape()[1];
size_t dlc = fused_weights->get_shape()[1];
size_t n_gates = 1;
size_t direction = 1;
size_t n_layers = 1;
size_t n_state = 1;
size_t time_steps = 1;
size_t seq_length = 1;
// split the fused weights for RNN kernel
auto wei_layer = std::make_shared<ngraph::op::Slice>(
fused_weights, Coordinate{0, 0}, Coordinate{slc, dlc});
auto wei_iter = std::make_shared<ngraph::op::Slice>(
fused_weights, Coordinate{slc, 0}, Coordinate{slc + sic, dlc});
auto rnn_node = std::make_shared<ngraph::op::Rnn>(src_layer,
src_iter,
wei_layer,
wei_iter,
bias,
time_steps,
n_gates,
seq_length,
n_state,
direction,
n_layers,
rnn_type);
auto dst_layer = std::make_shared<ngraph::op::GetOutputElement>(rnn_node, 0);
auto dst_iter = std::make_shared<ngraph::op::GetOutputElement>(rnn_node, 1);
ngraph::replace_node(m.get_match_root(), dst_layer);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(activation, "VanillaRNNFusion.vanilla_rnn");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop() void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
{ {
size_t ref_batch_size = 2; size_t ref_batch_size = 2;
......
...@@ -28,6 +28,7 @@ namespace ngraph ...@@ -28,6 +28,7 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class VanillaRNNFusion;
class LSTMFusion; class LSTMFusion;
class RNNFusion; class RNNFusion;
class BiDirectionalRnn; class BiDirectionalRnn;
...@@ -36,6 +37,19 @@ namespace ngraph ...@@ -36,6 +37,19 @@ namespace ngraph
} }
} }
} }
class CPU_BACKEND_API ngraph::runtime::cpu::pass::VanillaRNNFusion
: public ngraph::pass::GraphRewrite
{
public:
VanillaRNNFusion()
: GraphRewrite()
{
construct_vanilla_rnn();
}
private:
void construct_vanilla_rnn();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite class CPU_BACKEND_API ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{ {
......
...@@ -4013,6 +4013,30 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell) ...@@ -4013,6 +4013,30 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
} }
} }
TEST(cpu_fusion, vanilla_rnn_cpu_vs_inter)
{
const std::string file_name("tensorflow/rnn/vanilla_rnn_3_time_step.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
auto lstm_ops = get_ops_of_type<op::Rnn>(cpu_f);
EXPECT_EQ(lstm_ops.size(), 3);
}
TEST(cpu_fusion, validate_fuse_gru_inputs) TEST(cpu_fusion, validate_fuse_gru_inputs)
{ {
const std::string file_name("mxnet/gru_debug.json"); const std::string file_name("mxnet/gru_debug.json");
......
[
{
"name": "Function_5",
"ops": [
{
"cacheable": false,
"element_type": "float",
"friendly_name": "rnn/basic_rnn_cell/kernel",
"name": "Parameter_4535",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_4535_0"
],
"shape": [
34,
2
],
"type_info": {
"name": "Parameter",
"version": 0
}
},
{
"cacheable": false,
"element_type": "float",
"friendly_name": "rnn/basic_rnn_cell/bias",
"name": "Parameter_4536",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_4536_0"
],
"shape": [
2
],
"type_info": {
"name": "Parameter",
"version": 0
}
},
{
"cacheable": false,
"element_type": "float",
"friendly_name": "_arg_x_0_0",
"name": "Parameter_4537",
"op": "Parameter",
"op_version": 0,
"outputs": [
"Parameter_4537_0"
],
"shape": [
64,
3,
32
],
"type_info": {
"name": "Parameter",
"version": 0
}
},
{
"friendly_name": "unstack",
"inputs": [
"Parameter_4537"
],
"lower_bounds": [
0,
0,
0
],
"name": "Slice_4538",
"op": "Slice",
"op_version": 0,
"outputs": [
"Slice_4538_0"
],
"strides": [
1,
1,
1
],
"type_info": {
"name": "Slice",
"version": 0
},
"upper_bounds": [
64,
1,
32
]
},
{
"friendly_name": "unstack",
"input_order": [
0,
1,
2
],
"inputs": [
"Slice_4538"
],
"name": "Reshape_4539",
"op": "Reshape",
"op_version": 0,
"output_shape": [
64,
32
],
"outputs": [
"Reshape_4539_0"
],
"type_info": {
"name": "Reshape",
"version": 0
}
},
{
"element_type": "float",
"friendly_name": "rnn/BasicRNNCellZeroState/zeros",
"name": "Constant_4544",
"op": "Constant",
"op_version": 0,
"outputs": [
"Constant_4544_0"
],
"shape": [
64,
2
],
"type_info": {
"name": "Constant",
"version": 0
},
"value": [
"0"
]
},
{
"axis": 1,
"friendly_name": "rnn/basic_rnn_cell/concat",
"inputs": [
"Reshape_4539",
"Constant_4544"
],
"name": "Concat_4546",
"op": "Concat",
"op_version": 0,
"outputs": [
"Concat_4546_0"
],
"type_info": {
"name": "Concat",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd",
"inputs": [
"Concat_4546",
"Parameter_4535"
],
"name": "Dot_4547",
"op": "Dot",
"op_version": 0,
"outputs": [
"Dot_4547_0"
],
"reduction_axes_count": 1,
"type_info": {
"name": "Dot",
"version": 0
}
},
{
"axes": [
0
],
"friendly_name": "rnn/basic_rnn_cell/BiasAdd",
"inputs": [
"Parameter_4536"
],
"name": "Broadcast_4548",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_4548_0"
],
"shape": [
64,
2
],
"type_info": {
"name": "Broadcast",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd",
"inputs": [
"Dot_4547",
"Broadcast_4548"
],
"name": "Add_4549",
"op": "Add",
"op_version": 0,
"outputs": [
"Add_4549_0"
],
"type_info": {
"name": "Add",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/Tanh",
"inputs": [
"Add_4549"
],
"name": "Tanh_4550",
"op": "Tanh",
"op_version": 0,
"outputs": [
"Tanh_4550_0"
],
"type_info": {
"name": "Tanh",
"version": 0
}
},
{
"inputs": [
"Tanh_4550"
],
"name": "Result_4561",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_4561_0"
],
"type_info": {
"name": "Result",
"version": 0
}
},
{
"friendly_name": "unstack",
"inputs": [
"Parameter_4537"
],
"lower_bounds": [
0,
1,
0
],
"name": "Slice_4540",
"op": "Slice",
"op_version": 0,
"outputs": [
"Slice_4540_0"
],
"strides": [
1,
1,
1
],
"type_info": {
"name": "Slice",
"version": 0
},
"upper_bounds": [
64,
2,
32
]
},
{
"friendly_name": "unstack",
"input_order": [
0,
1,
2
],
"inputs": [
"Slice_4540"
],
"name": "Reshape_4541",
"op": "Reshape",
"op_version": 0,
"output_shape": [
64,
32
],
"outputs": [
"Reshape_4541_0"
],
"type_info": {
"name": "Reshape",
"version": 0
}
},
{
"axis": 1,
"friendly_name": "rnn/basic_rnn_cell/concat_1",
"inputs": [
"Reshape_4541",
"Tanh_4550"
],
"name": "Concat_4551",
"op": "Concat",
"op_version": 0,
"outputs": [
"Concat_4551_0"
],
"type_info": {
"name": "Concat",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_1",
"inputs": [
"Concat_4551",
"Parameter_4535"
],
"name": "Dot_4552",
"op": "Dot",
"op_version": 0,
"outputs": [
"Dot_4552_0"
],
"reduction_axes_count": 1,
"type_info": {
"name": "Dot",
"version": 0
}
},
{
"axes": [
0
],
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_1",
"inputs": [
"Parameter_4536"
],
"name": "Broadcast_4553",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_4553_0"
],
"shape": [
64,
2
],
"type_info": {
"name": "Broadcast",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_1",
"inputs": [
"Dot_4552",
"Broadcast_4553"
],
"name": "Add_4554",
"op": "Add",
"op_version": 0,
"outputs": [
"Add_4554_0"
],
"type_info": {
"name": "Add",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/Tanh_1",
"inputs": [
"Add_4554"
],
"name": "Tanh_4555",
"op": "Tanh",
"op_version": 0,
"outputs": [
"Tanh_4555_0"
],
"type_info": {
"name": "Tanh",
"version": 0
}
},
{
"inputs": [
"Tanh_4555"
],
"name": "Result_4562",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_4562_0"
],
"type_info": {
"name": "Result",
"version": 0
}
},
{
"friendly_name": "unstack",
"inputs": [
"Parameter_4537"
],
"lower_bounds": [
0,
2,
0
],
"name": "Slice_4542",
"op": "Slice",
"op_version": 0,
"outputs": [
"Slice_4542_0"
],
"strides": [
1,
1,
1
],
"type_info": {
"name": "Slice",
"version": 0
},
"upper_bounds": [
64,
3,
32
]
},
{
"friendly_name": "unstack",
"input_order": [
0,
1,
2
],
"inputs": [
"Slice_4542"
],
"name": "Reshape_4543",
"op": "Reshape",
"op_version": 0,
"output_shape": [
64,
32
],
"outputs": [
"Reshape_4543_0"
],
"type_info": {
"name": "Reshape",
"version": 0
}
},
{
"axis": 1,
"friendly_name": "rnn/basic_rnn_cell/concat_2",
"inputs": [
"Reshape_4543",
"Tanh_4555"
],
"name": "Concat_4556",
"op": "Concat",
"op_version": 0,
"outputs": [
"Concat_4556_0"
],
"type_info": {
"name": "Concat",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_2",
"inputs": [
"Concat_4556",
"Parameter_4535"
],
"name": "Dot_4557",
"op": "Dot",
"op_version": 0,
"outputs": [
"Dot_4557_0"
],
"reduction_axes_count": 1,
"type_info": {
"name": "Dot",
"version": 0
}
},
{
"axes": [
0
],
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_2",
"inputs": [
"Parameter_4536"
],
"name": "Broadcast_4558",
"op": "Broadcast",
"op_version": 0,
"outputs": [
"Broadcast_4558_0"
],
"shape": [
64,
2
],
"type_info": {
"name": "Broadcast",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/BiasAdd_2",
"inputs": [
"Dot_4557",
"Broadcast_4558"
],
"name": "Add_4559",
"op": "Add",
"op_version": 0,
"outputs": [
"Add_4559_0"
],
"type_info": {
"name": "Add",
"version": 0
}
},
{
"friendly_name": "rnn/basic_rnn_cell/Tanh_2",
"inputs": [
"Add_4559"
],
"name": "Tanh_4560",
"op": "Tanh",
"op_version": 0,
"outputs": [
"Tanh_4560_0"
],
"type_info": {
"name": "Tanh",
"version": 0
}
},
{
"inputs": [
"Tanh_4560"
],
"name": "Result_4563",
"needs_default_layout": true,
"op": "Result",
"op_version": 0,
"outputs": [
"Result_4563_0"
],
"type_info": {
"name": "Result",
"version": 0
}
}
],
"parameters": [
"Parameter_4535",
"Parameter_4536",
"Parameter_4537"
],
"result": [
"Result_4561",
"Result_4562",
"Result_4563"
]
}
]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment