Commit 4df5ea8b authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

RNN fusion (inference) (#1459)

* Add op::Sigmoid to nvgpu.

* Bring rnn fusion and concat passes over into GPU from IA. This is a temporary move until generalization and gpu specification can occur.

* Add LSTM fusion and cudnn inference kernel. Next need recurrent fusion and layer fusion.

* Formatting

* Removed unecessary extra output from LSTM op (rnn with seq. length = 1, so y = hy).

* Add RNN fusion of LSTM cells within a recurrent layer.

* Formatting.

* Add fusion across RNN layers.

* Formatting.

* Add algebraic simplification.

* Added rnn fusion tests.

* Updated conditional on LSTM fusion to better distinguish bound nodes as ht vs xt.

* Formatting.

* Removed print statements.

* Formatting.

* Committing missing file.

* Remove concat inputs pass and mkldnn references.

* fix cmake paths

* conflict resolution with merge from master.

* remove explicit lstm op support. bare LSTM ops are converted to RNN ops for emission.

* Formatting.

* Use NGRAPH_ASSERT. Formatting of intel copyright.

* Add check on the feature size (shape) of the recurrent (hidden) input and cell state, to ensure they are the same size.

* fix wrong rnn header

* Formatting.

* Add back lstm op to dispatch table.

* Added RNN test which shows cudnn rnn kernel is not producing correct results.

* With update to AlgSimpl. to simplify concat-reshape-slice, the check modifed in this commit needed to be relaxed.

* Bug fix in parameter tensor packing.

* Alias third output element of RNN for cell state (bug fix).

* Resolve numerical correctness issue with negative values in RNN (bug fix).
Add minimal test to evaluate LSTM and compare with values calculated by hand.

* Add tensor parameter sizes to kernel hash as
they are kernel-specific.

* Add 2 layer lstm fusion test against by-hand solution.

* Export param concatenation to graph for cudnn kernel at both the single rnn layer and multi-layer.

* Formatting.

* Finishing touches after merge: add support for macro expansed dispatch via op_tbl.

* Simplify macro support for gpu ops.

* Add CUDNN_VERSION >= 7200 defguards for RNN fusion.
Need to decide how to notify user of increased performance with >= 7200.

* Revert lstm_analytic test to explicitly copy data to tensor params.

* Removed namespace arg from NGRAPH_GPU_OP.

* Refactored macros to different header so op_tbl only contains op list.

* Defguard on cudnn_descriptor<cudnnRNNDataDescriptor_t>.

* doubles -> floats

* Reorg. pass asserts, prepare to replace with non-throwing pass failures.

* Remove Lstm op and replace it with Rnn.

* Format

* Utilize RETURN_IF_FALSE in rnn pass to avoid any RT asserts.
Note that falling back to raw (no passes) graph for 2rnn_3lstm json from mxnet models
results in a double free inside of the memory layout pass. Appears to be a bug
in Reshape pass through.

* Removed print statements. Add check on input data and recurrent data.

* Don't reuse memory for non-destructive ops.

* Add back Rnn test.

* Formatting.

* Clean up comments.

* Update test per review comments.
parent f04503b6
......@@ -41,6 +41,8 @@ set(SRC
pass/gpu_layout.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp
op/rnn.cpp
)
if (NGRAPH_GPU_ENABLE)
......
......@@ -136,6 +136,20 @@ namespace ngraph
}
};
#if CUDNN_VERSION >= 7200
template <>
struct cudnn_descriptor<cudnnRNNDataDescriptor_t>
{
static void create(cudnnRNNDataDescriptor_t& desc)
{
CUDNN_SAFE_CALL(cudnnCreateRNNDataDescriptor(&desc));
}
static void destroy(cudnnRNNDataDescriptor_t& desc)
{
CUDNN_SAFE_CALL_NO_THROW(cudnnDestroyRNNDataDescriptor(desc));
}
};
#endif
template <>
struct cudnn_descriptor<cudnnPoolingDescriptor_t>
{
......
......@@ -83,6 +83,24 @@ cudnnTensorDescriptor_t& runtime::gpu::CUDNNEmitter::tensor_descriptor_from_shap
return desc;
}
cudnnTensorDescriptor_t& runtime::gpu::CUDNNEmitter::get_nd_tensor_descriptor(
const Shape& shape, const cudnnDataType_t data_type, const cudnnTensorFormat_t tensor_format)
{
cudnnTensorDescriptor_t& desc = m_descriptors.build<cudnnTensorDescriptor_t>();
std::vector<int> dimensions(shape.size());
for (auto i = 0u; i < shape.size(); i++)
{
dimensions[i] = static_cast<int>(shape[i]);
}
CUDNN_SAFE_CALL(
cudnnSetTensorNdDescriptor(desc,
data_type,
static_cast<int>(dimensions.size()),
dimensions.data(),
runtime::gpu::cudnn_util::compute_strides(dimensions).data()));
return desc;
}
std::vector<int> runtime::gpu::cudnn_util::compute_strides(const Shape& shape)
{
return runtime::gpu::cudnn_util::get_vector_int_from_size_t(row_major_strides(shape));
......@@ -286,6 +304,24 @@ cudnnFilterDescriptor_t& runtime::gpu::CUDNNEmitter::get_cudnn_filter_descriptor
return filter_descriptor;
}
cudnnFilterDescriptor_t& runtime::gpu::CUDNNEmitter::get_nd_filter_descriptor(
const Shape& shape, const cudnnDataType_t data_type, const cudnnTensorFormat_t tensor_format)
{
auto& filter_descriptor = m_descriptors.build<cudnnFilterDescriptor_t>();
std::vector<int> dimensions(shape.size());
for (auto i = 0u; i < shape.size(); i++)
{
dimensions[i] = static_cast<int>(shape[i]);
}
CUDNN_SAFE_CALL(
cudnnSetFilterNdDescriptor(filter_descriptor,
/*dataType=*/data_type,
/*format=*/tensor_format,
/*num_dimensions=*/static_cast<int>(dimensions.size()),
/*dimensions*/ dimensions.data()));
return filter_descriptor;
}
cudnnConvolutionDescriptor_t& runtime::gpu::CUDNNEmitter::get_cudnn_convolution_descriptor(
const Shape& padding,
const Strides& window_movement_strides,
......@@ -896,6 +932,239 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
#if CUDNN_VERSION >= 7200
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::gpu::Rnn* node)
{
auto& args = node->get_inputs();
auto& out = node->get_outputs();
auto dtype = out[0].get_element_type().c_type_string();
std::stringstream ss;
ss << "rnn_psz" << shape_size(args[2].get_shape());
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
uint32_t seq_length = node->get_src_sequence_length();
uint32_t batch_size = node->get_batch_size();
std::vector<int32_t> sequence_lengths(batch_size, seq_length);
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
void* pad_value = m_host_parameters.allocate_by_datatype(data_type, 0);
// determine if LSTM cell is uni/bi-directional
cudnnDirectionMode_t cell_dir;
int direction = node->get_direction();
if (direction == 1)
{
cell_dir = CUDNN_UNIDIRECTIONAL;
}
else if (direction == 2)
{
cell_dir = CUDNN_BIDIRECTIONAL;
}
else
{
throw std::runtime_error("Encountered unhandled cudnnDirectionMode_t");
}
// TO DO: add support for projected input layer
// In that case, input vectorSize must match recProjSize
auto& x_desc = m_descriptors.build<cudnnRNNDataDescriptor_t>();
auto& y_desc = m_descriptors.build<cudnnRNNDataDescriptor_t>();
uint32_t input_size = node->get_src_layer_feature_size() * direction;
CUDNN_SAFE_CALL(
cudnnSetRNNDataDescriptor(x_desc,
data_type,
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, // TO DO: only unpacked
seq_length,
batch_size,
input_size,
sequence_lengths.data(),
pad_value));
uint32_t hidden_size = node->get_src_iter_feature_size() * direction;
CUDNN_SAFE_CALL(cudnnSetRNNDataDescriptor(y_desc,
data_type,
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
seq_length,
batch_size,
hidden_size,
sequence_lengths.data(),
pad_value));
// TO DO: with rnn projection layers the third dimension of the hidden_shape should be recProjSize
cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
uint32_t num_layers = node->get_num_fused_layers() * direction;
Shape hidden_shape{num_layers, batch_size, hidden_size};
auto& hx_desc = get_nd_tensor_descriptor(hidden_shape, data_type, format);
auto& hy_desc = get_nd_tensor_descriptor(hidden_shape, data_type, format);
Shape cell_state_shape{num_layers, batch_size, hidden_size};
auto& cx_desc = get_nd_tensor_descriptor(cell_state_shape, data_type, format);
auto& cy_desc = get_nd_tensor_descriptor(cell_state_shape, data_type, format);
GPUAllocator allocator = m_primitive_emitter->get_memory_allocator();
// TO DO: enable fused dropout layers
// this will require eager allocation of scratch space which we don't currently support
float dropout_prob = 0.0f;
size_t dropout_state_size = 0;
uint64_t seed;
auto& dropout_desc = m_descriptors.build<cudnnDropoutDescriptor_t>();
if (dropout_prob > 0.0f)
{
CUDNN_SAFE_CALL(cudnnDropoutGetStatesSize(*m_ctx->cudnn_handle, &dropout_state_size));
seed = 0UL; // TO DO: add random seed
// Requires memory allocation for RNG state. Need to test adding this eagerly vs
// wrapping the below call into a closure and executing it at RT. Possible failure
// vector in the second method as the dropout descriptor is used in the initialization
// of the RNN descriptor.
CUDNN_SAFE_CALL(cudnnSetDropoutDescriptor(dropout_desc,
*m_ctx->cudnn_handle,
dropout_prob,
nullptr, // device pointer
dropout_state_size,
seed));
}
// TO DO: support all RNN modes
cudnnRNNMode_t mode = CUDNN_LSTM;
if (node->get_gates_per_cell() != 4)
{
throw std::runtime_error("Only LSTMs are currently supported in fused RNN layers");
}
auto& rnn_desc = m_descriptors.build<cudnnRNNDescriptor_t>();
cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
CUDNN_SAFE_CALL(cudnnSetRNNDescriptor(*m_ctx->cudnn_handle,
rnn_desc,
hidden_size,
num_layers,
dropout_desc,
CUDNN_LINEAR_INPUT, // TO DO: support CUDNN_SKIP_INPUT
cell_dir,
mode,
algo,
data_type));
if (algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC)
{
// TO DO: add support for persistant RNN plan
}
// construct descriptor for RNN parameters
auto& temp_input_desc =
get_nd_tensor_descriptor(Shape{batch_size, input_size, 1}, data_type, format);
size_t params_size = 0;
CUDNN_SAFE_CALL(cudnnGetRNNParamsSize(
*m_ctx->cudnn_handle, rnn_desc, temp_input_desc, &params_size, data_type));
auto& w_desc = get_nd_filter_descriptor(Shape{params_size, 1, 1}, data_type, format);
size_t w_idx = allocator.reserve_workspace(params_size);
int num_tensors_per_layer = [&mode] {
switch (mode)
{
case CUDNN_RNN_RELU:
case CUDNN_RNN_TANH:
return 2; // 1 input + 1 recurrent input
case CUDNN_GRU:
return 6; // 3 input + 3 recurrent input
case CUDNN_LSTM:
return 8; // 4 input + 4 recurrent input
default: throw std::runtime_error("Encountered unsupported CUDNN RNN mode");
}
}();
std::vector<std::pair<int64_t, int64_t>> bias_offsets;
std::vector<std::pair<int64_t, int64_t>> weight_offsets;
auto& ifilter_desc = m_descriptors.build<cudnnFilterDescriptor_t>();
for (int ilayer = 0; ilayer < num_layers; ilayer++)
{
for (int itensor = 0; itensor < num_tensors_per_layer; itensor++)
{
for (int kind = 0; kind < 2; kind++)
{
void* offset = nullptr;
CUDNN_SAFE_CALL(((kind == 0) ? cudnnGetRNNLinLayerMatrixParams
: cudnnGetRNNLinLayerBiasParams)(*m_ctx->cudnn_handle,
rnn_desc,
ilayer,
temp_input_desc,
w_desc,
nullptr,
itensor,
ifilter_desc,
&offset));
cudnnDataType_t return_data_type;
cudnnTensorFormat_t return_format;
std::vector<int> dimensions = {1, 1, 1};
int return_rank;
CUDNN_SAFE_CALL(cudnnGetFilterNdDescriptor(ifilter_desc,
static_cast<int>(dimensions.size()),
&return_data_type,
&return_format,
&return_rank,
dimensions.data()));
(kind == 0 ? weight_offsets : bias_offsets)
.emplace_back(reinterpret_cast<int64_t>(offset),
shape_size(dimensions) * args[0].get_element_type().size());
}
}
}
size_t workspace_size = 0;
std::vector<cudnnTensorDescriptor_t> seq_descriptors(seq_length, temp_input_desc);
CUDNN_SAFE_CALL(cudnnGetRNNWorkspaceSize(
*m_ctx->cudnn_handle, rnn_desc, seq_length, seq_descriptors.data(), &workspace_size));
size_t workspace_idx = allocator.reserve_workspace(workspace_size);
auto recurrent_index = num_tensors_per_layer / 2;
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) {
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnRNNForwardInferenceEx(*m_ctx->cudnn_handle,
rnn_desc,
x_desc,
inputs[0],
hx_desc,
inputs[1],
cx_desc,
inputs[3],
w_desc,
inputs[2],
y_desc, // h_i
outputs[0],
hy_desc, // h_t
outputs[1],
cy_desc, // c_t
outputs[2],
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
workspace_ptr,
workspace_size));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
#endif
size_t runtime::gpu::CUDNNEmitter::build_convolution(const std::string& dtype,
const Shape& input_tensor_shape,
const Shape& input_filter_shape,
......
......@@ -35,6 +35,7 @@
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
namespace ngraph
{
......@@ -61,6 +62,7 @@ namespace ngraph
size_t build_primitive(const op::MaxPool* node);
size_t build_primitive(const op::Max* node);
size_t build_primitive(const op::Min* node);
size_t build_primitive(const op::gpu::Rnn* node);
public:
enum class Prop
......@@ -149,10 +151,18 @@ namespace ngraph
tensor_descriptor_from_shape(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnTensorDescriptor_t&
get_nd_tensor_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnFilterDescriptor_t&
get_cudnn_filter_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnFilterDescriptor_t&
get_nd_filter_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnConvolutionDescriptor_t&
get_cudnn_convolution_descriptor(const Shape& padding,
const Strides& window_movement_strides,
......
......@@ -108,6 +108,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp"
......@@ -124,7 +125,7 @@ function<void(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Nod
// ...
#define NGRAPH_OP(a, b) {type_index(typeid(b::a)), runtime::gpu::GPU_Emitter::emit_##a},
static const map<type_index, function<void(EMIT_ARGS)>> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(type_index(typeid(node)));
......@@ -1297,6 +1298,24 @@ void runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
writer.block_end();
}
#if CUDNN_VERSION >= 7200
void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
{
auto rnn = static_cast<const ngraph::op::gpu::Rnn*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_primitive(rnn);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
}
#endif
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
......
......@@ -39,7 +39,7 @@ namespace ngraph
// static void emit_Abs(EMIT_ARGS);
// static void emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static void emit_##a(EMIT_ARGS);
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
#undef NGRAPH_OP
template <typename T>
......
......@@ -103,6 +103,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
......@@ -111,7 +112,9 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
......@@ -551,19 +554,26 @@ void runtime::gpu::GPU_ExternalFunction::compile()
m_function_name = m_function->get_name();
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
#if CUDNN_VERSION >= 7200
// recurrent network fusion
m_pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
m_pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
m_pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
#else
m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
m_pass_manager.register_pass<ngraph::pass::LikeReplacement>();
m_pass_manager
.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
m_pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
m_pass_manager.register_pass<ngraph::pass::Liveness>();
m_pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
GPUAllocator allocator = m_shared_context->m_primitive_emitter->get_memory_allocator();
m_pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
allocator, m_tensor_memory_buffers);
*allocator, m_tensor_memory_buffers);
std::string common_function_string;
auto femitter = bind(&ngraph::runtime::gpu::GPU_ExternalFunction::emit_op_as_function,
this,
......@@ -571,7 +581,6 @@ void runtime::gpu::GPU_ExternalFunction::compile()
placeholders::_2);
m_pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, m_node_function_map, common_function_string);
string dump_filename = file_util::path_join(s_output_dir, m_function_name + "_ops.txt");
m_pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
......@@ -590,7 +599,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
emit_functions();
// allocate device buffers for primitive arguments and workspace
allocator.close();
allocator->close();
m_shared_context->m_primitive_emitter->allocate_primitive_memory();
string code = m_writer.get_code();
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/op_tbl.hpp"
#if CUDNN_VERSION >= 7200
NGRAPH_OP(Rnn, ngraph::op::gpu)
#endif
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::gpu::Rnn::copy_with_new_args(const NodeVector& new_args) const
{
NGRAPH_ASSERT(new_args.size() == 4) << "Incorrect number of new arguments";
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_src_layer_feature_size,
m_src_iter_feature_size,
m_direction,
m_num_fused_layers);
}
op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> params,
std::shared_ptr<Node> state_iter,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int direction,
const int num_fused_layers)
: Op("Rnn", {src_layer, src_iter, params, state_iter})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_src_layer_feature_size(src_layer_feature_size)
, m_src_iter_feature_size(src_iter_feature_size)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
{
NGRAPH_ASSERT(src_layer->get_shape().size() == 2) << "src_layer doesnt have a rank 2";
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / num_timesteps);
NGRAPH_ASSERT(shape_size(src_layer->get_shape()) ==
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
<< "src_layer size is not equal t*n*c";
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
set_output_size(3);
set_output_type(0,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
set_output_type(
1,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
set_output_type(
2,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
namespace gpu
{
// This is RNN op, which is formed by the fusion of multiple RNN cells ( LSTM/ GRU/ vanilla RNN)
// across multiple time slices
// INPUTS:
// [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{num_fused_layers*batch_size, feature_size}
// [1] - recurrent input tensor ht_1 of Shape{sequence length*batch_size, feature_size}
// [2] - flat parameter tensor consisting of weights and biases for each layer
// {W_x^0 | W_h^0 | W_x^1 | W_h^1 | ... | B_x^0 | B_h^0 | B_x^1 | B_h^1 }
// [3] - recurrent cell state tensor ct_1 with same shape as ht_1
// number_of_timesteps - number of unrolled cells up to timestep t.
// num_gates_per_cell - number of gates per RNN cell, LSTM = 4, GRU = 3, vanilla RNN = 1
// src_sequence_length - this will be same as number_of_timesteps
// src_layer_feature_size - feature size w.r.to input tensor
// src_iter_feature_size - feature size w.r.to hidden state
// num_cell_states - number of recurrent state tensor states , LSTM = 2, GRU = 1, vanilla RNN = 1
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, sequence-wise output tensor with shape (sequence_length*batch_size, feature_size) .
// [1] - hf, layer-wise output tensor with shape (num_fused_layers*batch_size, feature_size) .
// [2] - ct output cell state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
class Rnn : public Op
{
public:
Rnn(std::shared_ptr<Node> src_layer, // x
std::shared_ptr<Node> src_iter, // hx
std::shared_ptr<Node> params,
std::shared_ptr<Node> state_iter, // cx
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int direction,
const int num_fused_layers);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
private:
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_direction;
int m_num_fused_layers;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include <numeric>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "gpu_rnn_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#define RETURN_IF_FALSE(cond, message) \
if (!(cond)) \
{ \
NGRAPH_DEBUG << message; \
return false; \
}
using namespace ngraph;
void ngraph::runtime::gpu::pass::LSTMFusion::construct_sigmoid()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
}
static std::shared_ptr<Node> compute_lstm_params(const std::shared_ptr<Node>& w_x,
const std::shared_ptr<Node>& w_h,
const std::shared_ptr<Node>& b_x,
const std::shared_ptr<Node>& b_h)
{
// check if concat of params exists already
// if so, use it
for (auto& node : w_x->get_users())
{
for (auto& possible_concat : node->get_users())
{
if (auto concat = std::dynamic_pointer_cast<op::Concat>(possible_concat))
{
return concat;
}
}
}
NodeVector flat_params;
for (auto& param : NodeVector{w_x, w_h, b_x, b_h})
{
auto shape = param->get_shape();
flat_params.push_back(std::make_shared<op::Reshape>(
param, get_default_order(shape), Shape{shape_size(shape)}));
}
return std::make_shared<op::Concat>(flat_params, 0);
}
void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
{
auto input_xt = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto weights_i2h_reshape =
std::make_shared<op::Reshape>(weights_i2h, AxisVector{1, 0}, Shape{100, 400});
auto dot_1 = std::make_shared<op::Dot>(input_xt, weights_i2h_reshape);
auto bias_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto broadcast_bias_i2h = std::make_shared<op::Broadcast>(bias_i2h, Shape{10, 400}, AxisSet{0});
auto add_1 = std::make_shared<op::Add>(dot_1, broadcast_bias_i2h);
auto hidden_ht = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 50});
auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 50});
auto param2_2_reshape =
std::make_shared<op::Reshape>(weights_h2h, AxisVector{1, 0}, Shape{50, 400});
auto dot_2 = std::make_shared<op::Dot>(hidden_ht, param2_2_reshape);
auto bias_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto broadcast_bias_h2h = std::make_shared<op::Broadcast>(bias_h2h, Shape{10, 400}, AxisSet{0});
auto add_2 = std::make_shared<op::Add>(dot_2, broadcast_bias_h2h);
auto X = std::make_shared<op::Add>(add_2, add_1);
// construct forget gate
auto input_slice_0 = std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100});
auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
//ct-1 -> cell state
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
// construct input gate
auto input_slice_1 = std::make_shared<op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200});
auto input_gate = std::make_shared<op::Sigmoid>(input_slice_1);
auto input_slice_2 = std::make_shared<op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300});
auto tanh_1 = std::make_shared<op::Tanh>(input_slice_2);
auto multiply_input_gate_tanh_1 = std::make_shared<op::Multiply>(input_gate, tanh_1);
auto add_ct_1_input_gate_tanh_1 =
std::make_shared<op::Add>(multiply_forget_gate_ct_1, multiply_input_gate_tanh_1);
auto ct_label = std::make_shared<pattern::op::Label>(
add_ct_1_input_gate_tanh_1, nullptr, NodeVector{add_ct_1_input_gate_tanh_1});
// construct output gate
auto input_slice_3 = std::make_shared<op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400});
auto output_gate = std::make_shared<op::Sigmoid>(input_slice_3);
auto tanh_2 = std::make_shared<op::Tanh>(ct_label);
auto ht = std::make_shared<op::Multiply>(output_gate, tanh_2);
auto ht_label = std::make_shared<pattern::op::Label>(ht, nullptr, NodeVector{ht});
//Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [ct_label,
input_xt,
weights_i2h,
hidden_ht,
weights_h2h,
bias_i2h,
bias_h2h,
ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "In Lstm fprop call back";
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
auto input_xt_rank = input_xt->get_shape().size();
auto hidden_ht_rank = hidden_ht->get_shape().size();
auto weights_i2h_rank = weights_i2h->get_shape().size();
auto weights_h2h_rank = weights_h2h->get_shape().size();
if (input_xt_rank != 2 || hidden_ht_rank != 2 || weights_i2h_rank != 2 ||
weights_h2h_rank != 2)
{
return false;
}
RETURN_IF_FALSE(bias_i2h->get_shape().size() == 1 && bias_h2h->get_shape().size() == 1,
"Bias should have rank of 1 for Rnn op");
// Determine which is ht_1 and xt. but if both xt and ht_1 have the same shape we need to capture this
// reliably in the RNN fusion.
std::shared_ptr<op::gpu::Rnn> lstm = nullptr;
bool intermediate_lstm = false;
if (std::dynamic_pointer_cast<op::GetOutputElement>(pattern_map[ct_1]))
{
intermediate_lstm = true;
}
// if the matched LSTM is the first cell we need to check if symbol input_xt corresponds
// to the input data tensor, or the hidden (recurrent) data tensor
if (!intermediate_lstm &&
(std::dynamic_pointer_cast<op::Broadcast>(pattern_map[hidden_ht]) &&
std::dynamic_pointer_cast<op::Constant>(pattern_map[hidden_ht]->get_argument(0))))
// label input_xt is the input data to the first LSTM
{
auto params = compute_lstm_params(pattern_map[weights_i2h],
pattern_map[weights_h2h],
pattern_map[bias_i2h],
pattern_map[bias_h2h]);
lstm = std::make_shared<op::gpu::Rnn>(pattern_map[input_xt],
pattern_map[hidden_ht],
params,
pattern_map[ct_1],
1,
4,
1,
pattern_map[input_xt]->get_shape()[1],
pattern_map[hidden_ht]->get_shape()[1],
1,
1);
}
else if (!intermediate_lstm &&
(std::dynamic_pointer_cast<op::Broadcast>(pattern_map[input_xt]) &&
std::dynamic_pointer_cast<op::Constant>(pattern_map[input_xt]->get_argument(0))))
// label hidden_ht is the input data to the first LSTM
{
auto params = compute_lstm_params(pattern_map[weights_h2h],
pattern_map[weights_i2h],
pattern_map[bias_h2h],
pattern_map[bias_i2h]);
lstm = std::make_shared<op::gpu::Rnn>(pattern_map[hidden_ht],
pattern_map[input_xt],
params,
pattern_map[ct_1],
1,
4,
1,
pattern_map[hidden_ht]->get_shape()[1],
pattern_map[input_xt]->get_shape()[1],
1,
1);
}
else if (pattern_map[hidden_ht]->get_arguments().size() &&
pattern_map[ct_1]->get_arguments().at(0)->get_instance_id() ==
pattern_map[hidden_ht]->get_arguments().at(0)->get_instance_id())
// this still has a bug vector: if the hidden input ht is a non-broadcasted constant
// it will be misclassified as input data xt
{
// label input_xt is the output data from the previous LSTM cell
NGRAPH_DEBUG << "ct_shape : " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[hidden_ht]->get_shape());
auto params = compute_lstm_params(pattern_map[weights_i2h],
pattern_map[weights_h2h],
pattern_map[bias_i2h],
pattern_map[bias_h2h]);
lstm = std::make_shared<op::gpu::Rnn>(pattern_map[input_xt],
pattern_map[hidden_ht],
params,
pattern_map[ct_1],
1,
4,
1,
pattern_map[input_xt]->get_shape()[1],
pattern_map[hidden_ht]->get_shape()[1],
1,
1);
}
else
{
// label hidden_ht is the output data from the previous LSTM cell
NGRAPH_DEBUG << "ct_shape: " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[input_xt]->get_shape());
auto params = compute_lstm_params(pattern_map[weights_h2h],
pattern_map[weights_i2h],
pattern_map[bias_h2h],
pattern_map[bias_i2h]);
lstm = std::make_shared<op::gpu::Rnn>(pattern_map[hidden_ht],
pattern_map[input_xt],
params,
pattern_map[ct_1],
1,
4,
1,
pattern_map[hidden_ht]->get_shape()[1],
pattern_map[input_xt]->get_shape()[1],
1,
1);
}
auto ht_output = std::make_shared<op::GetOutputElement>(lstm, 0);
auto ct_output = std::make_shared<op::GetOutputElement>(lstm, 2);
NGRAPH_ASSERT(lstm->get_outputs().at(0).get_inputs().size() == 2)
<< "Lstm node doesnt have two outputs";
// Now identify the nodes which consumes the output of LSTM nodes
// and replace them accordingly
// find the user's for {ht|ct} and replace them with lstm_goe_1
for (auto node : pattern_map[ct_label]->get_users())
{
NGRAPH_DEBUG << "node_name: " << node->get_name();
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (node->get_argument(i) == pattern_map[ct_label])
{
node->get_inputs().at(i).replace_output(ct_output->get_outputs().at(0));
}
}
}
// find the user's for {ht} and replace them with lstm_goe_0
ngraph::replace_node(m.get_match_root(), ht_output);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback);
this->add_matcher(m);
}
static std::shared_ptr<ngraph::Node>
compute_rnn_args(std::vector<std::shared_ptr<pattern::op::Label>>& rnn_labels,
pattern::RecurrentMatcher& m,
bool concat_all = false)
{
NGRAPH_DEBUG << "Inside compute arg " << rnn_labels.size();
// src_layer -> concatenate input symbols from different LSTM cells belonging to same RNN layer
// in the order 0, 1, 2... t time slice
if (concat_all)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[0]);
if (node_labels.size() > 1)
{
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
}
else
{
return node_labels[0];
}
}
// i2h or h2h weights shared between LSTM cells
else
{
// return the first instance of the weight matrix for this RNN layer
// weights and biases are reused for all cells in a layer.
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[0]);
return node_labels[node_labels.size() - 1];
}
}
void ngraph::runtime::gpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
auto xt = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto params_label = std::make_shared<pattern::op::Label>(
element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto rpattern_ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto lstm = std::make_shared<op::gpu::Rnn>(xt,
ht_1,
params_label,
rpattern_ct_1,
1,
4,
1,
xt->get_shape()[1],
ht_1->get_shape()[1],
1,
1);
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0); // hidden output
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
pattern::recurrent_graph_rewrite_callback callback = [lstm_node_label,
xt,
ht_1,
params_label,
rpattern_ct_1](
pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In RNN fusion callback";
auto ht_1_label = m.get_bound_nodes_for_pattern(ht_1);
auto params_bound = m.get_bound_nodes_for_pattern(params_label);
// determine the ht and xt
std::shared_ptr<ngraph::Node> src_layer = nullptr;
std::shared_ptr<ngraph::Node> src_iter = nullptr;
auto xt_node_array = m.get_bound_nodes_for_pattern(xt);
auto hidden_ht_array = m.get_bound_nodes_for_pattern(ht_1);
// since we dont have metadata to differentiate between xt and ht_1
// we will be using the broadcasted constant initilization of the first LSTM cell
// in the RNN layer to identify ht_1
if (std::dynamic_pointer_cast<op::Broadcast>(xt_node_array[xt_node_array.size() - 1]) &&
std::dynamic_pointer_cast<op::Constant>(
xt_node_array[xt_node_array.size() - 1]->get_argument(0)))
// here xt is determined to be the hidden (recurrent) input data and so ht is the feedforward input
{
// concatenate the sequence inputs for a given layer
std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{ht_1};
src_layer = compute_rnn_args(src_layer_labels, m, true);
// concatenate the hidden (recurrent) input with the cell
std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{xt};
src_iter = compute_rnn_args(src_iter_labels, m);
}
else if (std::dynamic_pointer_cast<op::Broadcast>(
hidden_ht_array[hidden_ht_array.size() - 1]) &&
std::dynamic_pointer_cast<op::Constant>(
hidden_ht_array[hidden_ht_array.size() - 1]->get_argument(0)))
// here ht is determined to be the hidden (recurrent) input data and so xt is the feedforward input
{
std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{xt};
src_layer = compute_rnn_args(src_layer_labels, m, true);
std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{ht_1};
src_iter = compute_rnn_args(src_iter_labels, m);
}
else
{
// dont fuse, if the PM didn't discover all the cells belonging to RNN layer.
// we dont want to throw an assertion, if pattern matcher cannot discover all
// nodes belonging to RNN, instead we will return and can compute LSTM cell wise
return false;
}
std::vector<std::shared_ptr<pattern::op::Label>> params_labels{params_label};
auto params = compute_rnn_args(params_labels, m);
std::vector<std::shared_ptr<pattern::op::Label>> state_iter_labels{rpattern_ct_1};
auto state_iter = compute_rnn_args(state_iter_labels, m);
auto num_of_lstm_matched = m.get_number_of_recurrent_matches();
if (num_of_lstm_matched <= 1)
{
return false;
}
size_t num_gates_in_lstm = 4;
size_t batch_size = src_layer->get_shape()[0] / num_of_lstm_matched;
size_t sequence_len = num_of_lstm_matched;
size_t src_layer_feature_size = src_layer->get_shape()[1];
size_t feature_size = ht_1_label[0]->get_shape()[1];
// number of states for LSTM is 2
size_t direction = 1;
size_t num_fused_rnn_layers = 1;
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
RETURN_IF_FALSE(src_layer->get_arguments().size() == sequence_len ||
std::dynamic_pointer_cast<op::Parameter>(src_layer),
"number of lstm inputs captured in the RNN fusion is not equal to "
"src_sequence_length");
RETURN_IF_FALSE(!std::dynamic_pointer_cast<op::Parameter>(src_layer) || sequence_len == 1,
"number of lstm inputs captured in the RNN fusion is not equal to "
"src_sequence_length");
auto src_layer_rank = src_layer->get_shape().size();
auto src_iter_rank = src_iter->get_shape().size();
RETURN_IF_FALSE(src_layer_rank == 2 && src_iter_rank == 2,
"Pattern matcher error src_layer, src_iter, have rank 2 for RNN op");
RETURN_IF_FALSE(src_layer->get_element_type() == element::f32 &&
src_iter->get_element_type() == element::f32,
"input tensor type and input recurrent state tensor type for RNN op should "
"be float32");
auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
src_iter,
params,
state_iter,
num_of_lstm_matched,
num_gates_in_lstm,
sequence_len,
src_layer_feature_size,
feature_size,
direction,
num_fused_rnn_layers);
std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(num_of_lstm_matched, nullptr);
auto rnn_ht_out = std::make_shared<op::GetOutputElement>(rnn, 0);
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 1);
auto layer_rnn_ct = std::make_shared<op::GetOutputElement>(rnn, 2);
//slice the rnn ht's
size_t start_index = 0;
size_t end_index = batch_size;
// capture the slices in the reverse order, so it corrosponds to lstm_goes order captured by the Pattern matcher
for (size_t i = 0; i < num_of_lstm_matched; i++)
{
ht_slice_per_timestep[i] = (std::make_shared<op::Slice>(
rnn_ht_out, Coordinate{start_index, 0}, Coordinate{end_index, feature_size}));
start_index += batch_size;
end_index += batch_size;
}
std::reverse(ht_slice_per_timestep.begin(), ht_slice_per_timestep.end());
NGRAPH_DEBUG << "rnn_time_slice: " << ht_slice_per_timestep.size();
// find the lstm's nodes captured in PM
auto lstm_goes = m.get_bound_nodes_for_pattern(lstm_node_label);
std::vector<std::shared_ptr<ngraph::Node>> lstm_nodes;
// we need to collect LSTM from GOE's, in order to determine
// the individaual time slice output ht. lstm_goes will hold the GOE in the decreasing
// order of the time slices
for (size_t i = 0; i < lstm_goes.size(); i++)
{
// lstm's will be the input to GOE's
lstm_nodes.push_back(lstm_goes[i]->get_arguments()[0]);
}
RETURN_IF_FALSE(sequence_len == lstm_nodes.size(),
" Number of lstm nodes in RNN layer is not equal to time slices");
RETURN_IF_FALSE(
lstm_nodes.size() == lstm_goes.size() ||
lstm_goes.size() == ht_slice_per_timestep.size(),
"Number of slices of rnn output ht is not equal to the time slices in RNN layer");
// collect all the consumers of LSTM goe's (ht)
std::set<std::shared_ptr<ngraph::Node>> lstm_goe0_user;
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> map_goe_to_lstm_slices;
std::shared_ptr<Node> goe_0;
for (size_t index = 0; index < lstm_nodes.size(); index++)
{
// now get the GOE0 which is the first output of lstm (ht)
for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs())
{
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
// first output node of lstm
if (goe_node->get_n() == 0)
{
goe_0 = goes->get_node();
for (auto goe0_user : goe_0->get_users())
{
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() &&
ngraph::is_used(goe0_user.get()))
{
lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
NGRAPH_DEBUG << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
<< " goe0_user " << goe0_user->get_name() << " ";
}
}
}
// we need to only check the last LSTM cell Ct user and replace if needed.
if ((index == 0) && (goe_node->get_n() == 1))
{
// check if the last LSTM cell has any consumers
auto n_time_step_lstm_ct_goe = goes->get_node();
ngraph::replace_node(n_time_step_lstm_ct_goe, layer_rnn_ct);
}
}
}
//now go through the lstm goe_0 consumers and replace them with the slice
for (auto& node : lstm_goe0_user)
{
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (map_goe_to_lstm_slices.find(node->get_argument(i)) !=
map_goe_to_lstm_slices.end())
{
node->get_inputs().at(i).replace_output(
map_goe_to_lstm_slices[node->get_argument(i)]->get_outputs().at(0));
}
}
}
NGRAPH_DEBUG << "End of recurrent fusion call back "
<< "matched_node: " << m.get_match_root()->get_name();
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
lstm_node_label, rpattern_ct_1, empty_correlated_matches, callback);
this->add_matcher(m);
}
static std::shared_ptr<Node>
compute_multi_layer_rnn_inputs(const std::shared_ptr<pattern::op::Label>& rnn_label,
pattern::RecurrentMatcher& m)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_label);
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
}
static std::shared_ptr<Node>
compute_multi_layer_rnn_params(const std::shared_ptr<pattern::op::Label>& param_label,
pattern::RecurrentMatcher& m)
{
auto param_nodes = m.get_bound_nodes_for_pattern(param_label);
std::reverse(param_nodes.begin(), param_nodes.end());
// iterate over params for each layer in order [layer 0, ... layer n]
NodeVector biases;
NodeVector layer_params;
for (auto& param : param_nodes)
{
// split and group layer weights and layer biases
auto const& args = param->get_arguments();
for (size_t i = 0; i < args.size(); i++)
{
// first half set of params are weights, second half are biases
if (i < (args.size() / 2))
{
layer_params.push_back(args[i]);
}
else
{
biases.push_back(args[i]);
}
}
}
layer_params.insert(layer_params.end(), biases.begin(), biases.end());
return std::make_shared<op::Concat>(layer_params, 0);
}
void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
{
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
auto src_slice =
std::make_shared<pattern::op::Skip>(src_layer_label, pattern::has_class<op::Slice>());
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
auto params_label = std::make_shared<pattern::op::Label>(
element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto state_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
size_t ref_number_of_timesteps = 3;
size_t ref_number_of_gates_per_cell = 4;
size_t ref_src_seq_length = 3;
size_t ref_src_layer_feature_size = 100;
size_t ref_feature_size = 100;
size_t ref_rnn_direction = 1;
size_t ref_num_of_rnn_fused_layer = 1;
auto ref_rnn_node = std::make_shared<op::gpu::Rnn>(src_slice,
src_iter_label,
params_label,
state_iter_label,
ref_number_of_timesteps,
ref_number_of_gates_per_cell,
ref_src_seq_length,
ref_src_layer_feature_size,
ref_feature_size,
ref_rnn_direction,
ref_num_of_rnn_fused_layer);
auto rnn_ht_out = std::make_shared<op::GetOutputElement>(ref_rnn_node, 0);
auto rnn_ht_label =
std::make_shared<pattern::op::Label>(rnn_ht_out, nullptr, NodeVector{rnn_ht_out});
pattern::recurrent_graph_rewrite_callback callback =
[src_layer_label, src_iter_label, params_label, state_iter_label, rnn_ht_label](
pattern::RecurrentMatcher& m) {
if (m.get_number_of_recurrent_matches() <= 1)
{
return false;
}
auto src_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
auto rnn_ht_out_nodes = m.get_bound_nodes_for_pattern(rnn_ht_label);
auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << "In Recurrent multi layer RNN fusion callback ";
NGRAPH_DEBUG << "Number of RNN's Matched: " << number_of_rnn_cell_matched;
NGRAPH_DEBUG << "matched_root: " << m.get_match_root()->get_name();
NGRAPH_DEBUG << "src_layer_node: " << src_nodes[0]->get_name();
// we can fuse across different RNN layers only if SLC == DLC
for (size_t i = 0; i < number_of_rnn_cell_matched; i++)
{
if (src_nodes[i]->get_shape()[1] != rnn_ht_out_nodes[i]->get_shape()[1])
{
NGRAPH_DEBUG << "Not fusing since the feature sizes for xt and ht_1 dont match";
return false;
}
}
// we just need to capture the input symbols {x0 | x1.....| xt} of the first lstm layer
// the intermediate inputs for the next layer will be computed by the kernel
auto src_layer_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
auto src_layer = src_layer_nodes[src_layer_nodes.size() - 1];
auto src_iter = compute_multi_layer_rnn_inputs(src_iter_label, m);
auto state_iter = compute_multi_layer_rnn_inputs(state_iter_label, m);
auto params = compute_multi_layer_rnn_params(params_label, m);
// collect list of rnn ops (layers)
std::vector<std::shared_ptr<op::gpu::Rnn>> rnn_nodes;
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label))
{
auto rnn_op =
std::dynamic_pointer_cast<op::gpu::Rnn>(rnn_goe_input->get_arguments()[0]);
if (rnn_op)
{
rnn_nodes.push_back(rnn_op);
}
else
{
throw ngraph_error("Input for RNN output GetOuputElement Op should be RNN");
}
}
size_t num_time_steps = rnn_nodes[0]->get_num_timesteps();
size_t num_gates_in_lstm = rnn_nodes[0]->get_gates_per_cell();
size_t batch_size = rnn_nodes[0]->get_batch_size();
size_t sequence_len = rnn_nodes[0]->get_src_sequence_length();
size_t src_layer_feature_size = rnn_nodes[0]->get_src_layer_feature_size();
size_t feature_size = rnn_nodes[0]->get_src_iter_feature_size();
size_t rnn_direction = rnn_nodes[0]->get_direction();
size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
NGRAPH_DEBUG << "state_iter: " << join(state_iter->get_shape());
NGRAPH_DEBUG << "params size {wx|wh|bx|bh}: " << shape_size(params->get_shape());
NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
if (auto src_rnn = std::dynamic_pointer_cast<op::gpu::Rnn>(src_layer))
{
RETURN_IF_FALSE(
src_rnn->get_num_timesteps() == num_time_steps,
"input symbols for the layer fused RNN op, should be captured only for the "
"first layer");
}
RETURN_IF_FALSE(
!std::dynamic_pointer_cast<op::Parameter>(src_layer) ||
rnn_nodes[0]->get_num_timesteps() == 1,
"input symbols for the layer fused RNN op, should be captured only for the first "
"layer");
RETURN_IF_FALSE(
(src_iter->get_arguments().size()) == num_fused_rnn_layers,
"number of hidden states for RNN op in the layer fusion is not equal to num of "
"fused_rnn_layers");
RETURN_IF_FALSE(
(state_iter->get_arguments().size()) == num_fused_rnn_layers,
"number of cell states for RNN op in the layer fusion is not equal to num of "
"fused_rnn_layers");
RETURN_IF_FALSE(
(params->get_arguments().size()) == num_fused_rnn_layers * 4,
"RNN param tensor does not consist of normal and recurrent weight and bias tensor "
"for each layer");
auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
src_iter,
params,
state_iter,
num_time_steps,
num_gates_in_lstm,
sequence_len,
src_layer_feature_size,
feature_size,
rnn_direction,
num_fused_rnn_layers);
auto output_layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 1);
auto layer_rnn_ct = std::make_shared<op::GetOutputElement>(rnn, 2);
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node>& rnn_ct, size_t layer) {
std::shared_ptr<Node> node_to_replace = rnn_ct;
auto ct_slice = std::make_shared<op::Slice>(
layer_rnn_ct,
Coordinate{static_cast<unsigned long>(batch_size * (layer - 1)), 0},
Coordinate{static_cast<unsigned long>(batch_size * rnn_direction * layer),
static_cast<unsigned long>(feature_size)});
if (rnn_ct->get_users().size() == 1)
{
if (std::dynamic_pointer_cast<op::Slice>(rnn_ct->get_users()[0]))
{
node_to_replace = rnn_ct->get_users()[0];
}
}
if (ngraph::is_used(node_to_replace.get()))
{
ngraph::replace_node(node_to_replace, ct_slice);
}
};
for (size_t index = 0; index < rnn_nodes.size(); index++)
{
for (auto& rnn_goes : rnn_nodes[index]->get_users())
{
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (rnn_goes->get_users().empty())
{
continue;
}
if (auto rnn_goe_node =
std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
// we need to only replace the {ht} consumers of the last RNN layer,
// since for other layers the intermediate outputs {ht} will be computed
// within the kernel
if (index == 0)
{
if (rnn_goe_node->get_n() == 0)
{
ngraph::replace_node(rnn_goes, output_layer_rnn_ht);
}
}
if (rnn_goe_node->get_n() == 2)
{
replace_rnn_output_cellstate(rnn_goes, num_fused_rnn_layers - index);
}
}
}
}
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
rnn_ht_label, src_layer_label, empty_correlated_matches, callback);
this->add_matcher(m);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace pass
{
class LSTMFusion;
class RNNFusion;
class MultiLayerRNNFusion;
}
}
}
}
class ngraph::runtime::gpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{
public:
LSTMFusion()
: GraphRewrite()
{
construct_sigmoid();
construct_lstm_fprop();
}
private:
void construct_sigmoid();
void construct_lstm_fprop();
};
class ngraph::runtime::gpu::pass::RNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
RNNFusion()
: RecurrentGraphRewrite()
{
construct_rnn_lstm_fprop();
}
private:
void construct_rnn_lstm_fprop();
};
class ngraph::runtime::gpu::pass::MultiLayerRNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
MultiLayerRNNFusion()
: RecurrentGraphRewrite()
{
construct_multi_layer_rnn_fusion_fprop();
}
private:
void construct_multi_layer_rnn_fusion_fprop();
};
......@@ -73,7 +73,7 @@ if(NGRAPH_CPU_ENABLE)
endif()
if(NGRAPH_GPU_ENABLE)
set(SRC ${SRC} cudnn.cpp gpu_test.cpp)
set(SRC ${SRC} cudnn.cpp gpu_test.cpp gpu_fusion.cpp)
endif()
foreach(TEST_CONFIG ${UNIT_TEST_CONFIG_LIST})
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cstdio>
#include <cudnn.h>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/matcher.hpp"
#include "util/random.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
#if CUDNN_VERSION >= 7200
TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
{
auto src_layer = make_shared<op::Parameter>(element::f32, Shape{10, 100});
auto src_iter = make_shared<op::Parameter>(element::f32, Shape{10, 100});
auto params =
make_shared<op::Parameter>(element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto state_iter = make_shared<op::Parameter>(element::f32, Shape{10, 100});
const int number_of_timesteps = 1;
const int number_of_gates_per_cell = 4;
const int src_seq_length = 1;
const int src_layer_feature_size = 100;
const int feature_size = 100;
const int rnn_direction = 1;
const int num_of_rnn_fused_layer = 1;
auto rnn_node = make_shared<op::gpu::Rnn>(src_layer,
src_iter,
params,
state_iter,
number_of_timesteps,
number_of_gates_per_cell,
src_seq_length,
src_layer_feature_size,
feature_size,
rnn_direction,
num_of_rnn_fused_layer);
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
auto func = make_shared<Function>(NodeVector{rnn_ht_output, rnn_ct_output},
op::ParameterVector{src_layer, src_iter, params, state_iter});
auto backend = runtime::Backend::create("GPU");
shared_ptr<runtime::Tensor> src_layer_t =
backend->create_tensor(element::f32, src_layer->get_shape());
shared_ptr<runtime::Tensor> src_iter_t =
backend->create_tensor(element::f32, src_iter->get_shape());
shared_ptr<runtime::Tensor> state_iter_t =
backend->create_tensor(element::f32, state_iter->get_shape());
shared_ptr<runtime::Tensor> params_t =
backend->create_tensor(element::f32, params->get_shape());
shared_ptr<runtime::Tensor> result_ht = backend->create_tensor(element::f32, {10, 100});
shared_ptr<runtime::Tensor> result_ct = backend->create_tensor(element::f32, Shape{10, 100});
copy_data(src_layer_t, vector<float>(1000, 1));
copy_data(src_iter_t, vector<float>(1000, 1));
copy_data(state_iter_t, vector<float>(1000, 1));
copy_data(params_t, vector<float>(shape_size(params->get_shape()), 1));
backend->call_with_validate(
func, {result_ht, result_ct}, {src_layer_t, src_iter_t, params_t, state_iter_t});
vector<float> expected_ht(10 * 100, 0.964028f);
vector<float> expected_ct;
for (size_t i = 0; i < 10 * 100; i++)
{
expected_ct.push_back(0.964028f);
}
EXPECT_TRUE(test::all_close(expected_ht, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(expected_ct, read_vector<float>(result_ct)));
}
#endif
TEST(gpu_fusion, fuse_lstm_cells)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
auto lstm_ops = get_ops_of_type<op::gpu::Rnn>(func);
EXPECT_EQ(lstm_ops.size(), 6);
}
TEST(gpu_fusion, fuse_2_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::gpu::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::gpu::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
}
}
TEST(gpu_fusion, fuse_1_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/1rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::gpu::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::gpu::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), 1);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
}
}
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(gpu_fusion, lstm_analytic)
{
auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
auto weights_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto weights_i2h_reshape =
std::make_shared<op::Reshape>(weights_i2h, AxisVector{1, 0}, Shape{1, 4});
auto dot_1 = std::make_shared<op::Dot>(input_xt, weights_i2h_reshape);
auto bias_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_i2h = std::make_shared<op::Broadcast>(bias_i2h, Shape{1, 4}, AxisSet{0});
auto add_1 = std::make_shared<op::Add>(dot_1, broadcast_bias_i2h);
auto h_const = op::Constant::create(element::f32, Shape{}, {1.0});
auto hidden_ht = std::make_shared<op::Broadcast>(h_const, Shape{1, 1}, AxisSet{0, 1});
auto weights_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto param2_2_reshape =
std::make_shared<op::Reshape>(weights_h2h, AxisVector{1, 0}, Shape{1, 4});
auto dot_2 = std::make_shared<op::Dot>(hidden_ht, param2_2_reshape);
auto bias_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_h2h = std::make_shared<op::Broadcast>(bias_h2h, Shape{1, 4}, AxisSet{0});
auto add_2 = std::make_shared<op::Add>(dot_2, broadcast_bias_h2h);
auto X = std::make_shared<op::Add>(add_2, add_1);
// construct forget gate
auto input_slice_0 = std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{1, 1});
auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
//ct-1 -> cell state
auto c_const = op::Constant::create(element::f32, Shape{}, {-1.0});
auto ct_1 = std::make_shared<op::Broadcast>(c_const, Shape{1, 1}, AxisSet{0, 1});
//auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
// construct input gate
auto input_slice_1 = std::make_shared<op::Slice>(X, Coordinate{0, 1}, Coordinate{1, 2});
auto input_gate = std::make_shared<op::Sigmoid>(input_slice_1);
auto input_slice_2 = std::make_shared<op::Slice>(X, Coordinate{0, 2}, Coordinate{1, 3});
auto tanh_1 = std::make_shared<op::Tanh>(input_slice_2);
auto multiply_input_gate_tanh_1 = std::make_shared<op::Multiply>(input_gate, tanh_1);
auto ct = std::make_shared<op::Add>(multiply_forget_gate_ct_1, multiply_input_gate_tanh_1);
// construct output gate
auto input_slice_3 = std::make_shared<op::Slice>(X, Coordinate{0, 3}, Coordinate{1, 4});
auto output_gate = std::make_shared<op::Sigmoid>(input_slice_3);
auto tanh_2 = std::make_shared<op::Tanh>(ct);
auto ht = std::make_shared<op::Multiply>(output_gate, tanh_2);
auto f = make_shared<Function>(
NodeVector{ht, ct},
op::ParameterVector{input_xt, weights_i2h, weights_h2h, bias_i2h, bias_h2h});
auto backend = runtime::Backend::create("GPU");
std::shared_ptr<runtime::Tensor> input_xt_t =
backend->create_tensor(element::f32, input_xt->get_shape());
copy_data(input_xt_t, std::vector<float>{1.0});
std::shared_ptr<runtime::Tensor> weights_i2h_t =
backend->create_tensor(element::f32, weights_i2h->get_shape());
copy_data(weights_i2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::Tensor> weights_h2h_t =
backend->create_tensor(element::f32, weights_h2h->get_shape());
copy_data(weights_h2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::Tensor> bias_i2h_t =
backend->create_tensor(element::f32, bias_i2h->get_shape());
copy_data(bias_i2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::Tensor> bias_h2h_t =
backend->create_tensor(element::f32, bias_h2h->get_shape());
copy_data(bias_h2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::Tensor> result_ht =
backend->create_tensor(element::f32, ht->get_shape());
std::shared_ptr<runtime::Tensor> result_ct =
backend->create_tensor(element::f32, ct->get_shape());
backend->call_with_validate(f,
{result_ht, result_ct},
{input_xt_t, weights_i2h_t, weights_h2h_t, bias_i2h_t, bias_h2h_t});
auto sig = [](float x) { return 1.0f / (1.0f + std::exp(-x)); };
float ct_val = -sig(-4.0f) + sig(-4.0f) * std::tanh(-4.0f);
float ht_val = sig(-4.0f) * std::tanh(ct_val);
EXPECT_TRUE(test::all_close(std::vector<float>{ht_val}, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(std::vector<float>{ct_val}, read_vector<float>(result_ct)));
}
TEST(gpu_fusion, fuse_2_layer_rnn_1lstm_analytic)
{
auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
auto weights_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto weights_i2h_reshape =
std::make_shared<op::Reshape>(weights_i2h, AxisVector{1, 0}, Shape{1, 4});
auto dot_1 = std::make_shared<op::Dot>(input_xt, weights_i2h_reshape);
auto bias_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_i2h = std::make_shared<op::Broadcast>(bias_i2h, Shape{1, 4}, AxisSet{0});
auto add_1 = std::make_shared<op::Add>(dot_1, broadcast_bias_i2h);
auto h_const = op::Constant::create(element::f32, Shape{}, {1.0});
auto hidden_ht = std::make_shared<op::Broadcast>(h_const, Shape{1, 1}, AxisSet{0, 1});
auto weights_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto param2_2_reshape =
std::make_shared<op::Reshape>(weights_h2h, AxisVector{1, 0}, Shape{1, 4});
auto dot_2 = std::make_shared<op::Dot>(hidden_ht, param2_2_reshape);
auto bias_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_h2h = std::make_shared<op::Broadcast>(bias_h2h, Shape{1, 4}, AxisSet{0});
auto add_2 = std::make_shared<op::Add>(dot_2, broadcast_bias_h2h);
auto X = std::make_shared<op::Add>(add_2, add_1);
// construct forget gate
auto input_slice_0 = std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{1, 1});
auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
//ct-1 -> cell state
auto c_const = op::Constant::create(element::f32, Shape{}, {1.0});
auto ct_1 = std::make_shared<op::Broadcast>(c_const, Shape{1, 1}, AxisSet{0, 1});
//auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
// construct input gate
auto input_slice_1 = std::make_shared<op::Slice>(X, Coordinate{0, 1}, Coordinate{1, 2});
auto input_gate = std::make_shared<op::Sigmoid>(input_slice_1);
auto input_slice_2 = std::make_shared<op::Slice>(X, Coordinate{0, 2}, Coordinate{1, 3});
auto tanh_1 = std::make_shared<op::Tanh>(input_slice_2);
auto multiply_input_gate_tanh_1 = std::make_shared<op::Multiply>(input_gate, tanh_1);
auto ct = std::make_shared<op::Add>(multiply_forget_gate_ct_1, multiply_input_gate_tanh_1);
// construct output gate
auto input_slice_3 = std::make_shared<op::Slice>(X, Coordinate{0, 3}, Coordinate{1, 4});
auto output_gate = std::make_shared<op::Sigmoid>(input_slice_3);
auto tanh_2 = std::make_shared<op::Tanh>(ct);
auto ht = std::make_shared<op::Multiply>(output_gate, tanh_2);
// next lstm layer
auto weights_i2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto weights_i2h_0_reshape_0 =
std::make_shared<op::Reshape>(weights_i2h_0, AxisVector{1, 0}, Shape{1, 4});
auto dot_1_0 = std::make_shared<op::Dot>(ht, weights_i2h_0_reshape_0);
auto bias_i2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_i2h_0_0 =
std::make_shared<op::Broadcast>(bias_i2h_0, Shape{1, 4}, AxisSet{0});
auto add_1_0 = std::make_shared<op::Add>(dot_1_0, broadcast_bias_i2h_0_0);
auto h_const_0 = op::Constant::create(element::f32, Shape{}, {1.0});
auto hidden_ht_0 = std::make_shared<op::Broadcast>(h_const_0, Shape{1, 1}, AxisSet{0, 1});
auto weights_h2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
auto param2_2_reshape_0 =
std::make_shared<op::Reshape>(weights_h2h_0, AxisVector{1, 0}, Shape{1, 4});
auto dot_2_0 = std::make_shared<op::Dot>(hidden_ht_0, param2_2_reshape_0);
auto bias_h2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto broadcast_bias_h2h_0_0 =
std::make_shared<op::Broadcast>(bias_h2h_0, Shape{1, 4}, AxisSet{0});
auto add_2_0 = std::make_shared<op::Add>(dot_2_0, broadcast_bias_h2h_0_0);
auto X_0 = std::make_shared<op::Add>(add_2_0, add_1_0);
// construct forget gate
auto input_slice_0_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 0}, Coordinate{1, 1});
auto forget_gate_0 = std::make_shared<op::Sigmoid>(input_slice_0_0);
//ct-1 -> cell state
auto c_const_0 = op::Constant::create(element::f32, Shape{}, {1.0});
auto ct_1_0 = std::make_shared<op::Broadcast>(c_const_0, Shape{1, 1}, AxisSet{0, 1});
//auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
auto multiply_forget_gate_0_ct_1_0 = std::make_shared<op::Multiply>(forget_gate_0, ct_1_0);
// construct input gate
auto input_slice_1_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 1}, Coordinate{1, 2});
auto input_gate_0 = std::make_shared<op::Sigmoid>(input_slice_1_0);
auto input_slice_2_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 2}, Coordinate{1, 3});
auto tanh_1_0 = std::make_shared<op::Tanh>(input_slice_2_0);
auto multiply_input_gate_0_tanh_1_0 = std::make_shared<op::Multiply>(input_gate_0, tanh_1_0);
auto ct_0 =
std::make_shared<op::Add>(multiply_forget_gate_0_ct_1_0, multiply_input_gate_0_tanh_1_0);
// construct output gate
auto input_slice_3_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 3}, Coordinate{1, 4});
auto output_gate_0 = std::make_shared<op::Sigmoid>(input_slice_3_0);
auto tanh_2_0 = std::make_shared<op::Tanh>(ct_0);
auto ht_0 = std::make_shared<op::Multiply>(output_gate_0, tanh_2_0);
auto f = make_shared<Function>(NodeVector{ht_0, ct_0},
op::ParameterVector{input_xt,
weights_i2h,
weights_h2h,
bias_i2h,
bias_h2h,
weights_i2h_0,
weights_h2h_0,
bias_i2h_0,
bias_h2h_0});
auto backend = runtime::Backend::create("GPU");
auto params = f->get_parameters();
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> arg_tensors;
for (shared_ptr<op::Parameter> param : params)
{
vector<float> tensor_vals(shape_size(param->get_shape()), 1.0f);
auto tensor = backend->create_tensor(element::f32, param->get_shape());
copy_data(tensor, tensor_vals);
arg_tensors.push_back(tensor);
}
std::shared_ptr<runtime::Tensor> result_ht =
backend->create_tensor(element::f32, ht->get_shape());
std::shared_ptr<runtime::Tensor> result_ct =
backend->create_tensor(element::f32, ct->get_shape());
backend->call_with_validate(f, {result_ht, result_ct}, arg_tensors);
//EXPECT_EQ(1, count_ops_of_type<op::gpu::Rnn>(f));
auto sig = [](float x) { return 1.0f / (1.0f + std::exp(-x)); };
float kernel = 4.0f;
float ct_val_first = sig(kernel) + sig(kernel) * std::tanh(kernel);
float ht_val_first = sig(kernel) * std::tanh(ct_val_first);
kernel = 3.0f + ht_val_first;
float ct_val_second = sig(kernel) + sig(kernel) * std::tanh(kernel);
float ht_val_second = sig(kernel) * std::tanh(ct_val_second);
EXPECT_TRUE(test::all_close(std::vector<float>{ht_val_second}, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(std::vector<float>{ct_val_second}, read_vector<float>(result_ct)));
}
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto gpu_results = execute(gpu_f, args, "GPU");
for (size_t i = 0; i < gpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(gpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto gpu_results = execute(gpu_f, args, "GPU");
for (size_t i = 0; i < gpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(gpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto gpu_results = execute(gpu_f, args, "GPU");
for (size_t i = 0; i < gpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(gpu_results.at(i), int_results.at(i), 1.0e-3f, 1.0e-3f));
}
}
TEST(gpu_fusion, fuse_rnn_across_layer)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_1timestep.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ref_rnn_count = 1;
auto rnn_count = count_ops_of_type<op::gpu::Rnn>(func);
EXPECT_EQ(ref_rnn_count, rnn_count);
}
TEST(gpu_fusion, fuse_rnn_across_2layer_1timestep)
{
const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto gpu_results = execute(gpu_f, args, "GPU");
// TODO (pruthvi): Enable this after fixing failing
// mxnet rnn unit tests
// EXPECT_EQ(1, count_ops_of_type<op::gpu::Rnn>(gpu_f));
for (size_t i = 0; i < gpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(gpu_results.at(1), int_results.at(1), 1.0e-4f, 1.0e-4f));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment