Commit a444f7a9 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/bi rnn (#2232)

* - Added reorder support for rnn weights_layer/iter

* i) fixed compilation issues ii) working but still observing precision error

* i) fixed failing rnn unit test for DEX ii) refactored workspace in RNN mkldnn emitter

* i) added support for src reorder to TNC from NTC

* reorder support for rnn output fron NTC to TNC

* - added support for rnn weight reorder ldgoi -> ldigo
- code refactor for lstm/rnn kernel in mkldnn emitter

* - refactor rnn mkldnnn kernel, change variable names

* fix RNN codegen kernel

* disbale layer rnn fusion pass, to test CI

* method to validate recurrent rnn inputs

* add correlated macthes for Recurrent RNN PM

* - simplify reorder logic for rnn_weights
- fix graph pattern for fusing rnn cell across time steps

* do weights reorders in rnn timesteps fusion

* refactored LSTM graph pass

* - Bug fix for finding the lstm inputs determenstically
- Refactored LSTM graph pass to single pass
- made changes to LSTM RNN time step fusion graph pass

* - use replace_node instead of replace_output in Lstm_step_wise fusion graph pass

* fix compilation error

* Fix GNMT rnn fusion

* check if the node is in use before replacing in RNN graph passes

*  i) fix style ii) fix topo sort issue in RNN graph pass

* style fix

* fix bug in simplify_concat pass

* replaces Lstm1 -> {GOE1, GOE2} -> {Slice1, Slice2} -> Concat -> Lstm2 with Lstm1 -> Lstm2

* cse for convert layout

* addressed PR comments

* - optimization pass to remove  Lstm1 -> {GOE1, GOE2} -> {Slice1, Slice2} -> Lstm2
- conditional fusing of LSTM cells only for the decoder

* made changes to multi layer RNN fusion callback

* fix asserts in RNN op

* - added support to fuse layers when slc=dlc for RNN cells
- bug fix on the sanity checks for RNN Op

* - support RNN layer fusion till slc = dlc
- bug fixes in multi layer rnn fusion call back

* capture reshape in the RNN weights

* Addressed PR comments

* - added comments in multi layer PM call back
- fuse only if slc == DLC across layers

* restore deleted 3_lstm_cell_forward.json file

* fix typo

* fix failing unit tets

* When processing in place slice, do not change the offset of the slice node if the argument pointer comes from function input.

* Address PR feedback: process in place slice after propagating in place input.

* Set INTERMEDIATE role before propagating in place input.

* Do not add temporaries to the variable name map before propagating in place input in codegen.

* Fix a bug in codegen.

* Fix a bug in codegen slice.

* reenable disabled rnn unit test

* fix compiler error

* - bug fix in the slicing logic for the layer fused rnn cell
- fix failing rnn unit test

* - Addressed PR comments
- removed redundant checks from the rnn graph pass
- simplified rnn call back replace node logic

* - added new multilayer rnn *.json file
- fix test case

* [PRIVATE BRANCH] Style fixes (#2080)

* Style fixes

* change order of lstm gates

* WIP bi rnn

* [PRIVATE BRANCH] Jbobba/rnn fusion review (#2113)

* Style fixes for single-layer RNN fusion

* Style fixes to multi-layer RNN

* added callback routine for bi-directional rnn

* fix rnn op ctor, rnn mkldnn emitter to accomodate bi directional rnn

* style fix

* added helper function for rnn's to query direction and cell_type

* fix clang error

* - unit test case for bi rnn fusion
- style fix

* - updated bi-rnn graph pass to handle reverse and reverse_seq ops in the predicate
- added bi-rnn inter v/s cpu unit test case
- add support to in mkldnn_utils to create_md with tnc/ntc format

* - added enum type to deduce rnn_type

* Addressed PR comments
    - handle reshapes from {t, n, c} to {n, t, c} in the graph pass

* fix style

* fix clang error

* fix style

* i) move enum specific to rnn to seperate header
parent f8632ea0
...@@ -1104,6 +1104,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1104,6 +1104,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass); REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass);
REGISTER_KNOBBED_PASS(MultiLayerRNNFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(MultiLayerRNNFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(BiDirectionalRnn, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPURnnMatFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPURnnMatFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
...@@ -1052,7 +1053,9 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de ...@@ -1052,7 +1053,9 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
const mkldnn::memory::desc& weights_iter_desc, const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc, const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc, const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc) const mkldnn::memory::desc& dst_iter_desc,
const mkldnn::rnn_direction& rnn_direction,
const mkldnn::algorithm& rnn_algorithm)
{ {
size_t src_layer_index = build_memory_primitive(src_layer_desc); size_t src_layer_index = build_memory_primitive(src_layer_desc);
size_t src_iter_index = build_memory_primitive(src_iter_desc); size_t src_iter_index = build_memory_primitive(src_iter_desc);
...@@ -1062,10 +1065,10 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de ...@@ -1062,10 +1065,10 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
size_t dst_layer_index = build_memory_primitive(dst_layer_desc); size_t dst_layer_index = build_memory_primitive(dst_layer_desc);
size_t dst_iter_index = build_memory_primitive(dst_iter_desc); size_t dst_iter_index = build_memory_primitive(dst_iter_desc);
mkldnn::rnn_cell::desc rnn_cell(mkldnn::algorithm::vanilla_lstm); mkldnn::rnn_cell::desc rnn_cell(rnn_algorithm);
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_training, mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_training,
rnn_cell, rnn_cell,
mkldnn::rnn_direction::unidirectional_left2right, rnn_direction,
src_layer_desc, src_layer_desc,
src_iter_desc, src_iter_desc,
weights_layer_desc, weights_layer_desc,
...@@ -1073,6 +1076,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de ...@@ -1073,6 +1076,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
bias_desc, bias_desc,
dst_layer_desc, dst_layer_desc,
dst_iter_desc); dst_iter_desc);
auto rnn_layer_prim_desc = auto rnn_layer_prim_desc =
mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, executor::global_cpu_engine); mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, executor::global_cpu_engine);
auto workspace_index = auto workspace_index =
...@@ -1080,6 +1084,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de ...@@ -1080,6 +1084,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
auto workspace = std::unique_ptr<MKLDNNWorkspace>( auto workspace = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_primitive_desc().get_size())); new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_primitive_desc().get_size()));
auto workspace_buf_index = insert_workspace(workspace); auto workspace_buf_index = insert_workspace(workspace);
size_t rnn_index = insert_primitive(new mkldnn::rnn_forward( size_t rnn_index = insert_primitive(new mkldnn::rnn_forward(
rnn_layer_prim_desc, rnn_layer_prim_desc,
mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]), mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]),
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp" #include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp" #include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -477,7 +478,28 @@ namespace ngraph ...@@ -477,7 +478,28 @@ namespace ngraph
auto rnn_cell_n_states = auto rnn_cell_n_states =
static_cast<unsigned long>(rnn_node->get_num_cell_states()); static_cast<unsigned long>(rnn_node->get_num_cell_states());
if (out[0].get_shape().size() == 2 && (out[0].get_shape()[1] != feature_size)) auto get_mkldnn_rnn_cell_type = [&]() {
switch (rnn_node->get_rnn_type())
{
case rnn_utils::rnntype::vanilla_rnn: return mkldnn::algorithm::vanilla_rnn;
case rnn_utils::rnntype::vanilla_gru: return mkldnn::algorithm::vanilla_gru;
case rnn_utils::rnntype::vanilla_lstm:
return mkldnn::algorithm::vanilla_lstm;
default: throw ngraph_error("unsupported mkldnn rnn algorithm");
}
};
auto get_mkldnn_rnn_direction = [&]() {
switch (direction)
{
case 1: return mkldnn::rnn_direction::unidirectional_left2right;
case 2: return mkldnn::rnn_direction::bidirectional_concat;
default: throw ngraph_error("unsupported mkldnn rnn direction");
}
};
if (out[0].get_shape().size() == 2 &&
(out[0].get_shape()[1] != direction * feature_size))
{ {
throw ngraph_error( throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature " "input slc{ht} feature size is not equal to output dlc{ht} feature "
...@@ -508,7 +530,7 @@ namespace ngraph ...@@ -508,7 +530,7 @@ namespace ngraph
Shape wei_iter_tz{ Shape wei_iter_tz{
num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size}; num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size};
Shape bias_tz{num_fused_layers, direction, rnn_cell_n_gates, feature_size}; Shape bias_tz{num_fused_layers, direction, rnn_cell_n_gates, feature_size};
Shape dst_layer_tz{src_sequence_length_max, batch, feature_size}; Shape dst_layer_tz{src_sequence_length_max, batch, direction * feature_size};
Shape dst_iter_tz{ Shape dst_iter_tz{
num_fused_layers, direction, rnn_cell_n_states, batch, feature_size}; num_fused_layers, direction, rnn_cell_n_states, batch, feature_size};
...@@ -534,7 +556,9 @@ namespace ngraph ...@@ -534,7 +556,9 @@ namespace ngraph
wei_iter_md, wei_iter_md,
bias_md, bias_md,
dst_layer_md, dst_layer_md,
dst_iter_md); dst_iter_md,
get_mkldnn_rnn_direction(),
get_mkldnn_rnn_cell_type());
} }
size_t build_rnn_forward(const mkldnn::memory::desc& src_layer_desc, size_t build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
...@@ -543,7 +567,9 @@ namespace ngraph ...@@ -543,7 +567,9 @@ namespace ngraph
const mkldnn::memory::desc& weights_iter_desc, const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc, const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc, const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc); const mkldnn::memory::desc& dst_iter_desc,
const mkldnn::rnn_direction& rnn_direction,
const mkldnn::algorithm& rnn_algorithm);
size_t build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc, size_t build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
......
...@@ -366,6 +366,18 @@ mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_blocked_mkldnn_md( ...@@ -366,6 +366,18 @@ mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_blocked_mkldnn_md(
} }
} }
if (dims.size() == 3)
{
if (is_perm_sorted(strides, {0, 1, 2}))
{
return memory::desc(dim, dtype, memory::format::tnc);
}
if (is_perm_sorted(strides, {1, 0, 2}))
{
return memory::desc(dim, dtype, memory::format::ntc);
}
}
if (dims.size() == 4) if (dims.size() == 4)
{ {
if (is_perm_sorted(strides, {0, 1, 2, 3})) if (is_perm_sorted(strides, {0, 1, 2, 3}))
...@@ -450,6 +462,10 @@ memory::desc runtime::cpu::mkldnn_utils::try_get_named_md(const mkldnn_memory_de ...@@ -450,6 +462,10 @@ memory::desc runtime::cpu::mkldnn_utils::try_get_named_md(const mkldnn_memory_de
{ {
case 1: CANONICALIZE_MD(mkldnn_x); break; case 1: CANONICALIZE_MD(mkldnn_x); break;
case 2: CANONICALIZE_MD(mkldnn_nc); break; case 2: CANONICALIZE_MD(mkldnn_nc); break;
case 3:
CANONICALIZE_MD(mkldnn_tnc);
CANONICALIZE_MD(mkldnn_ntc);
break;
case 4: case 4:
CANONICALIZE_MD(mkldnn_nchw); CANONICALIZE_MD(mkldnn_nchw);
CANONICALIZE_MD(mkldnn_nhwc); CANONICALIZE_MD(mkldnn_nhwc);
......
...@@ -28,14 +28,15 @@ shared_ptr<Node> op::Lstm::copy_with_new_args(const NodeVector& new_args) const ...@@ -28,14 +28,15 @@ shared_ptr<Node> op::Lstm::copy_with_new_args(const NodeVector& new_args) const
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<Lstm>( return make_shared<Lstm>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4)); new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), m_rnntype);
} }
op::Lstm::Lstm(std::shared_ptr<Node> src_layer, op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter, std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer, std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter, std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias) std::shared_ptr<Node> bias,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op("Lstm", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias})) : Op("Lstm", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias}))
, m_output_tensor_shape(src_layer->get_shape()) , m_output_tensor_shape(src_layer->get_shape())
, m_output_cell_shape(src_iter->get_shape()) , m_output_cell_shape(src_iter->get_shape())
...@@ -47,6 +48,7 @@ op::Lstm::Lstm(std::shared_ptr<Node> src_layer, ...@@ -47,6 +48,7 @@ op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
, m_num_cell_states(2) , m_num_cell_states(2)
, m_direction(1) , m_direction(1)
, m_num_fused_layers(1) , m_num_fused_layers(1)
, m_rnntype(rnn_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
...@@ -43,9 +44,11 @@ namespace ngraph ...@@ -43,9 +44,11 @@ namespace ngraph
std::shared_ptr<Node> src_iter, std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer, std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter, std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias); std::shared_ptr<Node> bias,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
Shape get_output_tensor_shape() const { return m_output_tensor_shape; } Shape get_output_tensor_shape() const { return m_output_tensor_shape; }
Shape get_output_cell_shape() const { return m_output_cell_shape; } Shape get_output_cell_shape() const { return m_output_cell_shape; }
ngraph::runtime::cpu::rnn_utils::rnntype get_rnn_type() const { return m_rnntype; }
size_t get_num_timesteps() const { return m_num_timesteps; } size_t get_num_timesteps() const { return m_num_timesteps; }
size_t get_src_sequence_length() const { return m_src_sequence_length; } size_t get_src_sequence_length() const { return m_src_sequence_length; }
size_t get_gates_per_cell() const { return m_num_gates_per_cell; } size_t get_gates_per_cell() const { return m_num_gates_per_cell; }
...@@ -70,6 +73,7 @@ namespace ngraph ...@@ -70,6 +73,7 @@ namespace ngraph
size_t m_num_cell_states; size_t m_num_cell_states;
size_t m_direction; size_t m_direction;
size_t m_num_fused_layers; size_t m_num_fused_layers;
ngraph::runtime::cpu::rnn_utils::rnntype m_rnntype;
}; };
} }
} }
...@@ -37,7 +37,8 @@ shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const ...@@ -37,7 +37,8 @@ shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
m_src_sequence_length, m_src_sequence_length,
m_num_cell_states, m_num_cell_states,
m_direction, m_direction,
m_num_fused_layers); m_num_fused_layers,
m_rnntype);
} }
op::Rnn::Rnn(std::shared_ptr<Node> src_layer, op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
...@@ -50,7 +51,8 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -50,7 +51,8 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
size_t src_sequence_length, size_t src_sequence_length,
size_t num_cell_states, size_t num_cell_states,
size_t direction, size_t direction,
size_t num_fused_layers) size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op("Rnn", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias})) : Op("Rnn", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias}))
, m_num_timesteps(num_timesteps) , m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell) , m_num_gates_per_cell(num_gates_per_cell)
...@@ -58,6 +60,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -58,6 +60,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
, m_num_cell_states(num_cell_states) , m_num_cell_states(num_cell_states)
, m_direction(direction) , m_direction(direction)
, m_num_fused_layers(num_fused_layers) , m_num_fused_layers(num_fused_layers)
, m_rnntype(rnn_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
if (src_layer->get_shape().size() != weights_layer->get_shape().size()) if (src_layer->get_shape().size() != weights_layer->get_shape().size())
...@@ -90,8 +93,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -90,8 +93,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
throw ngraph_error("src_layer size is not equal t*n*c"); throw ngraph_error("src_layer size is not equal t*n*c");
} }
if ((bias->get_shape()[0] / m_num_fused_layers) != (weights_layer->get_shape()[1]) || if ((bias->get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(bias->get_shape()[0] / m_num_fused_layers) != (weights_iter->get_shape()[1])) (weights_layer->get_shape()[1]) ||
(bias->get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(weights_iter->get_shape()[1]))
{ {
throw ngraph_error("bias and weights_shape are not compatible"); throw ngraph_error("bias and weights_shape are not compatible");
} }
...@@ -108,7 +113,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -108,7 +113,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
set_output_size(2); set_output_size(2);
set_output_type(0, set_output_type(0,
src_layer->get_element_type(), src_layer->get_element_type(),
Shape{(m_direction * m_num_timesteps * m_batch_size), m_src_iter_feature_size}); Shape{(m_num_timesteps * m_batch_size), m_direction * m_src_iter_feature_size});
set_output_type(1, set_output_type(1,
src_layer->get_element_type(), src_layer->get_element_type(),
Shape{(m_num_cell_states * m_direction * m_num_fused_layers * m_batch_size), Shape{(m_num_cell_states * m_direction * m_num_fused_layers * m_batch_size),
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h" #include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
...@@ -57,9 +58,12 @@ namespace ngraph ...@@ -57,9 +58,12 @@ namespace ngraph
size_t src_sequence_length, size_t src_sequence_length,
size_t num_cell_states, size_t num_cell_states,
size_t direction, size_t direction,
size_t num_fused_layers); size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
ngraph::runtime::cpu::rnn_utils::rnntype get_rnn_type() const { return m_rnntype; }
size_t get_num_timesteps() const { return m_num_timesteps; } size_t get_num_timesteps() const { return m_num_timesteps; }
size_t get_src_sequence_length() const { return m_src_sequence_length; } size_t get_src_sequence_length() const { return m_src_sequence_length; }
size_t get_gates_per_cell() const { return m_num_gates_per_cell; } size_t get_gates_per_cell() const { return m_num_gates_per_cell; }
...@@ -83,6 +87,7 @@ namespace ngraph ...@@ -83,6 +87,7 @@ namespace ngraph
size_t m_num_cell_states; size_t m_num_cell_states;
size_t m_direction; size_t m_direction;
size_t m_num_fused_layers; size_t m_num_fused_layers;
ngraph::runtime::cpu::rnn_utils::rnntype m_rnntype;
}; };
} }
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <cstdint>
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace rnn_utils
{
// TODO(pruthvi): Populate this enums based of addition of new MKLDNN RNN variants
enum rnntype
{
vanilla_rnn,
vanilla_gru,
vanilla_lstm
};
}
}
}
}
...@@ -70,6 +70,7 @@ ...@@ -70,6 +70,7 @@
#include "ngraph/runtime/cpu/op/leaky_relu.hpp" #include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp" #include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -1780,8 +1781,14 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fuse_lstm_recurrent_state( ...@@ -1780,8 +1781,14 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fuse_lstm_recurrent_state(
auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400}); auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400});
auto weights_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400}); auto weights_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400});
auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400}); auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm1 = std::make_shared<op::Lstm>( ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
src_layer_label, src_iter_label, weights_layer_label, weights_iter_label, bias_label); ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto lstm1 = std::make_shared<op::Lstm>(src_layer_label,
src_iter_label,
weights_layer_label,
weights_iter_label,
bias_label,
rnn_type);
auto lstm1_goe0 = std::make_shared<op::GetOutputElement>(lstm1, 0); auto lstm1_goe0 = std::make_shared<op::GetOutputElement>(lstm1, 0);
auto lstm1_goe1 = std::make_shared<op::GetOutputElement>(lstm1, 1); auto lstm1_goe1 = std::make_shared<op::GetOutputElement>(lstm1, 1);
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
{ {
class LSTMFusion; class LSTMFusion;
class RNNFusion; class RNNFusion;
class BiDirectionalRnn;
class MultiLayerRNNFusion; class MultiLayerRNNFusion;
} }
} }
...@@ -77,3 +78,16 @@ public: ...@@ -77,3 +78,16 @@ public:
private: private:
void construct_multi_layer_rnn_fusion_fprop(); void construct_multi_layer_rnn_fusion_fprop();
}; };
class ngraph::runtime::cpu::pass::BiDirectionalRnn : public ngraph::pass::GraphRewrite
{
public:
BiDirectionalRnn()
: GraphRewrite()
{
construct_bidirectional_rnn();
}
private:
void construct_bidirectional_rnn();
};
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
...@@ -66,6 +67,7 @@ ...@@ -66,6 +67,7 @@
#include "ngraph/runtime/cpu/op/lstm.hpp" #include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp" #include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp" #include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
...@@ -2180,6 +2182,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell) ...@@ -2180,6 +2182,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
const int num_rnn_cell_states = 2; const int num_rnn_cell_states = 2;
const int rnn_direction = 1; const int rnn_direction = 1;
const int num_of_rnn_fused_layer = 1; const int num_of_rnn_fused_layer = 1;
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto rnn_node = make_shared<op::Rnn>(src_layer, auto rnn_node = make_shared<op::Rnn>(src_layer,
src_iter, src_iter,
weights_layer, weights_layer,
...@@ -2190,7 +2195,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell) ...@@ -2190,7 +2195,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
src_seq_length, src_seq_length,
num_rnn_cell_states, num_rnn_cell_states,
rnn_direction, rnn_direction,
num_of_rnn_fused_layer); num_of_rnn_fused_layer,
rnn_type);
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0); auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1); auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
...@@ -3590,7 +3597,6 @@ TEST(cpu_quant_fusion, qconvb_relu) ...@@ -3590,7 +3597,6 @@ TEST(cpu_quant_fusion, qconvb_relu)
rng.initialize(tensor_val); rng.initialize(tensor_val);
args.push_back(tensor_val); args.push_back(tensor_val);
} }
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1); set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU"); auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1); set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
...@@ -3808,3 +3814,46 @@ TEST(cpu_quant_fusion, qconvba) ...@@ -3808,3 +3814,46 @@ TEST(cpu_quant_fusion, qconvba)
auto cpu2_results = execute(cpu_f2, args, "CPU"); auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0))); EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
} }
TEST(cpu_fusion, fuse_bi_directional_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::BiDirectionalRnn>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/lstm_bi_directional.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
// Bidirectional graph pass will folds the reverse seq
auto rev_seq_ops = get_ops_of_type<op::Reverse>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rev_seq_ops.size(), 0);
// fuse two bi-directional rnn layers in to one MKLDNN Op
EXPECT_EQ(rnn_ops.size(), 1);
}
TEST(cpu_fusion, bi_rnn_interpreter_vs_cpu)
{
const std::string file_name("mxnet/lstm_bi_directional.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < int_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment