Commit a444f7a9 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/bi rnn (#2232)

* - Added reorder support for rnn weights_layer/iter

* i) fixed compilation issues ii) working but still observing precision error

* i) fixed failing rnn unit test for DEX ii) refactored workspace in RNN mkldnn emitter

* i) added support for src reorder to TNC from NTC

* reorder support for rnn output fron NTC to TNC

* - added support for rnn weight reorder ldgoi -> ldigo
- code refactor for lstm/rnn kernel in mkldnn emitter

* - refactor rnn mkldnnn kernel, change variable names

* fix RNN codegen kernel

* disbale layer rnn fusion pass, to test CI

* method to validate recurrent rnn inputs

* add correlated macthes for Recurrent RNN PM

* - simplify reorder logic for rnn_weights
- fix graph pattern for fusing rnn cell across time steps

* do weights reorders in rnn timesteps fusion

* refactored LSTM graph pass

* - Bug fix for finding the lstm inputs determenstically
- Refactored LSTM graph pass to single pass
- made changes to LSTM RNN time step fusion graph pass

* - use replace_node instead of replace_output in Lstm_step_wise fusion graph pass

* fix compilation error

* Fix GNMT rnn fusion

* check if the node is in use before replacing in RNN graph passes

*  i) fix style ii) fix topo sort issue in RNN graph pass

* style fix

* fix bug in simplify_concat pass

* replaces Lstm1 -> {GOE1, GOE2} -> {Slice1, Slice2} -> Concat -> Lstm2 with Lstm1 -> Lstm2

* cse for convert layout

* addressed PR comments

* - optimization pass to remove  Lstm1 -> {GOE1, GOE2} -> {Slice1, Slice2} -> Lstm2
- conditional fusing of LSTM cells only for the decoder

* made changes to multi layer RNN fusion callback

* fix asserts in RNN op

* - added support to fuse layers when slc=dlc for RNN cells
- bug fix on the sanity checks for RNN Op

* - support RNN layer fusion till slc = dlc
- bug fixes in multi layer rnn fusion call back

* capture reshape in the RNN weights

* Addressed PR comments

* - added comments in multi layer PM call back
- fuse only if slc == DLC across layers

* restore deleted 3_lstm_cell_forward.json file

* fix typo

* fix failing unit tets

* When processing in place slice, do not change the offset of the slice node if the argument pointer comes from function input.

* Address PR feedback: process in place slice after propagating in place input.

* Set INTERMEDIATE role before propagating in place input.

* Do not add temporaries to the variable name map before propagating in place input in codegen.

* Fix a bug in codegen.

* Fix a bug in codegen slice.

* reenable disabled rnn unit test

* fix compiler error

* - bug fix in the slicing logic for the layer fused rnn cell
- fix failing rnn unit test

* - Addressed PR comments
- removed redundant checks from the rnn graph pass
- simplified rnn call back replace node logic

* - added new multilayer rnn *.json file
- fix test case

* [PRIVATE BRANCH] Style fixes (#2080)

* Style fixes

* change order of lstm gates

* WIP bi rnn

* [PRIVATE BRANCH] Jbobba/rnn fusion review (#2113)

* Style fixes for single-layer RNN fusion

* Style fixes to multi-layer RNN

* added callback routine for bi-directional rnn

* fix rnn op ctor, rnn mkldnn emitter to accomodate bi directional rnn

* style fix

* added helper function for rnn's to query direction and cell_type

* fix clang error

* - unit test case for bi rnn fusion
- style fix

* - updated bi-rnn graph pass to handle reverse and reverse_seq ops in the predicate
- added bi-rnn inter v/s cpu unit test case
- add support to in mkldnn_utils to create_md with tnc/ntc format

* - added enum type to deduce rnn_type

* Addressed PR comments
    - handle reshapes from {t, n, c} to {n, t, c} in the graph pass

* fix style

* fix clang error

* fix style

* i) move enum specific to rnn to seperate header
parent f8632ea0
......@@ -1104,6 +1104,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(RNNFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(AlgebraicSimplification, true, ngraph::pass);
REGISTER_KNOBBED_PASS(MultiLayerRNNFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(BiDirectionalRnn, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPURnnMatFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
......
......@@ -28,6 +28,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/type/element_type.hpp"
using namespace ngraph::runtime::cpu;
......@@ -1052,7 +1053,9 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc)
const mkldnn::memory::desc& dst_iter_desc,
const mkldnn::rnn_direction& rnn_direction,
const mkldnn::algorithm& rnn_algorithm)
{
size_t src_layer_index = build_memory_primitive(src_layer_desc);
size_t src_iter_index = build_memory_primitive(src_iter_desc);
......@@ -1062,10 +1065,10 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
size_t dst_layer_index = build_memory_primitive(dst_layer_desc);
size_t dst_iter_index = build_memory_primitive(dst_iter_desc);
mkldnn::rnn_cell::desc rnn_cell(mkldnn::algorithm::vanilla_lstm);
mkldnn::rnn_cell::desc rnn_cell(rnn_algorithm);
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_training,
rnn_cell,
mkldnn::rnn_direction::unidirectional_left2right,
rnn_direction,
src_layer_desc,
src_iter_desc,
weights_layer_desc,
......@@ -1073,6 +1076,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
bias_desc,
dst_layer_desc,
dst_iter_desc);
auto rnn_layer_prim_desc =
mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, executor::global_cpu_engine);
auto workspace_index =
......@@ -1080,6 +1084,7 @@ size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_de
auto workspace = std::unique_ptr<MKLDNNWorkspace>(
new MKLDNNWorkspace(rnn_layer_prim_desc.workspace_primitive_desc().get_size()));
auto workspace_buf_index = insert_workspace(workspace);
size_t rnn_index = insert_primitive(new mkldnn::rnn_forward(
rnn_layer_prim_desc,
mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]),
......
......@@ -37,6 +37,7 @@
#include "ngraph/runtime/cpu/op/conv_add.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -477,7 +478,28 @@ namespace ngraph
auto rnn_cell_n_states =
static_cast<unsigned long>(rnn_node->get_num_cell_states());
if (out[0].get_shape().size() == 2 && (out[0].get_shape()[1] != feature_size))
auto get_mkldnn_rnn_cell_type = [&]() {
switch (rnn_node->get_rnn_type())
{
case rnn_utils::rnntype::vanilla_rnn: return mkldnn::algorithm::vanilla_rnn;
case rnn_utils::rnntype::vanilla_gru: return mkldnn::algorithm::vanilla_gru;
case rnn_utils::rnntype::vanilla_lstm:
return mkldnn::algorithm::vanilla_lstm;
default: throw ngraph_error("unsupported mkldnn rnn algorithm");
}
};
auto get_mkldnn_rnn_direction = [&]() {
switch (direction)
{
case 1: return mkldnn::rnn_direction::unidirectional_left2right;
case 2: return mkldnn::rnn_direction::bidirectional_concat;
default: throw ngraph_error("unsupported mkldnn rnn direction");
}
};
if (out[0].get_shape().size() == 2 &&
(out[0].get_shape()[1] != direction * feature_size))
{
throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature "
......@@ -508,7 +530,7 @@ namespace ngraph
Shape wei_iter_tz{
num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size};
Shape bias_tz{num_fused_layers, direction, rnn_cell_n_gates, feature_size};
Shape dst_layer_tz{src_sequence_length_max, batch, feature_size};
Shape dst_layer_tz{src_sequence_length_max, batch, direction * feature_size};
Shape dst_iter_tz{
num_fused_layers, direction, rnn_cell_n_states, batch, feature_size};
......@@ -534,7 +556,9 @@ namespace ngraph
wei_iter_md,
bias_md,
dst_layer_md,
dst_iter_md);
dst_iter_md,
get_mkldnn_rnn_direction(),
get_mkldnn_rnn_cell_type());
}
size_t build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
......@@ -543,7 +567,9 @@ namespace ngraph
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc);
const mkldnn::memory::desc& dst_iter_desc,
const mkldnn::rnn_direction& rnn_direction,
const mkldnn::algorithm& rnn_algorithm);
size_t build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc,
......
......@@ -366,6 +366,18 @@ mkldnn::memory::desc runtime::cpu::mkldnn_utils::create_blocked_mkldnn_md(
}
}
if (dims.size() == 3)
{
if (is_perm_sorted(strides, {0, 1, 2}))
{
return memory::desc(dim, dtype, memory::format::tnc);
}
if (is_perm_sorted(strides, {1, 0, 2}))
{
return memory::desc(dim, dtype, memory::format::ntc);
}
}
if (dims.size() == 4)
{
if (is_perm_sorted(strides, {0, 1, 2, 3}))
......@@ -450,6 +462,10 @@ memory::desc runtime::cpu::mkldnn_utils::try_get_named_md(const mkldnn_memory_de
{
case 1: CANONICALIZE_MD(mkldnn_x); break;
case 2: CANONICALIZE_MD(mkldnn_nc); break;
case 3:
CANONICALIZE_MD(mkldnn_tnc);
CANONICALIZE_MD(mkldnn_ntc);
break;
case 4:
CANONICALIZE_MD(mkldnn_nchw);
CANONICALIZE_MD(mkldnn_nhwc);
......
......@@ -28,14 +28,15 @@ shared_ptr<Node> op::Lstm::copy_with_new_args(const NodeVector& new_args) const
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Lstm>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), m_rnntype);
}
op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias)
std::shared_ptr<Node> bias,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op("Lstm", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias}))
, m_output_tensor_shape(src_layer->get_shape())
, m_output_cell_shape(src_iter->get_shape())
......@@ -47,6 +48,7 @@ op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
, m_num_cell_states(2)
, m_direction(1)
, m_num_fused_layers(1)
, m_rnntype(rnn_type)
{
constructor_validate_and_infer_types();
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp"
namespace ngraph
......@@ -43,9 +44,11 @@ namespace ngraph
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias);
std::shared_ptr<Node> bias,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
Shape get_output_tensor_shape() const { return m_output_tensor_shape; }
Shape get_output_cell_shape() const { return m_output_cell_shape; }
ngraph::runtime::cpu::rnn_utils::rnntype get_rnn_type() const { return m_rnntype; }
size_t get_num_timesteps() const { return m_num_timesteps; }
size_t get_src_sequence_length() const { return m_src_sequence_length; }
size_t get_gates_per_cell() const { return m_num_gates_per_cell; }
......@@ -70,6 +73,7 @@ namespace ngraph
size_t m_num_cell_states;
size_t m_direction;
size_t m_num_fused_layers;
ngraph::runtime::cpu::rnn_utils::rnntype m_rnntype;
};
}
}
......@@ -37,7 +37,8 @@ shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
m_src_sequence_length,
m_num_cell_states,
m_direction,
m_num_fused_layers);
m_num_fused_layers,
m_rnntype);
}
op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
......@@ -50,7 +51,8 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers)
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type)
: Op("Rnn", check_single_output_args({src_layer, src_iter, weights_layer, weights_iter, bias}))
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
......@@ -58,6 +60,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
, m_num_cell_states(num_cell_states)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
, m_rnntype(rnn_type)
{
constructor_validate_and_infer_types();
if (src_layer->get_shape().size() != weights_layer->get_shape().size())
......@@ -90,8 +93,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
throw ngraph_error("src_layer size is not equal t*n*c");
}
if ((bias->get_shape()[0] / m_num_fused_layers) != (weights_layer->get_shape()[1]) ||
(bias->get_shape()[0] / m_num_fused_layers) != (weights_iter->get_shape()[1]))
if ((bias->get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(weights_layer->get_shape()[1]) ||
(bias->get_shape()[0] / (m_direction * m_num_fused_layers)) !=
(weights_iter->get_shape()[1]))
{
throw ngraph_error("bias and weights_shape are not compatible");
}
......@@ -108,7 +113,7 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
set_output_size(2);
set_output_type(0,
src_layer->get_element_type(),
Shape{(m_direction * m_num_timesteps * m_batch_size), m_src_iter_feature_size});
Shape{(m_num_timesteps * m_batch_size), m_direction * m_src_iter_feature_size});
set_output_type(1,
src_layer->get_element_type(),
Shape{(m_num_cell_states * m_direction * m_num_fused_layers * m_batch_size),
......
......@@ -18,6 +18,7 @@
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/util.hpp"
namespace ngraph
......@@ -57,9 +58,12 @@ namespace ngraph
size_t src_sequence_length,
size_t num_cell_states,
size_t direction,
size_t num_fused_layers);
size_t num_fused_layers,
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
ngraph::runtime::cpu::rnn_utils::rnntype get_rnn_type() const { return m_rnntype; }
size_t get_num_timesteps() const { return m_num_timesteps; }
size_t get_src_sequence_length() const { return m_src_sequence_length; }
size_t get_gates_per_cell() const { return m_num_gates_per_cell; }
......@@ -83,6 +87,7 @@ namespace ngraph
size_t m_num_cell_states;
size_t m_direction;
size_t m_num_fused_layers;
ngraph::runtime::cpu::rnn_utils::rnntype m_rnntype;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <cstdint>
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace rnn_utils
{
// TODO(pruthvi): Populate this enums based of addition of new MKLDNN RNN variants
enum rnntype
{
vanilla_rnn,
vanilla_gru,
vanilla_lstm
};
}
}
}
}
......@@ -70,6 +70,7 @@
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/util.hpp"
......@@ -1780,8 +1781,14 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fuse_lstm_recurrent_state(
auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400});
auto weights_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{100, 400});
auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm1 = std::make_shared<op::Lstm>(
src_layer_label, src_iter_label, weights_layer_label, weights_iter_label, bias_label);
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto lstm1 = std::make_shared<op::Lstm>(src_layer_label,
src_iter_label,
weights_layer_label,
weights_iter_label,
bias_label,
rnn_type);
auto lstm1_goe0 = std::make_shared<op::GetOutputElement>(lstm1, 0);
auto lstm1_goe1 = std::make_shared<op::GetOutputElement>(lstm1, 1);
......
......@@ -30,6 +30,7 @@ namespace ngraph
{
class LSTMFusion;
class RNNFusion;
class BiDirectionalRnn;
class MultiLayerRNNFusion;
}
}
......@@ -77,3 +78,16 @@ public:
private:
void construct_multi_layer_rnn_fusion_fprop();
};
class ngraph::runtime::cpu::pass::BiDirectionalRnn : public ngraph::pass::GraphRewrite
{
public:
BiDirectionalRnn()
: GraphRewrite()
{
construct_bidirectional_rnn();
}
private:
void construct_bidirectional_rnn();
};
......@@ -38,6 +38,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
......@@ -66,6 +67,7 @@
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/rnn_utils.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/op/update_slice.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
......@@ -2180,6 +2182,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
const int num_rnn_cell_states = 2;
const int rnn_direction = 1;
const int num_of_rnn_fused_layer = 1;
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto rnn_node = make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
......@@ -2190,7 +2195,9 @@ TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
src_seq_length,
num_rnn_cell_states,
rnn_direction,
num_of_rnn_fused_layer);
num_of_rnn_fused_layer,
rnn_type);
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
......@@ -3590,7 +3597,6 @@ TEST(cpu_quant_fusion, qconvb_relu)
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:0", 1);
auto cpu1_results = execute(cpu_f1, args, "CPU");
set_environment("NGRAPH_PASS_ENABLES", "CPUQuantFusion:1", 1);
......@@ -3808,3 +3814,46 @@ TEST(cpu_quant_fusion, qconvba)
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_TRUE(test::all_close(cpu1_results.at(0), cpu2_results.at(0)));
}
TEST(cpu_fusion, fuse_bi_directional_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::BiDirectionalRnn>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/lstm_bi_directional.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
// Bidirectional graph pass will folds the reverse seq
auto rev_seq_ops = get_ops_of_type<op::Reverse>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rev_seq_ops.size(), 0);
// fuse two bi-directional rnn layers in to one MKLDNN Op
EXPECT_EQ(rnn_ops.size(), 1);
}
TEST(cpu_fusion, bi_rnn_interpreter_vs_cpu)
{
const std::string file_name("mxnet/lstm_bi_directional.json");
auto cpu_f = make_function_from_file(file_name);
auto int_f = make_function_from_file(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < int_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment