Commit 4df5ea8b authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

RNN fusion (inference) (#1459)

* Add op::Sigmoid to nvgpu.

* Bring rnn fusion and concat passes over into GPU from IA. This is a temporary move until generalization and gpu specification can occur.

* Add LSTM fusion and cudnn inference kernel. Next need recurrent fusion and layer fusion.

* Formatting

* Removed unecessary extra output from LSTM op (rnn with seq. length = 1, so y = hy).

* Add RNN fusion of LSTM cells within a recurrent layer.

* Formatting.

* Add fusion across RNN layers.

* Formatting.

* Add algebraic simplification.

* Added rnn fusion tests.

* Updated conditional on LSTM fusion to better distinguish bound nodes as ht vs xt.

* Formatting.

* Removed print statements.

* Formatting.

* Committing missing file.

* Remove concat inputs pass and mkldnn references.

* fix cmake paths

* conflict resolution with merge from master.

* remove explicit lstm op support. bare LSTM ops are converted to RNN ops for emission.

* Formatting.

* Use NGRAPH_ASSERT. Formatting of intel copyright.

* Add check on the feature size (shape) of the recurrent (hidden) input and cell state, to ensure they are the same size.

* fix wrong rnn header

* Formatting.

* Add back lstm op to dispatch table.

* Added RNN test which shows cudnn rnn kernel is not producing correct results.

* With update to AlgSimpl. to simplify concat-reshape-slice, the check modifed in this commit needed to be relaxed.

* Bug fix in parameter tensor packing.

* Alias third output element of RNN for cell state (bug fix).

* Resolve numerical correctness issue with negative values in RNN (bug fix).
Add minimal test to evaluate LSTM and compare with values calculated by hand.

* Add tensor parameter sizes to kernel hash as
they are kernel-specific.

* Add 2 layer lstm fusion test against by-hand solution.

* Export param concatenation to graph for cudnn kernel at both the single rnn layer and multi-layer.

* Formatting.

* Finishing touches after merge: add support for macro expansed dispatch via op_tbl.

* Simplify macro support for gpu ops.

* Add CUDNN_VERSION >= 7200 defguards for RNN fusion.
Need to decide how to notify user of increased performance with >= 7200.

* Revert lstm_analytic test to explicitly copy data to tensor params.

* Removed namespace arg from NGRAPH_GPU_OP.

* Refactored macros to different header so op_tbl only contains op list.

* Defguard on cudnn_descriptor<cudnnRNNDataDescriptor_t>.

* doubles -> floats

* Reorg. pass asserts, prepare to replace with non-throwing pass failures.

* Remove Lstm op and replace it with Rnn.

* Format

* Utilize RETURN_IF_FALSE in rnn pass to avoid any RT asserts.
Note that falling back to raw (no passes) graph for 2rnn_3lstm json from mxnet models
results in a double free inside of the memory layout pass. Appears to be a bug
in Reshape pass through.

* Removed print statements. Add check on input data and recurrent data.

* Don't reuse memory for non-destructive ops.

* Add back Rnn test.

* Formatting.

* Clean up comments.

* Update test per review comments.
parent f04503b6
......@@ -41,6 +41,8 @@ set(SRC
pass/gpu_layout.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp
op/rnn.cpp
)
if (NGRAPH_GPU_ENABLE)
......
......@@ -136,6 +136,20 @@ namespace ngraph
}
};
#if CUDNN_VERSION >= 7200
template <>
struct cudnn_descriptor<cudnnRNNDataDescriptor_t>
{
static void create(cudnnRNNDataDescriptor_t& desc)
{
CUDNN_SAFE_CALL(cudnnCreateRNNDataDescriptor(&desc));
}
static void destroy(cudnnRNNDataDescriptor_t& desc)
{
CUDNN_SAFE_CALL_NO_THROW(cudnnDestroyRNNDataDescriptor(desc));
}
};
#endif
template <>
struct cudnn_descriptor<cudnnPoolingDescriptor_t>
{
......
This diff is collapsed.
......@@ -35,6 +35,7 @@
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
namespace ngraph
{
......@@ -61,6 +62,7 @@ namespace ngraph
size_t build_primitive(const op::MaxPool* node);
size_t build_primitive(const op::Max* node);
size_t build_primitive(const op::Min* node);
size_t build_primitive(const op::gpu::Rnn* node);
public:
enum class Prop
......@@ -149,10 +151,18 @@ namespace ngraph
tensor_descriptor_from_shape(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnTensorDescriptor_t&
get_nd_tensor_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnFilterDescriptor_t&
get_cudnn_filter_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnFilterDescriptor_t&
get_nd_filter_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnConvolutionDescriptor_t&
get_cudnn_convolution_descriptor(const Shape& padding,
const Strides& window_movement_strides,
......
......@@ -108,6 +108,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp"
......@@ -124,7 +125,7 @@ function<void(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Nod
// ...
#define NGRAPH_OP(a, b) {type_index(typeid(b::a)), runtime::gpu::GPU_Emitter::emit_##a},
static const map<type_index, function<void(EMIT_ARGS)>> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(type_index(typeid(node)));
......@@ -1297,6 +1298,24 @@ void runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
writer.block_end();
}
#if CUDNN_VERSION >= 7200
void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
{
auto rnn = static_cast<const ngraph::op::gpu::Rnn*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_primitive(rnn);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
}
#endif
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
......
......@@ -39,7 +39,7 @@ namespace ngraph
// static void emit_Abs(EMIT_ARGS);
// static void emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static void emit_##a(EMIT_ARGS);
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
#undef NGRAPH_OP
template <typename T>
......
......@@ -103,6 +103,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
......@@ -111,7 +112,9 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
......@@ -551,19 +554,26 @@ void runtime::gpu::GPU_ExternalFunction::compile()
m_function_name = m_function->get_name();
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
#if CUDNN_VERSION >= 7200
// recurrent network fusion
m_pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
m_pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
m_pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
#else
m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
m_pass_manager.register_pass<ngraph::pass::LikeReplacement>();
m_pass_manager
.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
m_pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
m_pass_manager.register_pass<ngraph::pass::Liveness>();
m_pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
GPUAllocator allocator = m_shared_context->m_primitive_emitter->get_memory_allocator();
m_pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
allocator, m_tensor_memory_buffers);
*allocator, m_tensor_memory_buffers);
std::string common_function_string;
auto femitter = bind(&ngraph::runtime::gpu::GPU_ExternalFunction::emit_op_as_function,
this,
......@@ -571,7 +581,6 @@ void runtime::gpu::GPU_ExternalFunction::compile()
placeholders::_2);
m_pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, m_node_function_map, common_function_string);
string dump_filename = file_util::path_join(s_output_dir, m_function_name + "_ops.txt");
m_pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
......@@ -590,7 +599,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
emit_functions();
// allocate device buffers for primitive arguments and workspace
allocator.close();
allocator->close();
m_shared_context->m_primitive_emitter->allocate_primitive_memory();
string code = m_writer.get_code();
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/op_tbl.hpp"
#if CUDNN_VERSION >= 7200
NGRAPH_OP(Rnn, ngraph::op::gpu)
#endif
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::gpu::Rnn::copy_with_new_args(const NodeVector& new_args) const
{
NGRAPH_ASSERT(new_args.size() == 4) << "Incorrect number of new arguments";
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_src_layer_feature_size,
m_src_iter_feature_size,
m_direction,
m_num_fused_layers);
}
op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> params,
std::shared_ptr<Node> state_iter,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int direction,
const int num_fused_layers)
: Op("Rnn", {src_layer, src_iter, params, state_iter})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_src_layer_feature_size(src_layer_feature_size)
, m_src_iter_feature_size(src_iter_feature_size)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
{
NGRAPH_ASSERT(src_layer->get_shape().size() == 2) << "src_layer doesnt have a rank 2";
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / num_timesteps);
NGRAPH_ASSERT(shape_size(src_layer->get_shape()) ==
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
<< "src_layer size is not equal t*n*c";
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
set_output_size(3);
set_output_type(0,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
set_output_type(
1,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
set_output_type(
2,
src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_direction * m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
namespace gpu
{
// This is RNN op, which is formed by the fusion of multiple RNN cells ( LSTM/ GRU/ vanilla RNN)
// across multiple time slices
// INPUTS:
// [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{num_fused_layers*batch_size, feature_size}
// [1] - recurrent input tensor ht_1 of Shape{sequence length*batch_size, feature_size}
// [2] - flat parameter tensor consisting of weights and biases for each layer
// {W_x^0 | W_h^0 | W_x^1 | W_h^1 | ... | B_x^0 | B_h^0 | B_x^1 | B_h^1 }
// [3] - recurrent cell state tensor ct_1 with same shape as ht_1
// number_of_timesteps - number of unrolled cells up to timestep t.
// num_gates_per_cell - number of gates per RNN cell, LSTM = 4, GRU = 3, vanilla RNN = 1
// src_sequence_length - this will be same as number_of_timesteps
// src_layer_feature_size - feature size w.r.to input tensor
// src_iter_feature_size - feature size w.r.to hidden state
// num_cell_states - number of recurrent state tensor states , LSTM = 2, GRU = 1, vanilla RNN = 1
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, sequence-wise output tensor with shape (sequence_length*batch_size, feature_size) .
// [1] - hf, layer-wise output tensor with shape (num_fused_layers*batch_size, feature_size) .
// [2] - ct output cell state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
class Rnn : public Op
{
public:
Rnn(std::shared_ptr<Node> src_layer, // x
std::shared_ptr<Node> src_iter, // hx
std::shared_ptr<Node> params,
std::shared_ptr<Node> state_iter, // cx
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int direction,
const int num_fused_layers);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
private:
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_direction;
int m_num_fused_layers;
};
}
}
}
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace pass
{
class LSTMFusion;
class RNNFusion;
class MultiLayerRNNFusion;
}
}
}
}
class ngraph::runtime::gpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{
public:
LSTMFusion()
: GraphRewrite()
{
construct_sigmoid();
construct_lstm_fprop();
}
private:
void construct_sigmoid();
void construct_lstm_fprop();
};
class ngraph::runtime::gpu::pass::RNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
RNNFusion()
: RecurrentGraphRewrite()
{
construct_rnn_lstm_fprop();
}
private:
void construct_rnn_lstm_fprop();
};
class ngraph::runtime::gpu::pass::MultiLayerRNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
MultiLayerRNNFusion()
: RecurrentGraphRewrite()
{
construct_multi_layer_rnn_fusion_fprop();
}
private:
void construct_multi_layer_rnn_fusion_fprop();
};
......@@ -73,7 +73,7 @@ if(NGRAPH_CPU_ENABLE)
endif()
if(NGRAPH_GPU_ENABLE)
set(SRC ${SRC} cudnn.cpp gpu_test.cpp)
set(SRC ${SRC} cudnn.cpp gpu_test.cpp gpu_fusion.cpp)
endif()
foreach(TEST_CONFIG ${UNIT_TEST_CONFIG_LIST})
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment