Unverified Commit 412f5f7b authored by Pruthvi's avatar Pruthvi Committed by GitHub

Vanilla RNN Optimization (#4439)

* WIP PM and callback for vanilla RNN

* - Graph pass to fuse vanilla RNN Cell
- Unit test case for vanilla rnn through TF serialized graph

* - emit MKLDNN kernel for vanilla RNN ii) test case for Vanilla for cpu v/s inter

* i) style check ii) serialized graph for tf vanilla rnn

* fix warnings

* i) fixed emitter code ii) test case passes

* - added support for Vanilla RNN for MKLDNN > v1.0

* fix compiler warnings

* Fix build error

* fix unit test case
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent ef4692a6
This diff is collapsed.
...@@ -1279,6 +1279,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1279,6 +1279,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(ImplicitBroadcastElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(ImplicitBroadcastElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass) REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass)
REGISTER_KNOBBED_PASS(VanillaRNNFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(LSTMFusion, true, runtime::cpu::pass) REGISTER_KNOBBED_PASS(LSTMFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass) REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass)
REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass) REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass)
......
...@@ -1242,6 +1242,47 @@ void MKLDNNEmitter::build_batchnorm_backward( ...@@ -1242,6 +1242,47 @@ void MKLDNNEmitter::build_batchnorm_backward(
mkldnn_primitives[batchnorm_index] = new mkldnn::batch_normalization_backward(batchnorm_pd); mkldnn_primitives[batchnorm_index] = new mkldnn::batch_normalization_backward(batchnorm_pd);
} }
void MKLDNNEmitter::build_vanilla_rnn_forward(
std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
std::vector<char*>& mkldnn_workspaces,
const mkldnn::vanilla_rnn_forward::desc& rnn_desc,
std::vector<size_t>& deps,
size_t rnn_index)
{
size_t src_layer_index = deps[0];
build_memory(mkldnn_memories, rnn_desc.data.src_layer_desc, src_layer_index);
size_t src_iter_index = deps[1];
build_memory(mkldnn_memories, rnn_desc.data.src_iter_desc, src_iter_index);
size_t weights_layer_index = deps[2];
build_memory(mkldnn_memories, rnn_desc.data.weights_layer_desc, weights_layer_index);
size_t weights_iter_index = deps[3];
build_memory(mkldnn_memories, rnn_desc.data.weights_iter_desc, weights_iter_index);
size_t bias_index = deps[4];
build_memory(mkldnn_memories, rnn_desc.data.bias_desc, bias_index);
size_t dst_layer_index = deps[5];
build_memory(mkldnn_memories, rnn_desc.data.dst_layer_desc, dst_layer_index);
size_t dst_iter_index = deps[6];
build_memory(mkldnn_memories, rnn_desc.data.dst_iter_desc, dst_iter_index);
mkldnn::primitive_attr attr;
attr.set_scratchpad_mode(mkldnn::scratchpad_mode::user);
auto rnn_layer_prim_desc =
mkldnn::vanilla_rnn_forward::primitive_desc(rnn_desc, attr, executor::global_cpu_engine);
mkldnn_scratchpad_mds[rnn_index] =
new mkldnn::memory::desc(rnn_layer_prim_desc.scratchpad_desc());
size_t workspace_index = deps[7];
build_memory(mkldnn_memories, rnn_layer_prim_desc.workspace_desc(), workspace_index);
auto workspace = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_desc().get_size()));
auto workspace_buf_index = insert_workspace(mkldnn_workspaces, workspace);
deps[8] = workspace_buf_index;
mkldnn_primitives[rnn_index] = new mkldnn::vanilla_rnn_forward(rnn_layer_prim_desc);
}
void MKLDNNEmitter::build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories, void MKLDNNEmitter::build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives, std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds, std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
...@@ -1624,6 +1665,14 @@ size_t MKLDNNEmitter::query_scratchpad_rnn_forward(const mkldnn::lstm_forward::d ...@@ -1624,6 +1665,14 @@ size_t MKLDNNEmitter::query_scratchpad_rnn_forward(const mkldnn::lstm_forward::d
GET_SIZE GET_SIZE
} }
size_t MKLDNNEmitter::query_scratchpad_vanilla_rnn_forward(
const mkldnn::vanilla_rnn_forward::desc& desc)
{
ATTR_S
auto pd = mkldnn::vanilla_rnn_forward::primitive_desc(desc, attr, executor::global_cpu_engine);
GET_SIZE
}
size_t MKLDNNEmitter::query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc) size_t MKLDNNEmitter::query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc)
{ {
ATTR_S ATTR_S
......
...@@ -1332,6 +1332,93 @@ namespace ngraph ...@@ -1332,6 +1332,93 @@ namespace ngraph
dst_iter_c_desc); dst_iter_c_desc);
} }
template <typename OP>
mkldnn::vanilla_rnn_forward::desc
get_vanilla_rnn_forward_desc(const ngraph::Node* node,
const std::vector<TensorWrapper>& args,
const std::vector<TensorWrapper>& out)
{
auto rnn_node = static_cast<const OP*>(node);
auto src_sequence_length_max =
static_cast<unsigned long>(rnn_node->get_src_sequence_length());
auto direction = static_cast<unsigned long>(rnn_node->get_direction());
auto num_fused_layers =
static_cast<unsigned long>(rnn_node->get_num_fused_layers());
auto feature_size =
static_cast<unsigned long>(rnn_node->get_src_iter_feature_size());
auto batch = static_cast<unsigned long>(rnn_node->get_batch_size());
auto rnn_cell_n_gates =
static_cast<unsigned long>(rnn_node->get_gates_per_cell());
auto get_mkldnn_rnn_direction = [&]() {
switch (direction)
{
case 1: return mkldnn::rnn_direction::unidirectional_left2right;
case 2: return mkldnn::rnn_direction::bidirectional_concat;
default: throw ngraph_error("unsupported mkldnn rnn direction");
}
};
if (out[0].get_shape().size() == 2 &&
(out[0].get_shape()[1] != direction * feature_size))
{
throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature "
"size ");
}
if (out[1].get_shape().size() == 2 && (out[1].get_shape()[1] != feature_size) &&
rnn_node->get_num_timesteps() != 1)
{
throw ngraph_error(
"input sic{ht_1|ct_1} feature size is not equal to output "
"dlc{ht_1|ct_1} "
"feature size ");
}
Shape src_layer_tz{
src_sequence_length_max,
batch,
static_cast<unsigned long>(rnn_node->get_src_layer_feature_size())};
Shape src_iter_tz{num_fused_layers, direction, batch, feature_size};
Shape wei_layer_tz{
num_fused_layers,
direction,
static_cast<unsigned long>(rnn_node->get_src_layer_feature_size()),
rnn_cell_n_gates,
feature_size};
Shape wei_iter_tz{
num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size};
Shape bias_tz{num_fused_layers, direction, rnn_cell_n_gates, feature_size};
Shape dst_layer_tz{src_sequence_length_max, batch, direction * feature_size};
Shape dst_iter_tz{num_fused_layers, direction, batch, feature_size};
// We create the memory descriptors used by the user
auto src_layer_desc = build_memory_descriptor(
src_layer_tz, args[0].get_element_type(), mkldnn::memory::FORMAT::tnc);
auto src_iter_desc = build_memory_descriptor(
src_iter_tz, args[1].get_element_type(), mkldnn::memory::FORMAT::ldnc);
auto weights_layer_desc = build_memory_descriptor(
wei_layer_tz, args[2].get_element_type(), mkldnn::memory::FORMAT::ldigo);
auto weights_iter_desc = build_memory_descriptor(
wei_iter_tz, args[3].get_element_type(), mkldnn::memory::FORMAT::ldigo);
auto bias_desc = build_memory_descriptor(
bias_tz, args[4].get_element_type(), mkldnn::memory::FORMAT::ldgo);
auto dst_layer_desc = build_memory_descriptor(
dst_layer_tz, out[0].get_element_type(), mkldnn::memory::FORMAT::tnc);
auto dst_iter_desc = build_memory_descriptor(
dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldnc);
return mkldnn::vanilla_rnn_forward::desc(mkldnn::prop_kind::forward_training,
mkldnn::algorithm::eltwise_tanh,
get_mkldnn_rnn_direction(),
src_layer_desc,
src_iter_desc,
weights_layer_desc,
weights_iter_desc,
bias_desc,
dst_layer_desc,
dst_iter_desc);
}
void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories, void build_rnn_forward(std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives, std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds, std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
...@@ -1340,6 +1427,15 @@ namespace ngraph ...@@ -1340,6 +1427,15 @@ namespace ngraph
std::vector<size_t>& deps, std::vector<size_t>& deps,
size_t rnn_idx); size_t rnn_idx);
void build_vanilla_rnn_forward(
std::vector<mkldnn::memory*>& mkldnn_memories,
std::vector<mkldnn::primitive*>& mkldnn_primitives,
std::vector<mkldnn::memory::desc*>& mkldnn_scratchpad_mds,
std::vector<char*>& mkldnn_workspaces,
const mkldnn::vanilla_rnn_forward::desc& desc,
std::vector<size_t>& deps,
size_t rnn_idx);
template <bool with_bias> template <bool with_bias>
void build_convolution_forward( void build_convolution_forward(
std::vector<mkldnn::memory*>& mkldnn_memories, std::vector<mkldnn::memory*>& mkldnn_memories,
...@@ -1459,6 +1555,8 @@ namespace ngraph ...@@ -1459,6 +1555,8 @@ namespace ngraph
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc); size_t query_scratchpad_lrn_forward(const mkldnn::lrn_forward::desc& desc);
size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc); size_t query_scratchpad_rnn_forward(const mkldnn::lstm_forward::desc& desc);
size_t query_scratchpad_vanilla_rnn_forward(
const mkldnn::vanilla_rnn_forward::desc& desc);
size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc, size_t query_scratchpad_slice(mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& output_desc, const mkldnn::memory::desc& output_desc,
const ngraph::Coordinate& lower_bounds, const ngraph::Coordinate& lower_bounds,
...@@ -1593,7 +1691,8 @@ namespace ngraph ...@@ -1593,7 +1691,8 @@ namespace ngraph
auto dst_iter_desc = build_memory_descriptor( auto dst_iter_desc = build_memory_descriptor(
dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldsnc); dst_iter_tz, out[1].get_element_type(), mkldnn::memory::FORMAT::ldsnc);
mkldnn::rnn_cell::desc rnn_cell_desc(get_mkldnn_rnn_cell_type()); mkldnn::rnn_cell::desc rnn_cell_desc(get_mkldnn_rnn_cell_type(),
mkldnn::algorithm::eltwise_tanh);
return mkldnn::rnn_forward::desc(mkldnn::prop_kind::forward_training, return mkldnn::rnn_forward::desc(mkldnn::prop_kind::forward_training,
rnn_cell_desc, rnn_cell_desc,
get_mkldnn_rnn_direction(), get_mkldnn_rnn_direction(),
......
...@@ -187,7 +187,19 @@ extern "C" void ...@@ -187,7 +187,19 @@ extern "C" void
{MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[7]]}, {MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[7]]},
{MKLDNN_ARG_DST_ITER_C, *ctx->mkldnn_memories[deps[8]]}, {MKLDNN_ARG_DST_ITER_C, *ctx->mkldnn_memories[deps[8]]},
{MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[9]]}}; {MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[9]]}};
break; break;
case OpType::VANILLA_RNN:
exec_args = {{MKLDNN_ARG_SRC_LAYER, *ctx->mkldnn_memories[deps[0]]},
{MKLDNN_ARG_SRC_ITER, *ctx->mkldnn_memories[deps[1]]},
{MKLDNN_ARG_WEIGHTS_LAYER, *ctx->mkldnn_memories[deps[2]]},
{MKLDNN_ARG_WEIGHTS_ITER, *ctx->mkldnn_memories[deps[3]]},
{MKLDNN_ARG_BIAS, *ctx->mkldnn_memories[deps[4]]},
{MKLDNN_ARG_DST_LAYER, *ctx->mkldnn_memories[deps[5]]},
{MKLDNN_ARG_DST_ITER, *ctx->mkldnn_memories[deps[6]]},
{MKLDNN_ARG_WORKSPACE, *ctx->mkldnn_memories[deps[7]]}};
break;
case OpType::MAXPOOLBACKPROPFORWARD: case OpType::MAXPOOLBACKPROPFORWARD:
case OpType::MAXPOOLWITHINDICES: case OpType::MAXPOOLWITHINDICES:
exec_args = {{MKLDNN_ARG_SRC, *ctx->mkldnn_memories[deps[0]]}, exec_args = {{MKLDNN_ARG_SRC, *ctx->mkldnn_memories[deps[0]]},
......
...@@ -75,6 +75,7 @@ namespace ngraph ...@@ -75,6 +75,7 @@ namespace ngraph
RELU, RELU,
RELUBACKPROP, RELUBACKPROP,
RNN, RNN,
VANILLA_RNN,
SIGMOID, SIGMOID,
SIGMOIDBACKPROP, SIGMOIDBACKPROP,
SLICE, SLICE,
......
...@@ -26,23 +26,120 @@ constexpr NodeTypeInfo op::Rnn::type_info; ...@@ -26,23 +26,120 @@ constexpr NodeTypeInfo op::Rnn::type_info;
#if MKLDNN_VERSION_MAJOR >= 1 #if MKLDNN_VERSION_MAJOR >= 1
shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 6) if (new_args.size() != 6 && new_args.size() != 5)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return make_shared<Rnn>(new_args[0],
new_args[1], if (new_args.size() == 5)
new_args[2], {
new_args[3], return make_shared<Rnn>(new_args[0],
new_args[4], new_args[1],
new_args[5], new_args[2],
m_num_timesteps, new_args[3],
m_num_gates_per_cell, new_args[4],
m_src_sequence_length, m_num_timesteps,
m_num_cell_states, m_num_gates_per_cell,
m_direction, m_src_sequence_length,
m_num_fused_layers, m_num_cell_states,
m_rnntype); m_direction,
m_num_fused_layers,
m_rnntype);
}
else
{
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
new_args[4],
new_args[5],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_num_cell_states,
m_direction,
m_num_fused_layers,
m_rnntype);
}
}
op::Rnn::Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter,
const Output<Node>& weights_layer,
const Output<Node>& weights_iter,
const Output<Node>& bias,
size_t num_timesteps,
size_t num_gates_per_cell,
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op({src_layer, src_iter, weights_layer, weights_iter, bias})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_num_cell_states(num_cell_states)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
, m_rnntype(rnn_type)
{
constructor_validate_and_infer_types();
if (src_layer.get_shape().size() != weights_layer.get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter.get_shape().size() != weights_iter.get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer.get_shape().size() == 2)
{
m_batch_size = src_layer.get_shape()[0] / m_num_timesteps;
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
m_dst_iter_feature_size = weights_iter.get_shape()[1] / (m_num_gates_per_cell);
m_dst_layer_feature_size = weights_layer.get_shape()[1] / (m_num_gates_per_cell);
m_src_iter_feature_size = weights_iter.get_shape()[0] / (m_direction * m_num_fused_layers);
m_src_layer_feature_size = weights_layer.get_shape()[0] / (m_direction * m_num_fused_layers);
if (shape_size(src_layer.get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
throw ngraph_error("src_layer size is not equal t*n*c");
}
if ((bias.get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(weights_layer.get_shape()[1]) ||
(bias.get_shape()[0] / (m_direction * m_num_fused_layers)) != (weights_iter.get_shape()[1]))
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer.get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
set_output_size(2);
set_output_type(0,
src_layer.get_element_type(),
Shape{(m_num_timesteps * m_batch_size), m_direction * m_src_iter_feature_size});
set_output_type(1,
src_layer.get_element_type(),
Shape{(m_num_cell_states * m_direction * m_num_fused_layers * m_batch_size),
m_src_iter_feature_size});
} }
op::Rnn::Rnn(const Output<Node>& src_layer, op::Rnn::Rnn(const Output<Node>& src_layer,
......
...@@ -72,6 +72,19 @@ namespace ngraph ...@@ -72,6 +72,19 @@ namespace ngraph
size_t num_fused_layers, size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type); ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
#else #else
CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter,
const Output<Node>& weights_layer,
const Output<Node>& weights_iter,
const Output<Node>& bias,
size_t num_timesteps,
size_t num_gates_per_cell,
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
CPU_BACKEND_API Rnn(const Output<Node>& src_layer, CPU_BACKEND_API Rnn(const Output<Node>& src_layer,
const Output<Node>& src_iter, const Output<Node>& src_iter,
const Output<Node>& src_iter_c, const Output<Node>& src_iter_c,
...@@ -101,6 +114,11 @@ namespace ngraph ...@@ -101,6 +114,11 @@ namespace ngraph
size_t get_num_cell_states() const { return m_num_cell_states; } size_t get_num_cell_states() const { return m_num_cell_states; }
size_t get_direction() const { return m_direction; } size_t get_direction() const { return m_direction; }
size_t get_num_fused_layers() const { return m_num_fused_layers; } size_t get_num_fused_layers() const { return m_num_fused_layers; }
bool is_type(ngraph::runtime::cpu::rnn_utils::rnntype rnn_type) const
{
return m_rnntype == rnn_type;
}
private: private:
size_t m_num_timesteps; size_t m_num_timesteps;
size_t m_num_gates_per_cell; size_t m_num_gates_per_cell;
......
...@@ -663,6 +663,7 @@ namespace ngraph ...@@ -663,6 +663,7 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn) void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn)
{ {
(void)external_function; (void)external_function;
auto rnn_op = static_cast<ngraph::op::Rnn*>(node);
auto src_layer_rank = node->get_input_shape(0).size(); auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size(); auto src_iter_rank = node->get_input_shape(1).size();
#if MKLDNN_VERSION_MAJOR < 1 #if MKLDNN_VERSION_MAJOR < 1
...@@ -677,16 +678,33 @@ namespace ngraph ...@@ -677,16 +678,33 @@ namespace ngraph
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
} }
#else #else
auto src_iter_c_rank = node->get_input_shape(2).size();
auto weights_layer_rank = node->get_input_shape(3).size(); if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm))
auto weights_iter_rank = node->get_input_shape(4).size();
auto bias_rank = node->get_input_shape(5).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && src_iter_c_rank == 2 &&
weights_layer_rank == 2 && weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{ {
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); auto src_iter_c_rank = node->get_input_shape(2).size();
auto weights_layer_rank = node->get_input_shape(3).size();
auto weights_iter_rank = node->get_input_shape(4).size();
auto bias_rank = node->get_input_shape(5).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && src_iter_c_rank == 2 &&
weights_layer_rank == 2 && weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
else if (rnn_op->is_type(ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_rnn))
{
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
} }
#endif #endif
} }
......
...@@ -72,6 +72,77 @@ ...@@ -72,6 +72,77 @@
using namespace ngraph; using namespace ngraph;
void ngraph::runtime::cpu::pass::VanillaRNNFusion::construct_vanilla_rnn()
{
// pattern to capture the vanilla RNN
// at = W*h{t, l-1} + U *h{t-1, l} + B
// ht = activation(at)
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 34});
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 34});
auto concat =
std::make_shared<ngraph::op::Concat>(NodeVector{src_layer_label, src_iter_label}, 0);
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{34, 2});
auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{64, 2});
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return ((is_type<ngraph::op::Broadcast>(n)) || (is_type<ngraph::op::Reshape>(n)));
};
auto dot = std::make_shared<ngraph::op::Dot>(concat, weights);
auto add = std::make_shared<ngraph::op::Add>(
dot, std::make_shared<pattern::op::Skip>(bias_label, broadcast_pred));
auto activation = std::make_shared<ngraph::op::Tanh>(add);
auto callback = [src_layer_label, src_iter_label, weights, bias_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_rnn;
auto fused_weights = pattern_map[weights];
auto bias = pattern_map[bias_label];
auto src_layer = pattern_map[src_layer_label];
auto src_iter = pattern_map[src_iter_label];
size_t slc = src_layer->get_shape()[1];
size_t sic = src_iter->get_shape()[1];
size_t dlc = fused_weights->get_shape()[1];
size_t n_gates = 1;
size_t direction = 1;
size_t n_layers = 1;
size_t n_state = 1;
size_t time_steps = 1;
size_t seq_length = 1;
// split the fused weights for RNN kernel
auto wei_layer = std::make_shared<ngraph::op::Slice>(
fused_weights, Coordinate{0, 0}, Coordinate{slc, dlc});
auto wei_iter = std::make_shared<ngraph::op::Slice>(
fused_weights, Coordinate{slc, 0}, Coordinate{slc + sic, dlc});
auto rnn_node = std::make_shared<ngraph::op::Rnn>(src_layer,
src_iter,
wei_layer,
wei_iter,
bias,
time_steps,
n_gates,
seq_length,
n_state,
direction,
n_layers,
rnn_type);
auto dst_layer = std::make_shared<ngraph::op::GetOutputElement>(rnn_node, 0);
auto dst_iter = std::make_shared<ngraph::op::GetOutputElement>(rnn_node, 1);
ngraph::replace_node(m.get_match_root(), dst_layer);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(activation, "VanillaRNNFusion.vanilla_rnn");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop() void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
{ {
size_t ref_batch_size = 2; size_t ref_batch_size = 2;
......
...@@ -28,6 +28,7 @@ namespace ngraph ...@@ -28,6 +28,7 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class VanillaRNNFusion;
class LSTMFusion; class LSTMFusion;
class RNNFusion; class RNNFusion;
class BiDirectionalRnn; class BiDirectionalRnn;
...@@ -36,6 +37,19 @@ namespace ngraph ...@@ -36,6 +37,19 @@ namespace ngraph
} }
} }
} }
class CPU_BACKEND_API ngraph::runtime::cpu::pass::VanillaRNNFusion
: public ngraph::pass::GraphRewrite
{
public:
VanillaRNNFusion()
: GraphRewrite()
{
construct_vanilla_rnn();
}
private:
void construct_vanilla_rnn();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite class CPU_BACKEND_API ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{ {
......
...@@ -4013,6 +4013,30 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell) ...@@ -4013,6 +4013,30 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
} }
} }
TEST(cpu_fusion, vanilla_rnn_cpu_vs_inter)
{
const std::string file_name("tensorflow/rnn/vanilla_rnn_3_time_step.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
auto lstm_ops = get_ops_of_type<op::Rnn>(cpu_f);
EXPECT_EQ(lstm_ops.size(), 3);
}
TEST(cpu_fusion, validate_fuse_gru_inputs) TEST(cpu_fusion, validate_fuse_gru_inputs)
{ {
const std::string file_name("mxnet/gru_debug.json"); const std::string file_name("mxnet/gru_debug.json");
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment