Commit 5df0e17e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into dex2

parents c829a9c7 f75b8006
......@@ -68,16 +68,16 @@ public:
std::string generate_temporary_name(std::string prefix = "tempvar");
void block_begin(std::string block_prefix = "")
void block_begin()
{
*this << "{" << block_prefix << "\n";
*this << "{\n";
indent++;
}
void block_end(std::string block_suffix = "")
void block_end()
{
indent--;
*this << "}" << block_suffix << "\n";
*this << "}\n";
}
private:
......
......@@ -265,7 +265,6 @@ void codegen::StaticCompiler::add_header_search_path(const string& p)
vector<string> paths = split(p, ';');
for (const string& path : paths)
{
NGRAPH_INFO << path;
if (!contains(m_extra_search_path_list, path))
{
m_extra_search_path_list.push_back(path);
......
......@@ -344,9 +344,10 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
......
......@@ -242,10 +242,12 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
auto ht_output = std::make_shared<op::GetOutputElement>(lstm, 0);
auto ct_output = std::make_shared<op::GetOutputElement>(lstm, 1);
if (lstm->get_outputs().at(0).get_inputs().size() != 2)
{
throw ngraph_error("Lstm node doesnt have two outputs");
}
// Now identify the nodes which consumes the output of LSTM nodes
// and replace them accordingly
std::vector<std::shared_ptr<Node>> new_args;
// find the user's for {ht|ct} and replace them with lstm_goe_1
for (auto node : pattern_map[ct_label]->get_users())
{
......@@ -280,8 +282,15 @@ static std::shared_ptr<ngraph::Node>
if (concat_all)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[0]);
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
if (node_labels.size() > 1)
{
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
}
else
{
return node_labels[0];
}
}
// src_iter -> concatenate ht_1|ct_1 of the first LSTM cells belonging to same RNN layer
......@@ -437,7 +446,15 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
if ((src_layer->get_arguments().size()) != sequence_len)
if ((src_layer->get_arguments().size()) != sequence_len &&
!std::dynamic_pointer_cast<op::Parameter>(src_layer))
{
throw ngraph_error(
"number of lstm inputs captured in the RNN fusion is not equal to "
"src_sequence_length");
}
if (std::dynamic_pointer_cast<op::Parameter>(src_layer) && sequence_len != 1)
{
throw ngraph_error(
"number of lstm inputs captured in the RNN fusion is not equal to "
......@@ -491,7 +508,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(num_of_lstm_matched, nullptr);
auto rnn_ht_out = std::make_shared<op::GetOutputElement>(rnn, 0);
auto rnn_ct_out = std::make_shared<op::GetOutputElement>(rnn, 1);
auto rnn_ht_ct_out = std::make_shared<op::GetOutputElement>(rnn, 1);
//slice the rnn ht's
size_t start_index = 0;
......@@ -547,24 +564,43 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
if (goe_node->get_n() == 0)
{
goe_0 = goes->get_node();
for (auto goe0_user : goe_0->get_users())
{
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() &&
!is_unreachable(goe0_user))
{
lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
NGRAPH_DEBUG << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
<< " goe0_user " << goe0_user->get_name() << " ";
}
}
}
}
for (auto goe0_user : goe_0->get_users())
{
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() &&
!is_unreachable(goe0_user))
// we need to only check the last LSTM cell Ct user and replace if needed.
if ((index == 0) && (goe_node->get_n() == 1))
{
lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
NGRAPH_DEBUG << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
<< " goe0_user " << goe0_user->get_name() << " ";
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
auto ht_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out,
Coordinate{0, 0},
Coordinate{static_cast<unsigned long>(batch_size),
static_cast<unsigned long>(feature_size)});
auto ct_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out,
Coordinate{static_cast<unsigned long>(batch_size), 0},
Coordinate{static_cast<unsigned long>(2 * batch_size),
static_cast<unsigned long>(feature_size)});
// check if the last LSTM cell has any consumers
auto n_time_step_lstm_ct_goe = goes->get_node();
ngraph::replace_node(n_time_step_lstm_ct_goe, ct_slice);
}
}
}
//now go through the lstm consumers and replace them with the slice
//now go through the lstm goe_0 consumers and replace them with the slice
for (auto& node : lstm_goe0_user)
{
for (size_t i = 0; i < node->get_input_size(); i++)
......@@ -577,6 +613,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
}
}
}
NGRAPH_DEBUG << "End of recurrent fusion call back "
<< "matched_node: " << m.get_match_root()->get_name();
return true;
......@@ -588,3 +625,213 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
lstm_node_label, rpattern_ct_1, empty_correlated_matches, callback);
this->add_matcher(m);
}
static std::shared_ptr<Node>
compute_multi_layer_rnn_inputs(const std::shared_ptr<pattern::op::Label>& rnn_label,
pattern::RecurrentMatcher& m)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_label);
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
}
void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
{
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
auto slice_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Slice>(n));
};
auto src_slice = std::make_shared<pattern::op::Skip>(src_layer_label, slice_pred);
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto weights_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto bias_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
size_t ref_number_of_timesteps = 3;
size_t ref_number_of_gates_per_cell = 4;
size_t ref_src_seq_length = 3;
size_t ref_src_layer_feature_size = 100;
size_t ref_feature_size = 100;
size_t ref_num_rnn_cell_states = 2;
size_t ref_rnn_direction = 1;
size_t ref_num_of_rnn_fused_layer = 1;
auto ref_rnn_node = std::make_shared<op::Rnn>(src_slice,
src_iter_label,
weights_layer_label,
weights_iter_label,
bias_label,
ref_number_of_timesteps,
ref_number_of_gates_per_cell,
ref_src_seq_length,
ref_src_layer_feature_size,
ref_feature_size,
ref_num_rnn_cell_states,
ref_rnn_direction,
ref_num_of_rnn_fused_layer);
NodeVector ht_slice_per_timestep;
auto rnn_ht_out = std::make_shared<op::GetOutputElement>(ref_rnn_node, 0);
auto rnn_ht_label =
std::make_shared<pattern::op::Label>(rnn_ht_out, nullptr, NodeVector{rnn_ht_out});
auto rnn_ct_out = std::make_shared<op::GetOutputElement>(ref_rnn_node, 1);
pattern::recurrent_graph_rewrite_callback callback = [src_layer_label,
src_iter_label,
weights_layer_label,
weights_iter_label,
bias_label,
rnn_ht_label](
pattern::RecurrentMatcher& m) {
if (m.get_number_of_recurrent_matches() <= 1)
{
return false;
}
auto src_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
auto rnn_ht_out_nodes = m.get_bound_nodes_for_pattern(rnn_ht_label);
auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << " In Recurrent multi layer RNN fusion callback ";
NGRAPH_DEBUG << "Number of RNN's Matched: " << number_of_rnn_cell_matched;
NGRAPH_DEBUG << "matched_root: " << m.get_match_root()->get_name();
NGRAPH_DEBUG << "src_layer_node: " << src_nodes[0]->get_name();
// we can fuse across different RNN layers only if SLC == DLC
for (size_t i = 0; i < number_of_rnn_cell_matched; i++)
{
if (src_nodes[i]->get_shape()[1] != rnn_ht_out_nodes[i]->get_shape()[1])
{
NGRAPH_DEBUG << "Not fusing since the feature sizes for xt and ht_1 dont match";
return false;
}
}
// we just need to capture the input symbols {x0 | x1.....| xt} of the first lstm layer
// the intermediate inputs for the next layer will be computed by the MKLDNN
auto src_layer_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
auto src_layer = src_layer_nodes[src_layer_nodes.size() - 1];
auto src_iter = compute_multi_layer_rnn_inputs(src_iter_label, m);
auto weights_layer = compute_multi_layer_rnn_inputs(weights_layer_label, m);
auto weights_iter = compute_multi_layer_rnn_inputs(weights_iter_label, m);
auto bias = compute_multi_layer_rnn_inputs(bias_label, m);
std::shared_ptr<op::Rnn> rnn_node = nullptr;
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label)[0]->get_arguments())
{
if (std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input))
{
rnn_node = std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input);
}
else
{
throw ngraph_error("Input for RNN output GetOuputElement Op should be RNN");
}
}
size_t num_time_steps = rnn_node->get_num_timesteps();
size_t num_gates_in_lstm = rnn_node->get_gates_per_cell();
size_t batch_size = rnn_node->get_batch_size();
size_t sequence_len = rnn_node->get_src_sequence_length();
size_t src_layer_feature_size = rnn_node->get_src_layer_feature_size();
size_t feature_size = rnn_node->get_src_iter_feature_size();
size_t num_rnn_cell_states = rnn_node->get_num_cell_states();
size_t rnn_direction = rnn_node->get_direction();
size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
NGRAPH_DEBUG << "weights_layer: " << join(weights_layer->get_shape());
NGRAPH_DEBUG << "weights_iter: " << join(weights_iter->get_shape());
NGRAPH_DEBUG << "bias: " << join(bias->get_shape());
NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
if ((src_layer->get_arguments().size()) != rnn_node->get_num_timesteps() &&
!std::dynamic_pointer_cast<op::Parameter>(src_layer))
{
throw ngraph_error(
" input symbols for the layer fused RNN op, should be captured only for the first "
"layer");
}
if (std::dynamic_pointer_cast<op::Parameter>(src_layer) &&
rnn_node->get_num_timesteps() != 1)
{
throw ngraph_error(
" input symbols for the layer fused RNN op, should be captured only for the first "
"layer");
}
if ((src_iter->get_arguments().size()) != num_fused_rnn_layers)
{
throw ngraph_error(
"number of states(ht_1|ct_1) for RNN op in the layer fusion is not equal to num of "
"fused_rnn_layers");
}
if ((weights_layer->get_arguments().size()) != num_fused_rnn_layers)
{
throw ngraph_error(
"weights w.r.to input symbols of RNN op in the layer fusion is not equal to num of "
"fused_rnn_layers");
}
if ((weights_iter->get_arguments().size()) != num_fused_rnn_layers)
{
throw ngraph_error(
"weights w.r.to cell states of RNN op in the layer fusion is not equal to num of "
"fused_rnn_layers");
}
if ((bias->get_arguments().size()) != num_fused_rnn_layers)
{
throw ngraph_error(
"bias of RNN op in the layer fusion is not equal to num of fused_rnn_layers");
}
auto rnn = std::make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
bias,
num_time_steps,
num_gates_in_lstm,
sequence_len,
src_layer_feature_size,
feature_size,
num_rnn_cell_states,
rnn_direction,
num_fused_rnn_layers);
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto layer_rnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1);
// find the last RNN cell GOE's and replace them with the layer fused RNN GOE.
for (auto& rnn_goes : rnn_node->get_users())
{
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
if (rnn_goe_node->get_n() == 0)
{
ngraph::replace_node(rnn_goes, layer_rnn_ht);
}
else if (rnn_goe_node->get_n() == 1)
{
ngraph::replace_node(rnn_goes, layer_rnn_ht_ct);
}
}
}
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
rnn_ht_label, src_layer_label, empty_correlated_matches, callback);
this->add_matcher(m);
}
......@@ -29,6 +29,7 @@ namespace ngraph
{
class LSTMFusion;
class RNNFusion;
class MultiLayerRNNFusion;
}
}
}
......@@ -61,3 +62,16 @@ public:
private:
void construct_rnn_lstm_fprop();
};
class ngraph::runtime::cpu::pass::MultiLayerRNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
MultiLayerRNNFusion()
: RecurrentGraphRewrite()
{
construct_multi_layer_rnn_fusion_fprop();
}
private:
void construct_multi_layer_rnn_fusion_fprop();
};
......@@ -268,8 +268,8 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
unsigned int rank = static_cast<unsigned int>(input_shape.size());
unsigned int nthreads = static_cast<unsigned int>(shape_size(input_shape));
uint32_t rank = static_cast<uint32_t>(input_shape.size());
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
GPUShape pad_below(input_shape.size(), 0);
GPUShape pad_interior(input_shape.size(), 1);
......@@ -286,14 +286,14 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_strides = allocator.reserve_argspace(
input_strides.data(), input_strides.size() * sizeof(unsigned int));
size_t idx_output_strides = allocator.reserve_argspace(
output_strides.data(), output_strides.size() * sizeof(unsigned int));
size_t idx_input_strides =
allocator.reserve_argspace(input_strides.data(), input_strides.size() * sizeof(uint32_t));
size_t idx_output_strides =
allocator.reserve_argspace(output_strides.data(), output_strides.size() * sizeof(uint32_t));
size_t idx_padding_below =
allocator.reserve_argspace(pad_below.data(), pad_below.size() * sizeof(unsigned int));
allocator.reserve_argspace(pad_below.data(), pad_below.size() * sizeof(uint32_t));
size_t idx_padding_interior =
allocator.reserve_argspace(pad_interior.data(), pad_interior.size() * sizeof(unsigned int));
allocator.reserve_argspace(pad_interior.data(), pad_interior.size() * sizeof(uint32_t));
// create the launch primitive
std::unique_ptr<gpu::primitive> pad_dynamic(new gpu::primitive{[=](void** inputs,
......@@ -1015,7 +1015,7 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const GPURuntimeContext* c
args_list[6] = &nthreads;
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
static_cast<uint32_t>(nthreads),
1,
1, // grid dim
1,
......
......@@ -285,19 +285,19 @@ void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
const std::array<std::string, 2>& data_types)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, unsigned int* input_strides, unsigned int* output_strides, "
"unsigned int* padding_below, unsigned int* "
"padding_interior, unsigned int rank, unsigned int n)\n";
<< data_types[1] << "* out, uint32_t* input_strides, uint32_t* output_strides, "
"uint32_t* padding_below, uint32_t* "
"padding_interior, uint32_t rank, uint32_t n)\n";
writer.block_begin();
{
writer << "unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "unsigned int output_idx = 0;\n";
writer << "unsigned int input_idx = tid;\n";
writer << "uint32_t output_idx = 0;\n";
writer << "uint32_t input_idx = tid;\n";
writer << "for(unsigned int i = 0; i < rank; i++)\n";
writer << "for(uint32_t i = 0; i < rank; i++)\n";
writer.block_begin();
{
writer << "output_idx += (input_idx / input_strides[i] * padding_interior[i] + "
......
......@@ -47,7 +47,7 @@ void runtime::gpu::emit_onehot(const std::string& name,
void* args_list[] = {&in, &out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
......@@ -84,7 +84,7 @@ void runtime::gpu::emit_reshape(const std::string& name,
void* args_list[] = {&in, &out, &input_strides, &trans_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
......@@ -124,7 +124,7 @@ void runtime::gpu::emit_slice(const std::string& name,
void* args_list[] = {
&in, &out, &input_strides, &lower_bounds, &slice_strides, &output_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
......@@ -161,7 +161,7 @@ void runtime::gpu::emit_reverse(const std::string& name,
void* args_list[] = {&in, &out, &input_shapes, &reverse_axes, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
......
......@@ -118,7 +118,7 @@ namespace ngraph
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
......@@ -136,7 +136,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_ADD,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -193,7 +193,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto input_shape = args[0].get_shape();
Shape input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
if (pad_required || is_deconvolution)
{
input_shape_padded = get_padded_shape(
......@@ -314,7 +314,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
}
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
if (pad_required || is_deconvolution)
{
output_shape_padded = get_padded_shape(
......@@ -467,7 +467,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto input_shape = args[0].get_shape();
auto input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
if (pad_required || is_deconvolution)
{
input_shape_padded = get_padded_shape(
......@@ -549,7 +549,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto& first = (arg0_shape.empty() ? args[0] : args[1]);
auto& second = (arg0_shape.empty() ? args[1] : args[0]);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << second.get_size() << ";\n";
writer << "CUBLAS_SAFE_CALL(cublasScopy("
<< "*ctx->cublas_handle,"
......@@ -566,7 +566,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * sizeof(float));\n";
writer.block_end();
......@@ -586,7 +586,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
"arg0 and arg1 shape does not match for dot.");
}
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "CUBLAS_SAFE_CALL(cublasSdot("
<< "*ctx->cublas_handle," << args[0].get_size() << ","
<< args[0].get_name() << ","
......@@ -598,7 +598,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
(dot->get_reduction_axes_count() == 1))
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0;\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
......@@ -668,7 +668,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
}
// GEMM Call
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "const float alpha = 1.0;\n";
writer << "const float beta = 0.0;\n";
writer << "int m = " << m << ";\n";
......@@ -703,7 +703,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
......@@ -721,7 +721,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MAX,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -741,7 +741,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
......@@ -759,7 +759,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MIN,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -779,7 +779,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = -1.0, alpha2 = 0, beta = 0;
......@@ -797,7 +797,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_ADD,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -825,7 +825,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
// broadcast axes is empty, do a copy
if (axes.empty())
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
kernel::emit_memcpyDtD(writer, out[0], args[0]);
writer.block_end();
return;
......@@ -867,7 +867,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
block_size += block_strides[i];
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer << "int num_inputs = " << args.size() << ";\n";
......@@ -910,7 +910,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
auto reshape = static_cast<const op::Reshape*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
......@@ -996,6 +996,45 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
auto function_call = static_cast<const ngraph::op::FunctionCall*>(node);
shared_ptr<Function> function = function_call->get_functions()[0];
writer.block_begin();
{
std::vector<string> input_names;
std::vector<string> output_names;
for (const runtime::gpu::GPU_TensorViewWrapper& input : args)
{
input_names.push_back(input.get_name());
}
for (const runtime::gpu::GPU_TensorViewWrapper& output : out)
{
output_names.push_back(output.get_name());
}
writer << "void* args[] =\n";
writer.block_begin();
writer << "\n" << join(input_names, ",\n");
writer.block_end();
writer << ";\n";
writer << "void* out[] =\n";
writer.block_begin();
writer << "\n" << join(output_names, ",\n");
writer.block_end();
writer << ";\n";
writer << "\n";
writer << function->get_name() << "(args, out, ctx);\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Slice)
{
......@@ -1013,7 +1052,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
const auto input_strides = row_major_strides(arg_shape);
const auto output_strides = row_major_strides(result_shape);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
......@@ -1077,7 +1116,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
reverse_axes_flag[a] = 1;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
if (out[0].get_size() == 1)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
......@@ -1112,11 +1151,6 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Multiply)
{
......@@ -1124,7 +1158,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
......@@ -1142,7 +1176,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MUL,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -1173,7 +1207,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
repeat_size *= result_shape[i];
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * " << out[0].get_element_type().size() << ");\n";
writer << "runtime::gpu::emit_onehot(\"" << node->description() << "\", {\""
......@@ -1193,7 +1227,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
return;
}
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 0, beta = 0;
......@@ -1211,7 +1245,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_SQRT,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
......@@ -1227,7 +1261,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Result)
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
kernel::emit_memcpyDtD(writer, out[0], args[0]);
writer.block_end();
return;
......@@ -1237,7 +1271,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
void GPU_Emitter::EMITTER_DECL(ngraph::op::Max)
{
const ngraph::op::Max* max_op = static_cast<const ngraph::op::Max*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1285,7 +1319,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
void GPU_Emitter::EMITTER_DECL(ngraph::op::Min)
{
const ngraph::op::Min* min_op = static_cast<const ngraph::op::Min*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1333,7 +1367,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
void GPU_Emitter::EMITTER_DECL(ngraph::op::Sum)
{
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1372,7 +1406,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
void GPU_Emitter::EMITTER_DECL(ngraph::op::Product)
{
const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1432,7 +1466,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{TI(ngraph::op::Maximum), CUDNN_REDUCE_TENSOR_MAX},
{TI(ngraph::op::Minimum), CUDNN_REDUCE_TENSOR_MIN}};
const ngraph::op::Reduce* reduce_op = static_cast<const ngraph::op::Reduce*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1521,7 +1555,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
const ngraph::op::ReduceWindow* reduce_window_op =
static_cast<const ngraph::op::ReduceWindow*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
if (out[0].get_size() != 0)
{
......@@ -1620,7 +1654,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
void GPU_Emitter::EMITTER_DECL(ngraph::op::Pad)
{
auto pad = static_cast<const ngraph::op::Pad*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto input_shape = args[0].get_shape();
auto output_shape = out[0].get_shape();
......@@ -1653,7 +1687,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
// assumes NC{d1,d2,...} format
auto max_pool = static_cast<const ngraph::op::MaxPool*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
......@@ -1785,7 +1819,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::MaxPoolBackprop)
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto fp_input_shape = out[0].get_shape();
......@@ -1843,7 +1877,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
......@@ -1879,7 +1913,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
......@@ -1904,7 +1938,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
writer << "runtime::gpu::cuda_memcpyDtH(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
......@@ -1959,7 +1993,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
// assumes NC{d1,d2,...} format
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
......@@ -2034,7 +2068,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::AvgPoolBackprop)
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto output_shape = out[0].get_shape();
......@@ -2079,7 +2113,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
// assumes NC{d1,d2,...} format
auto rep_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto& input_shape = args[0].get_shape();
auto& source_shape = args[1].get_shape();
......@@ -2129,7 +2163,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Softmax)
{
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
auto tensor_shape = args[0].get_shape();
......
......@@ -77,7 +77,7 @@ namespace ngraph
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
writer.block_begin(" // " + node->get_name());
writer.block_begin();
{
std::vector<std::string> dtypes;
for (auto& arg : args)
......
......@@ -312,32 +312,32 @@ void runtime::gpu::GPU_ExternalFunction::compile()
writer +=
R"(// Generated by the NGraph GPU backend
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/gpu/cudnn_descriptors.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/util.hpp"
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/gpu/cudnn_descriptors.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/util.hpp"
)";
string pch_header_source = writer.get_code();
......@@ -346,81 +346,12 @@ void runtime::gpu::GPU_ExternalFunction::compile()
using namespace ngraph;
using namespace ngraph::runtime;
using namespace std;
)";
if (m_emit_timing)
{
writer << "// Declare debug timers\n";
vector<string> names;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (!node->is_parameter() && !node->is_constant())
{
names.push_back(node->get_name());
}
}
}
for (const string& s : names)
{
writer << "ngraph::stopwatch timer_" << s << ";\n";
}
writer << "extern \"C\" size_t get_debug_timer_count() { return " << names.size()
<< "; }\n";
writer << "extern \"C\" const char* get_debug_timer_name(size_t index)\n";
writer << "{\n";
writer.indent++;
writer << "const char* rc;\n";
writer << "switch(index)\n";
writer << "{\n";
for (size_t i = 0; i < names.size(); i++)
{
writer << "case " << i << ": rc = \"" << names[i] << "\"; break;\n";
}
writer << "default: rc = \"\";\n";
writer << "}\n";
writer << "return rc;\n";
writer.indent--;
writer << "}\n";
writer << "extern \"C\" const size_t get_debug_timer_microseconds(size_t index)\n";
writer << "{\n";
writer.indent++;
writer << "size_t rc;\n";
writer << "switch(index)\n";
writer << "{\n";
for (size_t i = 0; i < names.size(); i++)
{
writer << "case " << i << ": rc = timer_" << names[i]
<< ".get_total_microseconds(); break;\n";
}
writer << "default: rc = 0;\n";
writer << "}\n";
writer << "return rc;\n";
writer.indent--;
writer << "}\n";
writer << "extern \"C\" const size_t get_debug_timer_call_count(size_t index)\n";
writer << "{\n";
writer.indent++;
writer << "size_t rc;\n";
writer << "switch(index)\n";
writer << "{\n";
for (size_t i = 0; i < names.size(); i++)
{
writer << "case " << i << ": rc = timer_" << names[i] << ".get_call_count(); break;\n";
}
writer << "default: rc = 0;\n";
writer << "}\n";
writer << "return rc;\n";
writer.indent--;
writer << "}\n";
writer << "\n";
}
// // The "dso_handle" symbol is required by __cxa_atexit()
// // which is enabled because the JIT uses it as the default mechanism
// // to register cleanup handlers. We use it, and not atexit(), because
// // atexit() happens too late, when the JIT is no longer alive
)";
// The "dso_handle" symbol is required by __cxa_atexit()
// which is enabled because the JIT uses it as the default mechanism
// to register cleanup handlers. We use it, and not atexit(), because
// atexit() happens too late, when the JIT is no longer alive
writer << "void *__dso_handle = 0;\n\n";
writer << "// Declare all constants\n";
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
......@@ -432,6 +363,8 @@ using namespace std;
{
shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
auto c_value_strings = c->get_value_strings();
writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " *"
<< tv->get_tensor().get_name() << ";\n";
writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " "
<< tv->get_tensor().get_name() << "_cpu[" << c_value_strings.size()
<< "] =\n";
......@@ -440,8 +373,6 @@ using namespace std;
writer << emit_string_array(c_value_strings, 100 - writer.indent * 4);
writer.indent--;
writer << "\n};\n\n";
writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " *"
<< tv->get_tensor().get_name() << ";\n";
m_variable_name_map[tv->get_tensor().get_name()] = tv->get_tensor().get_name();
}
}
......@@ -449,7 +380,7 @@ using namespace std;
// Add cudnn descriptor factory for descriptor management.
// After the cuDNN code emitted in gpu_emitter.cc is refactored
// into the CUDNNEmitter class, this can be removed.
writer << "static runtime::gpu::CUDNNDescriptors descriptors;\n";
writer << "static runtime::gpu::CUDNNDescriptors descriptors;\n\n";
writer << "// Declare all functions\n";
for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
......@@ -457,85 +388,52 @@ using namespace std;
writer << "extern \"C\" void " << f->get_name() << "(void** inputs, void** outputs, "
<< "gpu::GPURuntimeContext* ctx);\n";
}
writer << "\n";
unordered_map<Node*, string> match_functions;
// This for loop creates a collection of functions that are called more than once
// and emitting them as globally callable functions.
// ops implement the is_functionally_identical method
unordered_map<string, string> match_function_map;
unordered_map<const Node*, string> node_function_map;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
set<string> output_names;
for (shared_ptr<Node> op : current_function->get_results())
{
shared_ptr<descriptor::TensorView> tv = op->get_output_tensor_view();
output_names.insert(tv->get_tensor().get_name());
}
const list<shared_ptr<Node>>& tmp = current_function->get_ordered_ops();
list<shared_ptr<Node>> tmp = current_function->get_ordered_ops();
if (tmp.size() < 2)
{
// Since we are comparing ops there must be at least two ops to proceed.
continue;
}
vector<shared_ptr<Node>> op_list{tmp.begin(), tmp.end()};
for (size_t i = 0; i < op_list.size() - 1; i++)
for (size_t i = 0; i < op_list.size(); i++)
{
if (op_list[i]->is_constant() || op_list[i]->is_parameter())
{
continue;
}
if (contains_key(match_functions, op_list[i].get()))
Node& node = *op_list[i];
auto handler = dispatcher.find(type_index(typeid(node)));
if (handler == dispatcher.end())
{
continue;
throw ngraph_error("Unhandled op during code generation : " + node.description());
}
string match_function = emit_op_as_function(node, "__f__");
string match_function_name;
if (!match_function_name.empty())
if (contains_key(match_function_map, match_function))
{
writer << "static void " << match_function_name << "(";
writer.indent++;
// Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto& n = *op_list[i];
auto handler = dispatcher.find(type_index(typeid(n)));
vector<GPU_TensorViewWrapper> in;
size_t arg_index = 0;
set<string> arg_names;
for (const descriptor::Input& input : n.get_inputs())
{
const descriptor::Output& output = input.get_output();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
GPU_TensorViewWrapper tvw{tv, "_arg" + to_string(arg_index)};
if (!contains(arg_names, tvw.get_name()))
{
arg_names.insert(tvw.get_name());
if (arg_index++ > 0)
{
writer << ",";
}
writer << "\n";
writer << tvw.get_type() << "* " << tvw.get_name();
}
in.push_back(tvw);
}
vector<GPU_TensorViewWrapper> out;
for (const descriptor::Output& output : n.get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
GPU_TensorViewWrapper tvw{tv, "_out" + to_string(arg_index)};
if (arg_index++ > 0)
{
writer << ",";
}
writer << "\n";
writer << tvw.get_type() << "* " << tvw.get_name();
out.push_back(tvw);
}
writer.indent--;
writer << "\n)\n";
writer << "{\n";
writer.indent++;
handler->second(this, writer, &n, in, out);
writer.indent--;
writer << "}\n";
match_function_name = match_function_map[match_function];
}
else
{
auto offset = match_function.find("__f__");
string emitted_function = match_function;
match_function_name = "func_" + node.get_name();
emitted_function.replace(offset, 5, match_function_name);
match_function_map.insert({match_function, match_function_name});
writer << emitted_function << "\n";
}
node_function_map.insert({&node, match_function_name});
}
}
......@@ -704,12 +602,15 @@ using namespace std;
throw ngraph_error("Unhandled op during code generation : " + node->description());
}
vector<GPU_TensorViewWrapper> in;
vector<string> node_input_names;
vector<string> node_output_names;
for (const descriptor::Input& input : node->get_inputs())
{
const descriptor::Output& output = input.get_output();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
in.push_back(
GPU_TensorViewWrapper(tv, m_variable_name_map[tv->get_tensor().get_name()]));
node_input_names.emplace_back(tv->get_tensor().get_name());
}
vector<GPU_TensorViewWrapper> out;
for (const descriptor::Output& output : node->get_outputs())
......@@ -717,6 +618,18 @@ using namespace std;
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
out.push_back(
GPU_TensorViewWrapper(tv, m_variable_name_map[tv->get_tensor().get_name()]));
node_output_names.emplace_back(tv->get_tensor().get_name());
}
// Emit function description comment
if (!node->is_parameter() && !node->is_constant())
{
writer << "\n// " << node->get_name() << "(";
vector<string> parameter_nodes = node_input_names;
parameter_nodes.insert(
parameter_nodes.end(), node_output_names.begin(), node_output_names.end());
writer << join(parameter_nodes);
writer << ")\n";
}
// Emit operation prologue
......@@ -730,13 +643,10 @@ using namespace std;
// Emit operation body
string func_name;
auto it = match_functions.find(node.get());
if (it != match_functions.end())
{
func_name = it->second;
}
func_name = node_function_map[node.get()];
if (func_name.empty())
{
//throw runtime_error("No matching function found for '" + node->get_name() + "'");
handler->second(this, writer, node.get(), in, out);
}
else
......@@ -750,6 +660,7 @@ using namespace std;
{
names.push_back(tv.get_name());
}
names.push_back("ctx");
writer << func_name << "(" << join(names) << ");\n";
}
......@@ -875,3 +786,117 @@ std::unique_ptr<runtime::gpu::GPURuntimeContext>& runtime::gpu::GPU_ExternalFunc
{
return m_ctx;
}
bool runtime::gpu::GPU_ExternalFunction::is_functionally_identical(
const Node& n1, const Node& n2, const unordered_map<const Node*, string>& node_cache) const
{
return node_cache.at(&n1) == node_cache.at(&n2);
}
string runtime::gpu::GPU_ExternalFunction::emit_op_as_function(const Node& node,
const string& function_name)
{
codegen::CodeWriter writer;
writer << "static void " << function_name << "(";
writer.indent++;
// Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto handler = dispatcher.find(type_index(typeid(node)));
vector<GPU_TensorViewWrapper> in;
size_t arg_index = 0;
set<string> arg_names;
for (const descriptor::Input& input : node.get_inputs())
{
const descriptor::Output& output = input.get_output();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
GPU_TensorViewWrapper tvw{tv, "_arg" + to_string(arg_index)};
if (!contains(arg_names, tvw.get_name()))
{
arg_names.insert(tvw.get_name());
if (arg_index++ > 0)
{
writer << ",";
}
writer << "\n";
writer << tvw.get_type() << "* " << tvw.get_name();
}
in.push_back(tvw);
}
vector<GPU_TensorViewWrapper> out;
for (const descriptor::Output& output : node.get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
GPU_TensorViewWrapper tvw{tv, "_out" + to_string(arg_index)};
if (arg_index++ > 0)
{
writer << ",";
}
writer << "\n";
writer << tvw.get_type() << "* " << tvw.get_name();
out.push_back(tvw);
}
writer << ",\ngpu::GPURuntimeContext* ctx";
writer.indent--;
writer << "\n)\n";
codegen::CodeWriter tmp_writer;
handler->second(this, tmp_writer, &node, in, out);
string body = tmp_writer.get_code();
if (body.size() > 0 && body[0] == '{')
{
// Body already surrounded by curly braces so don't add more
writer << body;
}
else
{
writer.block_begin();
writer << body;
writer.block_end();
}
string rc = writer.get_code();
if (function_name == "f")
{
rc = strip_comments(rc);
}
return rc;
}
string runtime::gpu::GPU_ExternalFunction::strip_comments(const string& s) const
{
stringstream out;
for (size_t i = 0; i < s.size(); i++)
{
if (i < s.size() - 2)
{
if (s[i] == '/' && s[i + 1] == '/')
{
// line comment
i += 2;
while (s[i] != '\n')
{
i++;
}
out << '\n';
}
else if (s[i] == '/' && s[i + 1] == '*')
{
// multi-line comment
i += 2;
while (!(s[i] == '*' && s[i + 1] == '/'))
{
i++;
}
i++;
}
else
{
out << s[i];
}
}
else
{
out << s[i];
}
}
return out.str();
}
......@@ -83,6 +83,13 @@ namespace ngraph
const Node&,
const std::unordered_map<descriptor::TensorView*, std::vector<size_t>>&);
void release_function() { m_function = nullptr; }
std::string emit_op_as_function(const Node& node, const std::string& function_name);
std::string strip_comments(const std::string& s) const;
bool is_functionally_identical(
const Node& n1,
const Node& n2,
const std::unordered_map<const Node*, std::string>& node_cache) const;
std::unique_ptr<codegen::Compiler> m_compiler;
std::unique_ptr<codegen::ExecutionEngine> m_execution_engine;
bool m_emit_timing;
......
......@@ -21,7 +21,6 @@ divide_by_zero_float32
divide_by_zero_int32
dot_4d_5d_multi_axis_big_fp64_VERY_SLOW
dot_matrix_vector_int64
function_call
mkldnn_layouts
numeric_double_nan
numeric_float_inf
......
......@@ -35,6 +35,7 @@
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
......@@ -2197,3 +2198,45 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, fuse_rnn_across_layer)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_1timestep.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t ref_rnn_count = 1;
auto rnn_count = count_ops_of_type<op::Rnn>(func);
EXPECT_EQ(ref_rnn_count, rnn_count);
}
TEST(cpu_fusion, fuse_rnn_across_2layer_1timestep)
{
const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::Rnn>(cpu_f));
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(1), int_results.at(1), 1.0e-4f, 1.0e-4f));
}
}
[
{
"name": "Function_0",
"ops": [
{
"element_type": "float",
"inputs": [],
"name": "Parameter_55",
"op": "Parameter",
"outputs": [
"Parameter_55_0"
],
"shape": [
400
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_54",
"op": "Parameter",
"outputs": [
"Parameter_54_0"
],
"shape": [
400,
100
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_47",
"op": "Parameter",
"outputs": [
"Parameter_47_0"
],
"shape": [
400
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_46",
"op": "Parameter",
"outputs": [
"Parameter_46_0"
],
"shape": [
400,
100
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_10",
"op": "Parameter",
"outputs": [
"Parameter_10_0"
],
"shape": [
400
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_9",
"op": "Parameter",
"outputs": [
"Parameter_9_0"
],
"shape": [
400,
100
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_2",
"op": "Parameter",
"outputs": [
"Parameter_2_0"
],
"shape": [
400
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_1",
"op": "Parameter",
"outputs": [
"Parameter_1_0"
],
"shape": [
400,
100
]
},
{
"element_type": "float",
"inputs": [],
"name": "Parameter_0",
"op": "Parameter",
"outputs": [
"Parameter_0_0"
],
"shape": [
10,
100
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_52",
"op": "Constant",
"outputs": [
"Constant_52_0"
],
"shape": [],
"value": [
"0"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_7",
"op": "Constant",
"outputs": [
"Constant_7_0"
],
"shape": [],
"value": [
"0"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_34",
"op": "Constant",
"outputs": [
"Constant_34_0"
],
"shape": [],
"value": [
"1"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_30",
"op": "Constant",
"outputs": [
"Constant_30_0"
],
"shape": [],
"value": [
"0"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_24",
"op": "Constant",
"outputs": [
"Constant_24_0"
],
"shape": [],
"value": [
"1"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_17",
"op": "Constant",
"outputs": [
"Constant_17_0"
],
"shape": [],
"value": [
"1"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_79",
"op": "Constant",
"outputs": [
"Constant_79_0"
],
"shape": [],
"value": [
"1"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_75",
"op": "Constant",
"outputs": [
"Constant_75_0"
],
"shape": [],
"value": [
"0"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_69",
"op": "Constant",
"outputs": [
"Constant_69_0"
],
"shape": [],
"value": [
"1"
]
},
{
"element_type": "float",
"inputs": [],
"name": "Constant_62",
"op": "Constant",
"outputs": [
"Constant_62_0"
],
"shape": [],
"value": [
"1"
]
},
{
"axes": [
0
],
"inputs": [
"Parameter_55"
],
"name": "Broadcast_58",
"op": "Broadcast",
"outputs": [
"Broadcast_58_0"
],
"shape": [
10,
400
]
},
{
"input_order": [
1,
0
],
"inputs": [
"Parameter_54"
],
"name": "Reshape_56",
"op": "Reshape",
"output_shape": [
100,
400
],
"outputs": [
"Reshape_56_0"
]
},
{
"axes": [
0
],
"inputs": [
"Parameter_47"
],
"name": "Broadcast_50",
"op": "Broadcast",
"outputs": [
"Broadcast_50_0"
],
"shape": [
10,
400
]
},
{
"input_order": [
1,
0
],
"inputs": [
"Parameter_46"
],
"name": "Reshape_48",
"op": "Reshape",
"output_shape": [
100,
400
],
"outputs": [
"Reshape_48_0"
]
},
{
"axes": [
0
],
"inputs": [
"Parameter_10"
],
"name": "Broadcast_13",
"op": "Broadcast",
"outputs": [
"Broadcast_13_0"
],
"shape": [
10,
400
]
},
{
"input_order": [
1,
0
],
"inputs": [
"Parameter_9"
],
"name": "Reshape_11",
"op": "Reshape",
"output_shape": [
100,
400
],
"outputs": [
"Reshape_11_0"
]
},
{
"axes": [
0
],
"inputs": [
"Parameter_2"
],
"name": "Broadcast_5",
"op": "Broadcast",
"outputs": [
"Broadcast_5_0"
],
"shape": [
10,
400
]
},
{
"input_order": [
1,
0
],
"inputs": [
"Parameter_1"
],
"name": "Reshape_3",
"op": "Reshape",
"output_shape": [
100,
400
],
"outputs": [
"Reshape_3_0"
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_52"
],
"name": "Broadcast_53",
"op": "Broadcast",
"outputs": [
"Broadcast_53_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_7"
],
"name": "Broadcast_8",
"op": "Broadcast",
"outputs": [
"Broadcast_8_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_34"
],
"name": "Broadcast_35",
"op": "Broadcast",
"outputs": [
"Broadcast_35_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_30"
],
"name": "Broadcast_31",
"op": "Broadcast",
"outputs": [
"Broadcast_31_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_24"
],
"name": "Broadcast_25",
"op": "Broadcast",
"outputs": [
"Broadcast_25_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_17"
],
"name": "Broadcast_18",
"op": "Broadcast",
"outputs": [
"Broadcast_18_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_79"
],
"name": "Broadcast_80",
"op": "Broadcast",
"outputs": [
"Broadcast_80_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_75"
],
"name": "Broadcast_76",
"op": "Broadcast",
"outputs": [
"Broadcast_76_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_69"
],
"name": "Broadcast_70",
"op": "Broadcast",
"outputs": [
"Broadcast_70_0"
],
"shape": [
10,
100
]
},
{
"axes": [
0,
1
],
"inputs": [
"Constant_62"
],
"name": "Broadcast_63",
"op": "Broadcast",
"outputs": [
"Broadcast_63_0"
],
"shape": [
10,
100
]
},
{
"inputs": [
"Parameter_0",
"Reshape_3"
],
"name": "Dot_4",
"op": "Dot",
"outputs": [
"Dot_4_0"
],
"reduction_axes_count": 1
},
{
"inputs": [
"Broadcast_53",
"Reshape_56"
],
"name": "Dot_57",
"op": "Dot",
"outputs": [
"Dot_57_0"
],
"reduction_axes_count": 1
},
{
"inputs": [
"Broadcast_8",
"Reshape_11"
],
"name": "Dot_12",
"op": "Dot",
"outputs": [
"Dot_12_0"
],
"reduction_axes_count": 1
},
{
"inputs": [
"Dot_4",
"Broadcast_5"
],
"name": "Add_6",
"op": "Add",
"outputs": [
"Add_6_0"
]
},
{
"inputs": [
"Dot_57",
"Broadcast_58"
],
"name": "Add_59",
"op": "Add",
"outputs": [
"Add_59_0"
]
},
{
"inputs": [
"Dot_12",
"Broadcast_13"
],
"name": "Add_14",
"op": "Add",
"outputs": [
"Add_14_0"
]
},
{
"inputs": [
"Add_6",
"Add_14"
],
"name": "Add_15",
"op": "Add",
"outputs": [
"Add_15_0"
]
},
{
"inputs": [
"Add_15"
],
"lower_bounds": [
0,
300
],
"name": "Slice_16",
"op": "Slice",
"outputs": [
"Slice_16_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
400
]
},
{
"inputs": [
"Add_15"
],
"lower_bounds": [
0,
100
],
"name": "Slice_23",
"op": "Slice",
"outputs": [
"Slice_23_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
200
]
},
{
"inputs": [
"Add_15"
],
"lower_bounds": [
0,
0
],
"name": "Slice_33",
"op": "Slice",
"outputs": [
"Slice_33_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
100
]
},
{
"inputs": [
"Add_15"
],
"lower_bounds": [
0,
200
],
"name": "Slice_40",
"op": "Slice",
"outputs": [
"Slice_40_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
300
]
},
{
"inputs": [
"Slice_16"
],
"name": "Negative_19",
"op": "Negative",
"outputs": [
"Negative_19_0"
]
},
{
"inputs": [
"Slice_23"
],
"name": "Negative_26",
"op": "Negative",
"outputs": [
"Negative_26_0"
]
},
{
"inputs": [
"Slice_33"
],
"name": "Negative_36",
"op": "Negative",
"outputs": [
"Negative_36_0"
]
},
{
"inputs": [
"Slice_40"
],
"name": "Tanh_41",
"op": "Tanh",
"outputs": [
"Tanh_41_0"
]
},
{
"inputs": [
"Negative_19"
],
"name": "Exp_20",
"op": "Exp",
"outputs": [
"Exp_20_0"
]
},
{
"inputs": [
"Negative_26"
],
"name": "Exp_27",
"op": "Exp",
"outputs": [
"Exp_27_0"
]
},
{
"inputs": [
"Negative_36"
],
"name": "Exp_37",
"op": "Exp",
"outputs": [
"Exp_37_0"
]
},
{
"inputs": [
"Broadcast_18",
"Exp_20"
],
"name": "Add_21",
"op": "Add",
"outputs": [
"Add_21_0"
]
},
{
"inputs": [
"Broadcast_25",
"Exp_27"
],
"name": "Add_28",
"op": "Add",
"outputs": [
"Add_28_0"
]
},
{
"inputs": [
"Broadcast_35",
"Exp_37"
],
"name": "Add_38",
"op": "Add",
"outputs": [
"Add_38_0"
]
},
{
"inputs": [
"Broadcast_18",
"Add_21"
],
"name": "Divide_22",
"op": "Divide",
"outputs": [
"Divide_22_0"
]
},
{
"inputs": [
"Broadcast_25",
"Add_28"
],
"name": "Divide_29",
"op": "Divide",
"outputs": [
"Divide_29_0"
]
},
{
"inputs": [
"Broadcast_35",
"Add_38"
],
"name": "Divide_39",
"op": "Divide",
"outputs": [
"Divide_39_0"
]
},
{
"inputs": [
"Divide_29",
"Broadcast_31"
],
"name": "Multiply_32",
"op": "Multiply",
"outputs": [
"Multiply_32_0"
]
},
{
"inputs": [
"Divide_39",
"Tanh_41"
],
"name": "Multiply_42",
"op": "Multiply",
"outputs": [
"Multiply_42_0"
]
},
{
"inputs": [
"Multiply_32",
"Multiply_42"
],
"name": "Add_43",
"op": "Add",
"outputs": [
"Add_43_0"
]
},
{
"inputs": [
"Add_43"
],
"name": "Tanh_44",
"op": "Tanh",
"outputs": [
"Tanh_44_0"
]
},
{
"inputs": [
"Divide_22",
"Tanh_44"
],
"name": "Multiply_45",
"op": "Multiply",
"outputs": [
"Multiply_45_0"
]
},
{
"inputs": [
"Multiply_45",
"Reshape_48"
],
"name": "Dot_49",
"op": "Dot",
"outputs": [
"Dot_49_0"
],
"reduction_axes_count": 1
},
{
"inputs": [
"Dot_49",
"Broadcast_50"
],
"name": "Add_51",
"op": "Add",
"outputs": [
"Add_51_0"
]
},
{
"inputs": [
"Add_51",
"Add_59"
],
"name": "Add_60",
"op": "Add",
"outputs": [
"Add_60_0"
]
},
{
"inputs": [
"Add_60"
],
"lower_bounds": [
0,
300
],
"name": "Slice_61",
"op": "Slice",
"outputs": [
"Slice_61_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
400
]
},
{
"inputs": [
"Add_60"
],
"lower_bounds": [
0,
100
],
"name": "Slice_68",
"op": "Slice",
"outputs": [
"Slice_68_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
200
]
},
{
"inputs": [
"Add_60"
],
"lower_bounds": [
0,
0
],
"name": "Slice_78",
"op": "Slice",
"outputs": [
"Slice_78_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
100
]
},
{
"inputs": [
"Add_60"
],
"lower_bounds": [
0,
200
],
"name": "Slice_85",
"op": "Slice",
"outputs": [
"Slice_85_0"
],
"strides": [
1,
1
],
"upper_bounds": [
10,
300
]
},
{
"inputs": [
"Slice_61"
],
"name": "Negative_64",
"op": "Negative",
"outputs": [
"Negative_64_0"
]
},
{
"inputs": [
"Slice_68"
],
"name": "Negative_71",
"op": "Negative",
"outputs": [
"Negative_71_0"
]
},
{
"inputs": [
"Slice_78"
],
"name": "Negative_81",
"op": "Negative",
"outputs": [
"Negative_81_0"
]
},
{
"inputs": [
"Slice_85"
],
"name": "Tanh_86",
"op": "Tanh",
"outputs": [
"Tanh_86_0"
]
},
{
"inputs": [
"Negative_64"
],
"name": "Exp_65",
"op": "Exp",
"outputs": [
"Exp_65_0"
]
},
{
"inputs": [
"Negative_71"
],
"name": "Exp_72",
"op": "Exp",
"outputs": [
"Exp_72_0"
]
},
{
"inputs": [
"Negative_81"
],
"name": "Exp_82",
"op": "Exp",
"outputs": [
"Exp_82_0"
]
},
{
"inputs": [
"Broadcast_63",
"Exp_65"
],
"name": "Add_66",
"op": "Add",
"outputs": [
"Add_66_0"
]
},
{
"inputs": [
"Broadcast_70",
"Exp_72"
],
"name": "Add_73",
"op": "Add",
"outputs": [
"Add_73_0"
]
},
{
"inputs": [
"Broadcast_80",
"Exp_82"
],
"name": "Add_83",
"op": "Add",
"outputs": [
"Add_83_0"
]
},
{
"inputs": [
"Broadcast_63",
"Add_66"
],
"name": "Divide_67",
"op": "Divide",
"outputs": [
"Divide_67_0"
]
},
{
"inputs": [
"Broadcast_70",
"Add_73"
],
"name": "Divide_74",
"op": "Divide",
"outputs": [
"Divide_74_0"
]
},
{
"inputs": [
"Broadcast_80",
"Add_83"
],
"name": "Divide_84",
"op": "Divide",
"outputs": [
"Divide_84_0"
]
},
{
"inputs": [
"Divide_74",
"Broadcast_76"
],
"name": "Multiply_77",
"op": "Multiply",
"outputs": [
"Multiply_77_0"
]
},
{
"inputs": [
"Divide_84",
"Tanh_86"
],
"name": "Multiply_87",
"op": "Multiply",
"outputs": [
"Multiply_87_0"
]
},
{
"inputs": [
"Multiply_77",
"Multiply_87"
],
"name": "Add_88",
"op": "Add",
"outputs": [
"Add_88_0"
]
},
{
"inputs": [
"Add_88"
],
"name": "Tanh_89",
"op": "Tanh",
"outputs": [
"Tanh_89_0"
]
},
{
"inputs": [
"Add_88"
],
"name": "Result_94",
"op": "Result",
"outputs": [
"Result_94_0"
]
},
{
"inputs": [
"Divide_67",
"Tanh_89"
],
"name": "Multiply_90",
"op": "Multiply",
"outputs": [
"Multiply_90_0"
]
},
{
"inputs": [
"Multiply_90"
],
"name": "Result_93",
"op": "Result",
"outputs": [
"Result_93_0"
]
}
],
"parameters": [
"Parameter_0",
"Parameter_1",
"Parameter_2",
"Parameter_9",
"Parameter_10",
"Parameter_46",
"Parameter_47",
"Parameter_54",
"Parameter_55"
],
"result": [
"Result_93",
"Result_94"
]
}
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment