Commit 1d08f073 authored by Pruthvi's avatar Pruthvi Committed by Adam Procter

LSTM fusion + RNN fusion across time slice's for single layer (#826)

* - Added pattren matcher for LSTM cell

* WIP added support to replace lstm cell instead of subgraph

* WIP LSTM pattern matcher, fuses recurrent cells

* WIP added RNN CPU op

* WIP mkldnn emmiter code for fprop RNN

* WIP RNN mkldnn integration
- Added mkldnn kernel for uni directional LSTM in the CPU emitter

* add a getter for root node

* recurrent graph rewrite

* fix perms, rename match_root -> get_match_root

* fix comp errors

* make match_root return the topmost match; fix tests

* - WIP GetOutputElement for handling multiple LSTM o/ps
- use RecurrentGraphRewrite for replacing node after matching LSTM cells

* WIP LSTM multi Output + debug prints

* moved LSTM fusion to cpu_fusion

* WIP added RNN superfused OP

* WIP towards RNN layer fusion

* WIP multiple output slicing RNN

* WIP RNN mulitple o/ps fusion across layer

* WIP corrected input params for fused RNN OP

* concat corrosponding param's across differnt LSTM to form inputs to RNN fused op

* i) Added  test case for RNN kernel ii) runs without error's

* refactored and moved LSTM class to standalone file

* Rename RNN -> Rnn , LSTM -> Lstm

* WIP replace lstm slices to the consumer op

* Slicing works on multiple RNN layers

* fixed all bugs

* - Added CPU RNN Recurrent Fusion
- Added CPU LSTM fusion
- removed debug code
- style fix

* - Added support to compute src_iter and dst_iter instead of taking zero_memory_desc
- Added unit test to compute one LSTM cell

*  changed RNN op signature to accept number of states in basic unit of RNN(GRU/LSTM/ vanilla RNN) cell

* added sanity checks for RNN op

* Fixed issue related to patching the graph while replacing the RNN sliced outputs

* Fixed issue to feed the input symbols in the order X0, X1, ...Xt to the RNN op

* Added unit test for multi layer RNN fusion

* Removed debug statements

* Added mulitlayered serialized graph ii) fixed compilation issue

* Addressed PR comments

* i) WIP MKLDNN layout for RNN Op ii) added test case for INTERPRETER v/s CPU Rnn results

* - Fixed bug w.r.to src_layer feature size in rnn mkldnn emitter code
- Refactored cpu_fusion rnn test case

* merge origin/master with branch pruthvi/lstm_fusion

* style fix

* Added test case for multiple RNN layers

* i) make rnn as mkldnn op if it meets the constraints ii) assert if rnn is not mkldnn op

* fix unit test failure

* - Added support to reliabily identify the hiddent state and input symbols from the nodes collected by Pattern matcher
- Fixed failing unit tests

* style fix

* - removed "node type" dependency to replace the intermediate LSTM outputs

* Addressed PR comments

* Fix unit test

* - added MKLDNN emitter for LSTM op
- graph pass to concat LSTM input recurrent state tensors
- CPU layout assignment for LSTM Op
- Fixed bug in rnn/lstm unit test's
- made changes to use replace_output instead of replace_node for replacing matched graph nodes in LSTM/RNN fusion pass

(cherry picked from commit d16fc709265cc0a73e60c6d5f6d2878e7b908aca)

* style fix

* Renamed passes and style fixes
parent 6425a516
......@@ -216,11 +216,15 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
runtime/cpu/op/conv_relu.cpp
runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/rnn.cpp
runtime/cpu/op/lstm.cpp
runtime/cpu/op/matmul_bias.cpp
runtime/cpu/op/max_pool_with_indices.cpp
runtime/cpu/op/batch_norm_relu.cpp
runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_rnn_fusion.cpp
runtime/cpu/pass/cpu_concat_inputs.cpp
runtime/cpu/pass/cpu_workspace_insertion.cpp
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
......
......@@ -97,8 +97,10 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
......@@ -368,6 +370,213 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Lstm)
{
const ngraph::op::Lstm* lstm_node = static_cast<const ngraph::op::Lstm*>(node);
if (args.size() != 5 || !lstm_node->get_fused_inputs())
{
throw ngraph_error(
"Lstm op doesnt have the required number of inputs to emit MKLDNN kernel");
}
const int src_sequence_length_max = lstm_node->get_src_sequence_length();
const int direction = lstm_node->get_direction();
const int num_fused_layers = lstm_node->get_num_fused_layers();
const int lstm_cell_n_gates = lstm_node->get_gates_per_cell();
const int lstm_cell_n_states = lstm_node->get_num_cell_states();
const int feature_size = lstm_node->get_src_iter_feature_size();
const int batch = lstm_node->get_batch_size();
if (out[0].get_shape().size() == 2 && (out[0].get_shape()[1] != feature_size))
{
throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature size ");
}
if (out[1].get_shape().size() == 2 && (out[1].get_shape()[1] != feature_size) &&
lstm_node->get_num_timesteps() != 1)
{
throw ngraph_error(
"input sic{ht_1|ct_1} feature size is not equal to output dlc{ht_1|ct_1} "
"feature size ");
}
NGRAPH_DEBUG << "slc: " << lstm_node->get_src_layer_feature_size()
<< " sic: " << feature_size;
NGRAPH_DEBUG << "batch_size: " << batch << " lstm_cell_n_states "
<< lstm_cell_n_states << " lstm_cell_n_gates: " << lstm_cell_n_gates
<< " src_sequence_length_max: " << src_sequence_length_max;
mkldnn::memory::dims src_layer_tz = {
src_sequence_length_max, batch, lstm_node->get_src_layer_feature_size()};
mkldnn::memory::dims src_iter_tz = {
num_fused_layers, direction, lstm_cell_n_states, batch, feature_size};
mkldnn::memory::dims weights_layer_tz = {num_fused_layers,
direction,
lstm_node->get_src_layer_feature_size(),
lstm_cell_n_gates,
feature_size};
mkldnn::memory::dims weights_iter_tz = {
num_fused_layers, direction, feature_size, lstm_cell_n_gates, feature_size};
mkldnn::memory::dims bias_tz = {
num_fused_layers, direction, lstm_cell_n_gates, feature_size};
mkldnn::memory::dims dst_layer_tz = {src_sequence_length_max, batch, feature_size};
mkldnn::memory::dims dst_iter_tz = {
num_fused_layers, direction, lstm_cell_n_states, batch, feature_size};
// We create the memory descriptors used by the user
auto src_layer_md = mkldnn::memory::desc(
{src_layer_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
auto src_iter_md = mkldnn::memory::desc(
{src_iter_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
auto wei_layer_md = mkldnn::memory::desc({weights_layer_tz},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::ldigo);
auto wei_iter_md = mkldnn::memory::desc({weights_iter_tz},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::ldigo);
auto bias_md = mkldnn::memory::desc(
{bias_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
auto dst_layer_md = mkldnn::memory::desc(
{dst_layer_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
auto dst_iter_md = mkldnn::memory::desc(
{dst_iter_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lstm_index = mkldnn_emitter->build_rnn_forward(src_layer_md,
src_iter_md,
wei_layer_md,
wei_iter_md,
bias_md,
dst_layer_md,
dst_iter_md);
auto& deps = mkldnn_emitter->get_primitive_deps(lstm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< args[3].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< args[4].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(lstm_index) << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Rnn)
{
const ngraph::op::Rnn* rnn_node = static_cast<const ngraph::op::Rnn*>(node);
const int src_sequence_length_max = rnn_node->get_src_sequence_length();
const int direction = rnn_node->get_direction();
const int num_fused_layers = rnn_node->get_num_fused_layers();
const int rnn_cell_n_gates = rnn_node->get_gates_per_cell();
const int rnn_cell_n_states = rnn_node->get_num_cell_states();
const int feature_size = rnn_node->get_src_iter_feature_size();
const int batch = rnn_node->get_batch_size();
if (out[0].get_shape().size() == 2 && (out[0].get_shape()[1] != feature_size))
{
throw ngraph_error(
"input slc{ht} feature size is not equal to output dlc{ht} feature size ");
}
if (out[1].get_shape().size() == 2 && (out[1].get_shape()[1] != feature_size))
{
throw ngraph_error(
"input sic{ht_1|ct_1} feature size is not equal to output dlc{ht_1|ct_1} "
"feature size ");
}
NGRAPH_DEBUG << "slc: " << rnn_node->get_src_layer_feature_size()
<< " sic: " << feature_size;
NGRAPH_DEBUG << "batch_size: " << batch << " rnn_cell_n_states "
<< rnn_cell_n_states << " rnn_cell_n_gates: " << rnn_cell_n_gates
<< " src_sequence_length_max: " << src_sequence_length_max;
mkldnn::memory::dims src_layer_tz = {
src_sequence_length_max, batch, rnn_node->get_src_layer_feature_size()};
mkldnn::memory::dims src_iter_tz = {
num_fused_layers, direction, rnn_cell_n_states, batch, feature_size};
mkldnn::memory::dims weights_layer_tz = {num_fused_layers,
direction,
rnn_node->get_src_layer_feature_size(),
rnn_cell_n_gates,
feature_size};
mkldnn::memory::dims weights_iter_tz = {
num_fused_layers, direction, feature_size, rnn_cell_n_gates, feature_size};
mkldnn::memory::dims bias_tz = {
num_fused_layers, direction, rnn_cell_n_gates, feature_size};
mkldnn::memory::dims dst_layer_tz = {src_sequence_length_max, batch, feature_size};
mkldnn::memory::dims dst_iter_tz = {
num_fused_layers, direction, rnn_cell_n_states, batch, feature_size};
// We create the memory descriptors used by the user
auto src_layer_md = mkldnn::memory::desc(
{src_layer_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
auto src_iter_md = mkldnn::memory::desc(
{src_iter_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
auto wei_layer_md = mkldnn::memory::desc({weights_layer_tz},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::ldigo);
auto wei_iter_md = mkldnn::memory::desc({weights_iter_tz},
mkldnn::memory::data_type::f32,
mkldnn::memory::format::ldigo);
auto bias_md = mkldnn::memory::desc(
{bias_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
auto dst_layer_md = mkldnn::memory::desc(
{dst_layer_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
auto dst_iter_md = mkldnn::memory::desc(
{dst_iter_tz}, mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto rnn_index = mkldnn_emitter->build_rnn_forward(src_layer_md,
src_iter_md,
wei_layer_md,
wei_iter_md,
bias_md,
dst_layer_md,
dst_iter_md);
auto& deps = mkldnn_emitter->get_primitive_deps(rnn_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< args[3].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< args[4].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[5]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[6]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " << to_string(rnn_index)
<< ");\n";
}
void CPU_Emitter::emitBatchNorm(CPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
......
......@@ -120,13 +120,17 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
......@@ -274,6 +278,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>},
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::CPU_Emitter::emit<op::Lstm>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndicesBackprop>},
......@@ -282,6 +287,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>},
{TI(ngraph::op::Relu), &runtime::cpu::CPU_Emitter::emit<op::Relu>},
{TI(ngraph::op::ReluBackprop), &runtime::cpu::CPU_Emitter::emit<op::ReluBackprop>},
{TI(ngraph::op::Rnn), &runtime::cpu::CPU_Emitter::emit<op::Rnn>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>},
{TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>},
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
......@@ -317,6 +323,9 @@ void runtime::cpu::CPU_ExternalFunction::compile()
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
......@@ -764,6 +764,59 @@ size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weigh
return batchnorm_index;
}
size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
const mkldnn::memory::desc& src_iter_desc,
const mkldnn::memory::desc& weights_layer_desc,
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc)
{
size_t src_layer_index = build_memory_primitive(src_layer_desc);
size_t src_iter_index = build_memory_primitive(src_iter_desc);
size_t weights_layer_index = build_memory_primitive(weights_layer_desc);
size_t weights_iter_index = build_memory_primitive(weights_iter_desc);
size_t bias_index = build_memory_primitive(bias_desc);
size_t dst_layer_index = build_memory_primitive(dst_layer_desc);
size_t dst_iter_index = build_memory_primitive(dst_iter_desc);
//TODO: figure our the role of workspace
auto null_memory_ = mkldnn::null_memory(mkldnn_utils::global_cpu_engine);
mkldnn::rnn_cell::desc rnn_cell(mkldnn::algorithm::vanilla_lstm);
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_inference,
rnn_cell,
mkldnn::rnn_direction::unidirectional_left2right,
src_layer_desc,
src_iter_desc,
weights_layer_desc,
weights_iter_desc,
bias_desc,
dst_layer_desc,
dst_iter_desc);
auto rnn_layer_prim_desc =
mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, mkldnn_utils::global_cpu_engine);
size_t rnn_index = insert_primitive(
new mkldnn::rnn_forward(rnn_layer_prim_desc,
mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[src_iter_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_layer_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_iter_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[bias_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_layer_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_iter_index]),
static_cast<mkldnn::memory>(null_memory_)));
m_primitive_deps[rnn_index] = {src_layer_index,
src_iter_index,
weights_layer_index,
weights_iter_index,
bias_index,
dst_layer_index,
dst_iter_index};
return rnn_index;
}
size_t MKLDNNEmitter::build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc,
const size_t concat_dim)
......
......@@ -202,6 +202,14 @@ namespace ngraph
const mkldnn::memory::desc& dweights_desc,
const double eps);
size_t build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
const mkldnn::memory::desc& src_iter_desc,
const mkldnn::memory::desc& weights_layer_desc,
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc);
size_t build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc,
const size_t concat_dim);
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
......@@ -46,6 +47,7 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm),
TI(ngraph::op::BatchNormBackprop),
TI(ngraph::op::Concat),
TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
......@@ -113,6 +115,10 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::Ohwi8o, "memory::format::Ohwi8o"},
{memory::format::Ohwi16o, "memory::format::Ohwi16o"},
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
{memory::format::tnc, "memory::format::tnc"},
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
};
static const std::set<memory::format> s_filter_formats{
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::Lstm::copy_with_new_args(const NodeVector& new_args) const
{
if (!m_fused_inputs)
{
if (new_args.size() != 7)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Lstm>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6));
}
else
{
if (new_args.size() != 5 && m_fused_inputs)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Lstm>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
}
op::Lstm::Lstm(std::shared_ptr<Node> input_xt_1,
std::shared_ptr<Node> i2h_weights,
std::shared_ptr<Node> hidden_state_ht_1,
std::shared_ptr<Node> h2h_weights,
std::shared_ptr<Node> i2h_bias,
std::shared_ptr<Node> h2h_bias,
std::shared_ptr<Node> cell_state_ct_1)
: RequiresTensorViewArgs("Lstm",
{input_xt_1,
i2h_weights,
hidden_state_ht_1,
h2h_weights,
i2h_bias,
h2h_bias,
cell_state_ct_1})
, m_output_tensor_shape(hidden_state_ht_1->get_shape())
, m_output_cell_shape(cell_state_ct_1->get_shape())
, m_num_timesteps(1)
, m_num_gates_per_cell(4)
, m_src_sequence_length(1)
, m_src_layer_feature_size(static_cast<int>(input_xt_1->get_shape()[1]))
, m_src_iter_feature_size(static_cast<int>(hidden_state_ht_1->get_shape()[1]))
, m_num_cell_states(2)
, m_direction(1)
, m_num_fused_layers(1)
, m_fused_inputs(false)
{
if (input_xt_1->get_shape().size() != i2h_weights->get_shape().size())
{
throw ngraph_error("input_xt_1 and i2h weights size dont match");
}
if (hidden_state_ht_1->get_shape().size() != h2h_weights->get_shape().size())
{
throw ngraph_error("hidden_state_ht_1 and h2h weights size dont match");
}
if (input_xt_1->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(input_xt_1->get_shape()[0]);
}
else
{
throw ngraph_error("input_xt_1 doesnt have a rank 2");
}
if (shape_size(input_xt_1->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(input_xt_1->get_shape()) << std::endl;
throw ngraph_error("input_xt_1 size is not equal t*n*c");
}
if (i2h_bias->get_shape()[0] != i2h_weights->get_shape()[0] ||
h2h_bias->get_shape()[0] != h2h_weights->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = input_xt_1->get_element_type();
for (auto& lstm_input : get_arguments())
{
if (lstm_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(hidden_state_ht_1->get_element_type(), hidden_state_ht_1->get_shape());
add_output(cell_state_ct_1->get_element_type(), cell_state_ct_1->get_shape());
}
op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias)
: RequiresTensorViewArgs("Lstm", {src_layer, src_iter, weights_layer, weights_iter, bias})
, m_output_tensor_shape(src_layer->get_shape())
, m_output_cell_shape(src_iter->get_shape())
, m_num_timesteps(1)
, m_num_gates_per_cell(4)
, m_src_sequence_length(1)
, m_src_layer_feature_size(static_cast<int>(src_layer->get_shape()[1]))
, m_src_iter_feature_size(static_cast<int>(src_iter->get_shape()[1]))
, m_num_cell_states(2)
, m_direction(1)
, m_num_fused_layers(1)
, m_fused_inputs(true)
{
if (src_layer->get_shape().size() != weights_layer->get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter->get_shape().size() != weights_iter->get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / m_num_timesteps);
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c");
}
if (bias->get_shape()[0] != weights_layer->get_shape()[0] ||
bias->get_shape()[0] != weights_iter->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class Lstm : public util::RequiresTensorViewArgs
{
public:
// INPUTS:
// [0] - xt, input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [2] - ht_1, hidden state of shape (batch_size, feature_size)
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs.
// [5] - Initializer for the bias vector w.r.to hidden state
// [6] - ct_1, cell state of shape (batch_size, feature_size)
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
// [1] - ct, output recurrent state tensor with the same shape as cell state
// This version of the LSTM op is only used to simplify recurrent RNN cell(LSTM) fusion across
// horizontal time steps. This doesnt have mkldnn emitter code.
Lstm(std::shared_ptr<Node> input_xt_1,
std::shared_ptr<Node> i2h_weights,
std::shared_ptr<Node> hidden_state_ht_1,
std::shared_ptr<Node> h2h_weights,
std::shared_ptr<Node> i2h_bias,
std::shared_ptr<Node> h2h_bias,
std::shared_ptr<Node> cell_state_ct_1);
// INPUTS:
// [0] - {Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
// [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs + hidden state (ibh_bias + hbh_bias)
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
// [1] - {ht | ct} output recurrent state tensor with the same shape as states
// This version of the LSTM op supports MKLDNN emitter code, this can be used standalone for computing RNN
// without fusing RNN cell (LSTM)'s across time steps.
Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias);
Shape get_output_tensor_shape() const { return m_output_tensor_shape; }
Shape get_output_cell_shape() const { return m_output_cell_shape; }
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_num_cell_states() const { return m_num_cell_states; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
int get_fused_inputs() const { return m_fused_inputs; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
Shape m_output_tensor_shape;
Shape m_output_cell_shape;
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_num_cell_states;
int m_direction;
int m_num_fused_layers;
bool m_fused_inputs; // True if node gets fused inputs/weights
};
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
new_args[4],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_src_layer_feature_size,
m_src_iter_feature_size,
m_num_cell_states,
m_direction,
m_num_fused_layers);
}
op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int num_cell_states,
const int direction,
const int num_fused_layers)
: RequiresTensorViewArgs("Rnn", {src_layer, src_iter, weights_layer, weights_iter, bias})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_src_layer_feature_size(src_layer_feature_size)
, m_src_iter_feature_size(src_iter_feature_size)
, m_num_cell_states(num_cell_states)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
{
if (src_layer->get_shape().size() != weights_layer->get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter->get_shape().size() != weights_iter->get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / num_timesteps);
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c");
}
if (bias->get_shape()[0] != weights_layer->get_shape()[0] ||
bias->get_shape()[0] != weights_iter->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
// This is RNN op, which is formed by the fusion of multiple RNN cells ( LSTM/ GRU/ vanilla RNN)
// across multiple time slices
// INPUTS:
// [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
// [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs + hidden state (ibh_bias + hbh_bias)
// number_of_timesteps - number of unrolled cells up to timestep t.
// num_gates_per_cell - number of gates per RNN cell, LSTM = 4, GRU = 3, vanilla RNN = 1
// src_sequence_length - this will be same as number_of_timesteps
// src_layer_feature_size - feature size w.r.to input tensor
// src_iter_feature_size - feature size w.r.to hidden state
// num_cell_states - number of recurrent state tensor states , LSTM = 2, GRU = 1, vanilla RNN = 1
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, feature_size) .
// [1] - {ht | ct} output recurrent state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
class Rnn : public util::RequiresTensorViewArgs
{
public:
Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int num_cell_states,
const int direction,
const int num_fused_layers);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_num_cell_states() const { return m_num_cell_states; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
private:
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_num_cell_states;
int m_direction;
int m_num_fused_layers;
};
}
}
......@@ -37,7 +37,9 @@
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -497,6 +499,48 @@ namespace ngraph
batchnorm->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Lstm)
{
auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size();
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto lstm_node = static_cast<op::Lstm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
lstm_node->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn)
{
auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size();
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto rnn_node = static_cast<op::Rnn*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
rnn_node->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -542,6 +586,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "cpu_concat_inputs.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace ngraph;
void ngraph::runtime::cpu::pass::ConcatInputs::concat_lstm_inputs()
{
auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto xt = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto bias1 = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto bias2 = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto lstm = std::make_shared<op::Lstm>(xt, weights_i2h, ht_1, weights_h2h, bias1, bias2, ct_1);
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0);
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
pattern::graph_rewrite_callback callback =
[lstm_node_label, xt, weights_h2h, ht_1, weights_i2h, bias1, bias2, ct_1](
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << " In LSTM MKLDNN callback";
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
std::shared_ptr<Node> src_layer = pattern_map[xt];
std::shared_ptr<Node> src_iter =
std::make_shared<op::Concat>(NodeVector{pattern_map[ht_1], pattern_map[ct_1]}, 0);
std::shared_ptr<Node> bias =
std::make_shared<op::Add>(pattern_map[bias1], pattern_map[bias2]);
auto lstm_node = pattern_map[lstm_node_label]->get_arguments()[0];
auto batch_size = std::dynamic_pointer_cast<op::Lstm>(lstm_node)->get_batch_size();
auto feature_size =
std::dynamic_pointer_cast<op::Lstm>(lstm_node)->get_src_iter_feature_size();
auto lstm_mkldnn_node = std::make_shared<op::Lstm>(
src_layer, src_iter, pattern_map[weights_i2h], pattern_map[weights_h2h], bias);
auto lstm_ht_out = std::make_shared<op::GetOutputElement>(lstm_mkldnn_node, 0);
auto lstm_ht_ct_out = std::make_shared<op::GetOutputElement>(lstm_mkldnn_node, 1);
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
auto ht_slice =
std::make_shared<op::Slice>(lstm_ht_ct_out,
Coordinate{0, 0},
Coordinate{static_cast<unsigned long>(batch_size),
static_cast<unsigned long>(feature_size)});
auto ct_slice =
std::make_shared<op::Slice>(lstm_ht_ct_out,
Coordinate{static_cast<unsigned long>(batch_size), 0},
Coordinate{static_cast<unsigned long>(2 * batch_size),
static_cast<unsigned long>(feature_size)});
// now go through the GOE'sand replace the slices(ht)
std::set<std::shared_ptr<ngraph::Node>> lstm_outputs;
for (auto& goes : lstm_node->get_outputs().at(0).get_inputs())
{
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
lstm_outputs.insert(goes->get_node());
// first output node of lstm
if (goe_node->get_n() == 0)
{
NGRAPH_DEBUG << "Replacing 1st output Lstm node " << goe_node->get_name()
<< " with " << lstm_ht_out->get_name();
ngraph::replace_node(goe_node, lstm_ht_out);
}
else if (goe_node->get_n() == 1)
{
for (auto& goe_ct_user : goe_node->get_users())
{
for (size_t i = 0; i < goe_ct_user->get_input_size(); i++)
{
if (goe_ct_user->get_argument(i) == goe_node)
{
goe_ct_user->get_inputs().at(i).replace_output(
ct_slice->get_outputs().at(0));
}
}
}
NGRAPH_DEBUG << "Replacing 2nd output Lstm node " << goe_node->get_name()
<< " with " << ct_slice->get_name();
}
}
if (lstm_outputs.find(m.get_match_root()) == lstm_outputs.end())
{
throw ngraph_error(
"Pattern matcher error, matched root node should be one of the LSTM outputs");
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(lstm_node_label, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class ConcatInputs;
}
}
}
}
class ngraph::runtime::cpu::pass::ConcatInputs : public ngraph::pass::GraphRewrite
{
public:
ConcatInputs()
: GraphRewrite()
{
concat_lstm_inputs();
}
private:
void concat_lstm_inputs();
};
......@@ -43,7 +43,9 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -1349,6 +1351,38 @@ namespace ngraph
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Lstm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
// TODO: for now, framework formats for src_layer, src_iter, weights_layer and weights_iter
// matches to the expected mkldnn format. we need to handle a case to insert convert Op's
// if the format doesn't matches.
set_default_layouts(external_function, node, false);
}
else
{
throw ngraph_error("LSTM fused op is only supported in MKLDNN for now.");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Rnn)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
// TODO: for now, framework formats for src_layer, src_iter, weights_layer and weights_iter
// matches to the expected mkldnn format. we need to handle a case to insert convert Op's
// if the format doesn't matches.
set_default_layouts(external_function, node, false);
}
else
{
throw ngraph_error("RNN fused op is only supported in MKLDNN for now.");
}
}
}
}
}
......@@ -1396,6 +1430,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <numeric>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "cpu_rnn_fusion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace ngraph;
void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_outputs().size() != pattern_map[input]->get_outputs().size())
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
{
auto input_xt = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto weights_i2h_reshape =
std::make_shared<op::Reshape>(weights_i2h, AxisVector{1, 0}, Shape{100, 400});
auto dot_1 = std::make_shared<op::Dot>(input_xt, weights_i2h_reshape);
auto bias_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto broadcast_bias_i2h = std::make_shared<op::Broadcast>(bias_i2h, Shape{10, 400}, AxisSet{0});
auto add_1 = std::make_shared<op::Add>(dot_1, broadcast_bias_i2h);
auto hidden_ht = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 50});
auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 50});
auto param2_2_reshape =
std::make_shared<op::Reshape>(weights_h2h, AxisVector{1, 0}, Shape{50, 400});
auto dot_2 = std::make_shared<op::Dot>(hidden_ht, param2_2_reshape);
auto bias_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto broadcast_bias_h2h = std::make_shared<op::Broadcast>(bias_h2h, Shape{10, 400}, AxisSet{0});
auto add_2 = std::make_shared<op::Add>(dot_2, broadcast_bias_h2h);
auto X = std::make_shared<op::Add>(add_2, add_1);
// construct forget gate
auto input_slice_0 = std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100});
auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
//ct-1 -> cell state (src_iter -> {ht | ct-1}
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
// construct input gate
auto input_slice_1 = std::make_shared<op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200});
auto input_gate = std::make_shared<op::Sigmoid>(input_slice_1);
auto input_slice_2 = std::make_shared<op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300});
auto tanh_1 = std::make_shared<op::Tanh>(input_slice_2);
auto multiply_input_gate_tanh_1 = std::make_shared<op::Multiply>(input_gate, tanh_1);
auto add_ct_1_input_gate_tanh_1 =
std::make_shared<op::Add>(multiply_forget_gate_ct_1, multiply_input_gate_tanh_1);
auto ct_label = std::make_shared<pattern::op::Label>(
add_ct_1_input_gate_tanh_1, nullptr, NodeVector{add_ct_1_input_gate_tanh_1});
// construct output gate
auto input_slice_3 = std::make_shared<op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400});
auto output_gate = std::make_shared<op::Sigmoid>(input_slice_3);
auto tanh_2 = std::make_shared<op::Tanh>(ct_label);
auto ht = std::make_shared<op::Multiply>(output_gate, tanh_2);
auto ht_label = std::make_shared<pattern::op::Label>(ht, nullptr, NodeVector{ht});
//Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [ct_label,
input_xt,
weights_i2h,
hidden_ht,
weights_h2h,
bias_i2h,
bias_h2h,
ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "In Lstm fprop call back";
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
auto input_xt_rank = input_xt->get_shape().size();
auto hidden_ht_rank = hidden_ht->get_shape().size();
auto weights_i2h_rank = weights_i2h->get_shape().size();
auto weights_h2h_rank = weights_h2h->get_shape().size();
if (input_xt_rank != 2 || hidden_ht_rank != 2 || weights_i2h_rank != 2 ||
weights_h2h_rank != 2)
{
return false;
}
if (bias_i2h->get_shape().size() != 1 || bias_h2h->get_shape().size() != 1)
{
throw ngraph_error("Bias should have rank of 1 for MKLDNN Rnn op");
}
// Determine which is ht_1 and xt. but if both xt and ht_1 have the same shape we need to capture this
// reliably in the RNN fusion.
std::shared_ptr<op::Lstm> lstm = nullptr;
bool intermediate_lstm = false;
if (std::dynamic_pointer_cast<op::GetOutputElement>(pattern_map[ct_1]))
{
intermediate_lstm = true;
}
// this checks if its a first LSTM cell and uses constant initialization of hidden states to
// differentiate between hidden state ht and input symbols xt.
if (!intermediate_lstm &&
(std::dynamic_pointer_cast<op::Broadcast>(pattern_map[hidden_ht]) &&
std::dynamic_pointer_cast<op::Constant>(pattern_map[hidden_ht]->get_argument(0))))
{
lstm = std::make_shared<op::Lstm>(pattern_map[input_xt],
pattern_map[weights_i2h],
pattern_map[hidden_ht],
pattern_map[weights_h2h],
pattern_map[bias_i2h],
pattern_map[bias_h2h],
pattern_map[ct_1]);
}
else if (!intermediate_lstm &&
(std::dynamic_pointer_cast<op::Broadcast>(pattern_map[input_xt]) &&
std::dynamic_pointer_cast<op::Constant>(pattern_map[input_xt]->get_argument(0))))
{
lstm = std::make_shared<op::Lstm>(pattern_map[hidden_ht],
pattern_map[weights_h2h],
pattern_map[input_xt],
pattern_map[weights_i2h],
pattern_map[bias_h2h],
pattern_map[bias_i2h],
pattern_map[ct_1]);
}
else if (pattern_map[ct_1]->get_shape() == pattern_map[hidden_ht]->get_shape())
{
NGRAPH_DEBUG << "ct_shape : " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[hidden_ht]->get_shape());
lstm = std::make_shared<op::Lstm>(pattern_map[input_xt],
pattern_map[weights_i2h],
pattern_map[hidden_ht],
pattern_map[weights_h2h],
pattern_map[bias_i2h],
pattern_map[bias_h2h],
pattern_map[ct_1]);
}
else
{
NGRAPH_DEBUG << "ct_shape: " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[input_xt]->get_shape());
lstm = std::make_shared<op::Lstm>(pattern_map[hidden_ht],
pattern_map[weights_h2h],
pattern_map[input_xt],
pattern_map[weights_i2h],
pattern_map[bias_h2h],
pattern_map[bias_i2h],
pattern_map[ct_1]);
}
auto ht_output = std::make_shared<op::GetOutputElement>(lstm, 0);
auto ct_output = std::make_shared<op::GetOutputElement>(lstm, 1);
// Now identify the nodes which consumes the output of LSTM nodes
// and replace them accordingly
std::vector<std::shared_ptr<Node>> new_args;
// find the user's for {ht|ct} and replace them with lstm_goe_1
for (auto node : pattern_map[ct_label]->get_users())
{
NGRAPH_DEBUG << "node_name: " << node->get_name();
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (node->get_argument(i) == pattern_map[ct_label])
{
node->get_inputs().at(i).replace_output(ct_output->get_outputs().at(0));
}
}
}
// find the user's for {ht} and replace them with lstm_goe_0
ngraph::replace_node(m.get_match_root(), ht_output);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback);
this->add_matcher(m);
}
static std::shared_ptr<ngraph::Node>
compute_rnn_args(std::vector<std::shared_ptr<pattern::op::Label>>& rnn_labels,
pattern::RecurrentMatcher& m,
bool concat_all = false)
{
NGRAPH_DEBUG << "Inside compute arg " << rnn_labels.size();
NodeVector concat_args;
// src_layer -> concatenate input symbols from different LSTM cells belonging to same RNN layer
// in the order 0, 1, 2... t time slice
if (concat_all)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[0]);
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
}
// src_iter -> concatenate ht_1|ct_1 of the first LSTM cells belonging to same RNN layer
if (rnn_labels.size() == 2)
{
for (size_t i = 0; i < rnn_labels.size(); i++)
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[i]);
// this is to make sure, we are not capturing any intermediate op's as Cell states.
if (std::dynamic_pointer_cast<op::GetOutputElement>(
node_labels[node_labels.size() - 1]))
{
throw ngraph_error(
"pattern matcher error, ht_1|ct_1 of the first LSTM cell should not match "
"intermediate LSTM outputs");
}
concat_args.push_back(node_labels[node_labels.size() - 1]);
}
return std::make_shared<op::Concat>(concat_args, 0);
}
// i2h or h2h weights shared between LSTM cells
else
{
auto node_labels = m.get_bound_nodes_for_pattern(rnn_labels[0]);
return node_labels[node_labels.size() - 1];
}
}
static bool is_unreachable(std::shared_ptr<ngraph::Node> node)
{
std::unordered_set<std::shared_ptr<ngraph::Node>> instances_seen;
std::deque<std::shared_ptr<ngraph::Node>> stack;
stack.push_front(node);
while (stack.size() > 0)
{
std::shared_ptr<ngraph::Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
{
return false;
}
instances_seen.insert(n);
}
stack.pop_front();
for (auto arg : n->get_users())
{
if (instances_seen.count(arg) == 0)
{
stack.push_front(arg);
}
}
}
return true;
}
void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto xt = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto bias_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto bias_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto rpattern_ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto lstm = std::make_shared<op::Lstm>(
xt, weights_i2h, ht_1, weights_h2h, bias_i2h, bias_h2h, rpattern_ct_1);
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0);
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
pattern::recurrent_graph_rewrite_callback callback = [lstm_node_label,
xt,
weights_h2h,
ht_1,
weights_i2h,
bias_i2h,
bias_h2h,
rpattern_ct_1](
pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In recurrent RNN fusion callback";
auto ht_1_label = m.get_bound_nodes_for_pattern(ht_1);
// determine the ht and xt
std::shared_ptr<ngraph::Node> src_layer = nullptr;
std::shared_ptr<ngraph::Node> src_iter = nullptr;
auto xt_node_array = m.get_bound_nodes_for_pattern(xt);
auto hidden_ht_array = m.get_bound_nodes_for_pattern(ht_1);
// since we dont have metadata to differentiate between xt and ht_1
// we will be using the broadcasted constant initilization of the first LSTM cell
// in the RNN layer to identify ht_1
if (std::dynamic_pointer_cast<op::Broadcast>(xt_node_array[xt_node_array.size() - 1]) &&
std::dynamic_pointer_cast<op::Constant>(
xt_node_array[xt_node_array.size() - 1]->get_argument(0)))
{
std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{ht_1};
src_layer = compute_rnn_args(src_layer_labels, m, true);
std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{xt, rpattern_ct_1};
src_iter = compute_rnn_args(src_iter_labels, m);
}
else if (std::dynamic_pointer_cast<op::Broadcast>(
hidden_ht_array[hidden_ht_array.size() - 1]) &&
std::dynamic_pointer_cast<op::Constant>(
hidden_ht_array[hidden_ht_array.size() - 1]->get_argument(0)))
{
std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{xt};
src_layer = compute_rnn_args(src_layer_labels, m, true);
std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{ht_1, rpattern_ct_1};
src_iter = compute_rnn_args(src_iter_labels, m);
}
else
{
// dont fuse, if the PM didn't discover all the cells belonging to RNN layer.
// we dont want to throw an assertion, if pattern matcher cannot discover all
// nodes belonging to RNN, instead we will return and can compute LSTM cell wise
return false;
}
std::vector<std::shared_ptr<pattern::op::Label>> weights_layer_labels{weights_i2h};
auto weights_layer = compute_rnn_args(weights_layer_labels, m);
std::vector<std::shared_ptr<pattern::op::Label>> weights_iter_labels{weights_h2h};
auto weights_iter = compute_rnn_args(weights_iter_labels, m);
auto bias_i2h_label = m.get_bound_nodes_for_pattern(bias_i2h);
auto bias_h2h_label = m.get_bound_nodes_for_pattern(bias_h2h);
auto bias = std::make_shared<op::Add>(bias_i2h_label[0], bias_h2h_label[0]);
auto num_of_lstm_matched = m.get_number_of_recurrent_matches();
size_t num_gates_in_lstm = 4;
// TODO: assert for batch_size, sequence length and num_of_lstm's fused
size_t batch_size = src_layer->get_shape()[0] / num_of_lstm_matched;
size_t sequence_len = num_of_lstm_matched;
size_t src_layer_feature_size = src_layer->get_shape()[1];
size_t feature_size = ht_1_label[0]->get_shape()[1];
// number of states for LSTM is 2
size_t num_cell_states = 2;
size_t direction = 1;
size_t num_fused_rnn_layers = 1;
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
NGRAPH_DEBUG << "weights_layer: " << join(weights_layer->get_shape());
NGRAPH_DEBUG << "weights_iter: " << join(weights_iter->get_shape());
NGRAPH_DEBUG << "bias: " << join(bias->get_shape());
NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
if ((src_layer->get_arguments().size()) != sequence_len)
{
throw ngraph_error(
"number of lstm inputs captured in the RNN fusion is not equal to "
"src_sequence_length");
}
if ((src_iter->get_arguments().size()) != num_cell_states)
{
throw ngraph_error("number of states for RNN op is not equal to (ht_1|ct_1)");
}
auto src_layer_rank = src_layer->get_shape().size();
auto src_iter_rank = src_iter->get_shape().size();
auto weights_layer_rank = weights_layer->get_shape().size();
auto weights_iter_rank = weights_iter->get_shape().size();
auto bias_rank = bias->get_shape().size();
if (src_layer_rank != 2 || src_iter_rank != 2 || weights_layer_rank != 2 ||
weights_iter_rank != 2)
{
throw ngraph_error(
"Pattern matcher error src_layer, weights_layer, src_iter, weights_iter should "
"have rank 2 for MKLDNN RNN op");
}
if (bias_rank != 1)
{
throw ngraph_error("Bias should have rank of 1 for MKLDNN Rnn op");
}
if (src_layer->get_element_type() != element::f32 ||
src_iter->get_element_type() != element::f32)
{
throw ngraph_error(
"input tensor type and input recurrent state tensor type for MKLDNN RNN op should "
"be float32");
}
auto rnn = std::make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
bias,
num_of_lstm_matched,
num_gates_in_lstm,
sequence_len,
src_layer_feature_size,
feature_size,
num_cell_states,
direction,
num_fused_rnn_layers);
std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(num_of_lstm_matched, nullptr);
auto rnn_ht_out = std::make_shared<op::GetOutputElement>(rnn, 0);
auto rnn_ct_out = std::make_shared<op::GetOutputElement>(rnn, 1);
//slice the rnn ht's
size_t start_index = 0;
size_t end_index = batch_size;
// capture the slices in the reverse order, so it corrosponds to lstm_goes order captured by the Pattern matcher
for (size_t i = 0; i < num_of_lstm_matched; i++)
{
ht_slice_per_timestep[i] = (std::make_shared<op::Slice>(
rnn_ht_out, Coordinate{start_index, 0}, Coordinate{end_index, feature_size}));
start_index += batch_size;
end_index += batch_size;
}
std::reverse(ht_slice_per_timestep.begin(), ht_slice_per_timestep.end());
NGRAPH_DEBUG << "rnn_time_slice: " << ht_slice_per_timestep.size();
// find the lstm's nodes captured in PM
auto lstm_goes = m.get_bound_nodes_for_pattern(lstm_node_label);
std::vector<std::shared_ptr<ngraph::Node>> lstm_nodes;
// we need to collect LSTM from GOE's, in order to deterministicaly determine
// the individaual time slice output ht. lstm_goes will hold the GOE in the decreasing
// order of the time slices
for (size_t i = 0; i < lstm_goes.size(); i++)
{
// lstm's will be the input to GOE's
lstm_nodes.push_back(lstm_goes[i]->get_arguments()[0]);
}
if (sequence_len != lstm_nodes.size())
{
throw ngraph_error(" Number of lstm nodes in RNN layer is not equal to time slices");
}
if (lstm_nodes.size() != lstm_goes.size() &&
lstm_goes.size() != ht_slice_per_timestep.size())
{
throw ngraph_error(
"Number of slices of rnn output ht is not equal to the time slices in RNN layer");
}
// collect all the consumers of LSTM goe's (ht)
std::set<std::shared_ptr<ngraph::Node>> lstm_goe0_user;
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> map_goe_to_lstm_slices;
std::shared_ptr<Node> goe_0;
for (size_t index = 0; index < lstm_nodes.size(); index++)
{
// now get the GOE0 which is the first output of lstm (ht)
for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs())
{
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
// first output node of lstm
if (goe_node->get_n() == 0)
{
goe_0 = goes->get_node();
}
}
for (auto goe0_user : goe_0->get_users())
{
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() &&
!is_unreachable(goe0_user))
{
lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
NGRAPH_DEBUG << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
<< " goe0_user " << goe0_user->get_name() << " ";
}
}
}
//now go through the lstm consumers and replace them with the slice
for (auto& node : lstm_goe0_user)
{
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (map_goe_to_lstm_slices.find(node->get_argument(i)) !=
map_goe_to_lstm_slices.end())
{
node->get_inputs().at(i).replace_output(
map_goe_to_lstm_slices[node->get_argument(i)]->get_outputs().at(0));
}
}
}
NGRAPH_DEBUG << "End of recurrent fusion call back "
<< "matched_node: " << m.get_match_root()->get_name();
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
lstm_node_label, rpattern_ct_1, empty_correlated_matches, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class LSTMFusion;
class RNNFusion;
}
}
}
}
class ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{
public:
LSTMFusion()
: GraphRewrite()
{
construct_sigmoid();
construct_lstm_fprop();
}
private:
void construct_sigmoid();
void construct_lstm_fprop();
};
class ngraph::runtime::cpu::pass::RNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
RNNFusion()
: RecurrentGraphRewrite()
{
construct_rnn_lstm_fprop();
}
private:
void construct_rnn_lstm_fprop();
};
......@@ -44,10 +44,14 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/serializer.hpp"
......@@ -1078,7 +1082,6 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
const size_t NUM_STEPS = 10;
auto mmb_predicate = [](std::shared_ptr<Node> node) {
auto users = node->get_users();
......@@ -1293,3 +1296,213 @@ TEST(cpu_fusion, batch_norm_folding)
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
{
auto src_layer = make_shared<op::Parameter>(element::f32, Shape{10, 100});
auto src_iter = make_shared<op::Parameter>(element::f32, Shape{20, 100});
auto weights_layer = make_shared<op::Parameter>(element::f32, Shape{400, 100});
auto weights_iter = make_shared<op::Parameter>(element::f32, Shape{400, 100});
auto biases = make_shared<op::Parameter>(element::f32, Shape{400});
const int number_of_timesteps = 1;
const int number_of_gates_per_cell = 4;
const int src_seq_length = 1;
const int src_layer_feature_size = 100;
const int feature_size = 100;
const int num_rnn_cell_states = 2;
const int rnn_direction = 1;
const int num_of_rnn_fused_layer = 1;
auto rnn_node = make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
biases,
number_of_timesteps,
number_of_gates_per_cell,
src_seq_length,
src_layer_feature_size,
feature_size,
num_rnn_cell_states,
rnn_direction,
num_of_rnn_fused_layer);
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
auto func = make_shared<Function>(
NodeVector{rnn_ht_output, rnn_ct_output},
op::ParameterVector{src_layer, src_iter, weights_layer, weights_iter, biases});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> src_layer_t =
backend->create_tensor(element::f32, src_layer->get_shape());
shared_ptr<runtime::TensorView> src_iter_t =
backend->create_tensor(element::f32, src_iter->get_shape());
shared_ptr<runtime::TensorView> weights_layer_t =
backend->create_tensor(element::f32, weights_layer->get_shape());
shared_ptr<runtime::TensorView> weights_iter_t =
backend->create_tensor(element::f32, weights_iter->get_shape());
shared_ptr<runtime::TensorView> biases_t =
backend->create_tensor(element::f32, biases->get_shape());
shared_ptr<runtime::TensorView> result_ht = backend->create_tensor(element::f32, {10, 100});
shared_ptr<runtime::TensorView> result_ct =
backend->create_tensor(element::f32, Shape{20, 100});
copy_data(src_layer_t, vector<float>(1000, 1));
copy_data(src_iter_t, vector<float>(2000, 1));
copy_data(weights_layer_t, vector<float>(400 * 100, 1));
copy_data(weights_iter_t, vector<float>(400 * 100, 1));
copy_data(biases_t, vector<float>(400, 1));
backend->call(func,
{result_ht, result_ct},
{src_layer_t, src_iter_t, weights_layer_t, weights_iter_t, biases_t});
vector<float> expected_ht(10 * 100, 0.964028f);
vector<float> expected_ct;
for (size_t i = 0; i < 20 * 100; i++)
{
if (i < 1000)
{
expected_ct.push_back(0.964028f);
}
else
{
expected_ct.push_back(2.0f);
}
}
EXPECT_TRUE(test::all_close(expected_ht, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(expected_ct, read_vector<float>(result_ct)));
}
TEST(cpu_fusion, fuse_lstm_cells)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
auto lstm_ops = get_ops_of_type<op::Lstm>(func);
EXPECT_EQ(lstm_ops.size(), 6);
}
TEST(cpu_fusion, fuse_2_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
EXPECT_EQ(node->get_num_cell_states(), node->get_argument(1)->get_arguments().size());
}
}
TEST(cpu_fusion, fuse_1_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/1rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), 1);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
EXPECT_EQ(node->get_num_cell_states(), node->get_argument(1)->get_arguments().size());
}
}
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_1rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_2rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
......@@ -31,6 +31,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......
[{
"name" : "Function_0",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_10",
"op" : "Parameter",
"outputs" : ["Parameter_10_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_9",
"op" : "Parameter",
"outputs" : ["Parameter_9_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 10, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_7",
"op" : "Constant",
"outputs" : ["Constant_7_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_34",
"op" : "Constant",
"outputs" : ["Constant_34_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_30",
"op" : "Constant",
"outputs" : ["Constant_30_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_24",
"op" : "Constant",
"outputs" : ["Constant_24_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_17",
"op" : "Constant",
"outputs" : ["Constant_17_0"],
"shape" : [],
"value" : ["1"]
},
{
"axes" : [0],
"inputs" : ["Parameter_10"],
"name" : "Broadcast_13",
"op" : "Broadcast",
"outputs" : ["Broadcast_13_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_9"],
"name" : "Reshape_11",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_11_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_5",
"op" : "Broadcast",
"outputs" : ["Broadcast_5_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_3",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_3_0"]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_7"],
"name" : "Broadcast_8",
"op" : "Broadcast",
"outputs" : ["Broadcast_8_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_34"],
"name" : "Broadcast_35",
"op" : "Broadcast",
"outputs" : ["Broadcast_35_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_30"],
"name" : "Broadcast_31",
"op" : "Broadcast",
"outputs" : ["Broadcast_31_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_24"],
"name" : "Broadcast_25",
"op" : "Broadcast",
"outputs" : ["Broadcast_25_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_17"],
"name" : "Broadcast_18",
"op" : "Broadcast",
"outputs" : ["Broadcast_18_0"],
"shape" : [ 10, 100 ]
},
{
"inputs" : [ "Parameter_0", "Reshape_3" ],
"name" : "Dot_4",
"op" : "Dot",
"outputs" : ["Dot_4_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Broadcast_8", "Reshape_11" ],
"name" : "Dot_12",
"op" : "Dot",
"outputs" : ["Dot_12_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_4", "Broadcast_5" ],
"name" : "Add_6",
"op" : "Add",
"outputs" : ["Add_6_0"]
},
{
"inputs" : [ "Dot_12", "Broadcast_13" ],
"name" : "Add_14",
"op" : "Add",
"outputs" : ["Add_14_0"]
},
{
"inputs" : [ "Add_6", "Add_14" ],
"name" : "Add_15",
"op" : "Add",
"outputs" : ["Add_15_0"]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_16",
"op" : "Slice",
"outputs" : ["Slice_16_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 400 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_23",
"op" : "Slice",
"outputs" : ["Slice_23_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 200 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_33",
"op" : "Slice",
"outputs" : ["Slice_33_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 100 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_40",
"op" : "Slice",
"outputs" : ["Slice_40_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 300 ]
},
{
"inputs" : ["Slice_16"],
"name" : "Negative_19",
"op" : "Negative",
"outputs" : ["Negative_19_0"]
},
{
"inputs" : ["Slice_23"],
"name" : "Negative_26",
"op" : "Negative",
"outputs" : ["Negative_26_0"]
},
{
"inputs" : ["Slice_33"],
"name" : "Negative_36",
"op" : "Negative",
"outputs" : ["Negative_36_0"]
},
{
"inputs" : ["Slice_40"],
"name" : "Tanh_41",
"op" : "Tanh",
"outputs" : ["Tanh_41_0"]
},
{
"inputs" : ["Negative_19"],
"name" : "Exp_20",
"op" : "Exp",
"outputs" : ["Exp_20_0"]
},
{
"inputs" : ["Negative_26"],
"name" : "Exp_27",
"op" : "Exp",
"outputs" : ["Exp_27_0"]
},
{
"inputs" : ["Negative_36"],
"name" : "Exp_37",
"op" : "Exp",
"outputs" : ["Exp_37_0"]
},
{
"inputs" : [ "Broadcast_18", "Exp_20" ],
"name" : "Add_21",
"op" : "Add",
"outputs" : ["Add_21_0"]
},
{
"inputs" : [ "Broadcast_25", "Exp_27" ],
"name" : "Add_28",
"op" : "Add",
"outputs" : ["Add_28_0"]
},
{
"inputs" : [ "Broadcast_35", "Exp_37" ],
"name" : "Add_38",
"op" : "Add",
"outputs" : ["Add_38_0"]
},
{
"inputs" : [ "Broadcast_18", "Add_21" ],
"name" : "Divide_22",
"op" : "Divide",
"outputs" : ["Divide_22_0"]
},
{
"inputs" : [ "Broadcast_25", "Add_28" ],
"name" : "Divide_29",
"op" : "Divide",
"outputs" : ["Divide_29_0"]
},
{
"inputs" : [ "Broadcast_35", "Add_38" ],
"name" : "Divide_39",
"op" : "Divide",
"outputs" : ["Divide_39_0"]
},
{
"inputs" : [ "Divide_29", "Broadcast_31" ],
"name" : "Multiply_32",
"op" : "Multiply",
"outputs" : ["Multiply_32_0"]
},
{
"inputs" : [ "Divide_39", "Tanh_41" ],
"name" : "Multiply_42",
"op" : "Multiply",
"outputs" : ["Multiply_42_0"]
},
{
"inputs" : [ "Multiply_32", "Multiply_42" ],
"name" : "Add_43",
"op" : "Add",
"outputs" : ["Add_43_0"]
},
{
"inputs" : ["Add_43"],
"name" : "Tanh_44",
"op" : "Tanh",
"outputs" : ["Tanh_44_0"]
},
{
"inputs" : [ "Divide_22", "Tanh_44" ],
"name" : "Multiply_45",
"op" : "Multiply",
"outputs" : ["Multiply_45_0"]
},
{
"inputs" : ["Multiply_45"],
"name" : "Result_46",
"op" : "Result",
"outputs" : ["Result_46_0"]
}
],
"parameters" : [
"Parameter_0", "Parameter_1", "Parameter_2", "Parameter_9",
"Parameter_10"
],
"result" : ["Result_46"]
}]
[{
"name" : "Function_0",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 10, 50 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_7",
"op" : "Parameter",
"outputs" : ["Parameter_7_0"],
"shape" : [ 10, 50 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_20",
"op" : "Parameter",
"outputs" : ["Parameter_20_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_19",
"op" : "Parameter",
"outputs" : ["Parameter_19_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [ 400, 50 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_12",
"op" : "Parameter",
"outputs" : ["Parameter_12_0"],
"shape" : [ 10, 50 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_17",
"op" : "Constant",
"outputs" : ["Constant_17_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_44",
"op" : "Constant",
"outputs" : ["Constant_44_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_40",
"op" : "Constant",
"outputs" : ["Constant_40_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_34",
"op" : "Constant",
"outputs" : ["Constant_34_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_27",
"op" : "Constant",
"outputs" : ["Constant_27_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_77",
"op" : "Constant",
"outputs" : ["Constant_77_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_69",
"op" : "Constant",
"outputs" : ["Constant_69_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_62",
"op" : "Constant",
"outputs" : ["Constant_62_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_110",
"op" : "Constant",
"outputs" : ["Constant_110_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_102",
"op" : "Constant",
"outputs" : ["Constant_102_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_95",
"op" : "Constant",
"outputs" : ["Constant_95_0"],
"shape" : [],
"value" : ["1"]
},
{
"axes" : [0],
"inputs" : ["Parameter_20"],
"name" : "Broadcast_23",
"op" : "Broadcast",
"outputs" : ["Broadcast_23_0"],
"shape" : [ 10, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_20"],
"name" : "Broadcast_58",
"op" : "Broadcast",
"outputs" : ["Broadcast_58_0"],
"shape" : [ 10, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_20"],
"name" : "Broadcast_91",
"op" : "Broadcast",
"outputs" : ["Broadcast_91_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_19"],
"name" : "Reshape_21",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_21_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_19"],
"name" : "Reshape_56",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_56_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_19"],
"name" : "Reshape_89",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_89_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_5",
"op" : "Broadcast",
"outputs" : ["Broadcast_5_0"],
"shape" : [ 10, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_10",
"op" : "Broadcast",
"outputs" : ["Broadcast_10_0"],
"shape" : [ 10, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_15",
"op" : "Broadcast",
"outputs" : ["Broadcast_15_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_3",
"op" : "Reshape",
"output_shape" : [ 50, 400 ],
"outputs" : ["Reshape_3_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_8",
"op" : "Reshape",
"output_shape" : [ 50, 400 ],
"outputs" : ["Reshape_8_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_13",
"op" : "Reshape",
"output_shape" : [ 50, 400 ],
"outputs" : ["Reshape_13_0"]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_17"],
"name" : "Broadcast_18",
"op" : "Broadcast",
"outputs" : ["Broadcast_18_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_44"],
"name" : "Broadcast_45",
"op" : "Broadcast",
"outputs" : ["Broadcast_45_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_40"],
"name" : "Broadcast_41",
"op" : "Broadcast",
"outputs" : ["Broadcast_41_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_34"],
"name" : "Broadcast_35",
"op" : "Broadcast",
"outputs" : ["Broadcast_35_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_27"],
"name" : "Broadcast_28",
"op" : "Broadcast",
"outputs" : ["Broadcast_28_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_77"],
"name" : "Broadcast_78",
"op" : "Broadcast",
"outputs" : ["Broadcast_78_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_69"],
"name" : "Broadcast_70",
"op" : "Broadcast",
"outputs" : ["Broadcast_70_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_62"],
"name" : "Broadcast_63",
"op" : "Broadcast",
"outputs" : ["Broadcast_63_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_110"],
"name" : "Broadcast_111",
"op" : "Broadcast",
"outputs" : ["Broadcast_111_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_102"],
"name" : "Broadcast_103",
"op" : "Broadcast",
"outputs" : ["Broadcast_103_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_95"],
"name" : "Broadcast_96",
"op" : "Broadcast",
"outputs" : ["Broadcast_96_0"],
"shape" : [ 10, 100 ]
},
{
"inputs" : [ "Parameter_0", "Reshape_3" ],
"name" : "Dot_4",
"op" : "Dot",
"outputs" : ["Dot_4_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Parameter_7", "Reshape_8" ],
"name" : "Dot_9",
"op" : "Dot",
"outputs" : ["Dot_9_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Parameter_12", "Reshape_13" ],
"name" : "Dot_14",
"op" : "Dot",
"outputs" : ["Dot_14_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Broadcast_18", "Reshape_21" ],
"name" : "Dot_22",
"op" : "Dot",
"outputs" : ["Dot_22_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_4", "Broadcast_5" ],
"name" : "Add_6",
"op" : "Add",
"outputs" : ["Add_6_0"]
},
{
"inputs" : [ "Dot_9", "Broadcast_10" ],
"name" : "Add_11",
"op" : "Add",
"outputs" : ["Add_11_0"]
},
{
"inputs" : [ "Dot_14", "Broadcast_15" ],
"name" : "Add_16",
"op" : "Add",
"outputs" : ["Add_16_0"]
},
{
"inputs" : [ "Dot_22", "Broadcast_23" ],
"name" : "Add_24",
"op" : "Add",
"outputs" : ["Add_24_0"]
},
{
"inputs" : [ "Add_16", "Add_24" ],
"name" : "Add_25",
"op" : "Add",
"outputs" : ["Add_25_0"]
},
{
"inputs" : ["Add_25"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_26",
"op" : "Slice",
"outputs" : ["Slice_26_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 400 ]
},
{
"inputs" : ["Add_25"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_33",
"op" : "Slice",
"outputs" : ["Slice_33_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 200 ]
},
{
"inputs" : ["Add_25"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_43",
"op" : "Slice",
"outputs" : ["Slice_43_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 100 ]
},
{
"inputs" : ["Add_25"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_50",
"op" : "Slice",
"outputs" : ["Slice_50_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 300 ]
},
{
"inputs" : ["Slice_26"],
"name" : "Negative_29",
"op" : "Negative",
"outputs" : ["Negative_29_0"]
},
{
"inputs" : ["Slice_33"],
"name" : "Negative_36",
"op" : "Negative",
"outputs" : ["Negative_36_0"]
},
{
"inputs" : ["Slice_43"],
"name" : "Negative_46",
"op" : "Negative",
"outputs" : ["Negative_46_0"]
},
{
"inputs" : ["Slice_50"],
"name" : "Tanh_51",
"op" : "Tanh",
"outputs" : ["Tanh_51_0"]
},
{
"inputs" : ["Negative_29"],
"name" : "Exp_30",
"op" : "Exp",
"outputs" : ["Exp_30_0"]
},
{
"inputs" : ["Negative_36"],
"name" : "Exp_37",
"op" : "Exp",
"outputs" : ["Exp_37_0"]
},
{
"inputs" : ["Negative_46"],
"name" : "Exp_47",
"op" : "Exp",
"outputs" : ["Exp_47_0"]
},
{
"inputs" : [ "Broadcast_28", "Exp_30" ],
"name" : "Add_31",
"op" : "Add",
"outputs" : ["Add_31_0"]
},
{
"inputs" : [ "Broadcast_35", "Exp_37" ],
"name" : "Add_38",
"op" : "Add",
"outputs" : ["Add_38_0"]
},
{
"inputs" : [ "Broadcast_45", "Exp_47" ],
"name" : "Add_48",
"op" : "Add",
"outputs" : ["Add_48_0"]
},
{
"inputs" : [ "Broadcast_28", "Add_31" ],
"name" : "Divide_32",
"op" : "Divide",
"outputs" : ["Divide_32_0"]
},
{
"inputs" : [ "Broadcast_35", "Add_38" ],
"name" : "Divide_39",
"op" : "Divide",
"outputs" : ["Divide_39_0"]
},
{
"inputs" : [ "Broadcast_45", "Add_48" ],
"name" : "Divide_49",
"op" : "Divide",
"outputs" : ["Divide_49_0"]
},
{
"inputs" : [ "Divide_39", "Broadcast_41" ],
"name" : "Multiply_42",
"op" : "Multiply",
"outputs" : ["Multiply_42_0"]
},
{
"inputs" : [ "Divide_49", "Tanh_51" ],
"name" : "Multiply_52",
"op" : "Multiply",
"outputs" : ["Multiply_52_0"]
},
{
"inputs" : [ "Multiply_42", "Multiply_52" ],
"name" : "Add_53",
"op" : "Add",
"outputs" : ["Add_53_0"]
},
{
"inputs" : ["Add_53"],
"name" : "Tanh_54",
"op" : "Tanh",
"outputs" : ["Tanh_54_0"]
},
{
"inputs" : [ "Divide_32", "Tanh_54" ],
"name" : "Multiply_55",
"op" : "Multiply",
"outputs" : ["Multiply_55_0"]
},
{
"inputs" : [ "Multiply_55", "Reshape_56" ],
"name" : "Dot_57",
"op" : "Dot",
"outputs" : ["Dot_57_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_57", "Broadcast_58" ],
"name" : "Add_59",
"op" : "Add",
"outputs" : ["Add_59_0"]
},
{
"inputs" : [ "Add_11", "Add_59" ],
"name" : "Add_60",
"op" : "Add",
"outputs" : ["Add_60_0"]
},
{
"inputs" : ["Add_60"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_61",
"op" : "Slice",
"outputs" : ["Slice_61_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 400 ]
},
{
"inputs" : ["Add_60"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_68",
"op" : "Slice",
"outputs" : ["Slice_68_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 200 ]
},
{
"inputs" : ["Add_60"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_76",
"op" : "Slice",
"outputs" : ["Slice_76_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 100 ]
},
{
"inputs" : ["Add_60"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_83",
"op" : "Slice",
"outputs" : ["Slice_83_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 300 ]
},
{
"inputs" : ["Slice_61"],
"name" : "Negative_64",
"op" : "Negative",
"outputs" : ["Negative_64_0"]
},
{
"inputs" : ["Slice_68"],
"name" : "Negative_71",
"op" : "Negative",
"outputs" : ["Negative_71_0"]
},
{
"inputs" : ["Slice_76"],
"name" : "Negative_79",
"op" : "Negative",
"outputs" : ["Negative_79_0"]
},
{
"inputs" : ["Slice_83"],
"name" : "Tanh_84",
"op" : "Tanh",
"outputs" : ["Tanh_84_0"]
},
{
"inputs" : ["Negative_64"],
"name" : "Exp_65",
"op" : "Exp",
"outputs" : ["Exp_65_0"]
},
{
"inputs" : ["Negative_71"],
"name" : "Exp_72",
"op" : "Exp",
"outputs" : ["Exp_72_0"]
},
{
"inputs" : ["Negative_79"],
"name" : "Exp_80",
"op" : "Exp",
"outputs" : ["Exp_80_0"]
},
{
"inputs" : [ "Broadcast_63", "Exp_65" ],
"name" : "Add_66",
"op" : "Add",
"outputs" : ["Add_66_0"]
},
{
"inputs" : [ "Broadcast_70", "Exp_72" ],
"name" : "Add_73",
"op" : "Add",
"outputs" : ["Add_73_0"]
},
{
"inputs" : [ "Broadcast_78", "Exp_80" ],
"name" : "Add_81",
"op" : "Add",
"outputs" : ["Add_81_0"]
},
{
"inputs" : [ "Broadcast_63", "Add_66" ],
"name" : "Divide_67",
"op" : "Divide",
"outputs" : ["Divide_67_0"]
},
{
"inputs" : [ "Broadcast_70", "Add_73" ],
"name" : "Divide_74",
"op" : "Divide",
"outputs" : ["Divide_74_0"]
},
{
"inputs" : [ "Broadcast_78", "Add_81" ],
"name" : "Divide_82",
"op" : "Divide",
"outputs" : ["Divide_82_0"]
},
{
"inputs" : [ "Divide_74", "Add_53" ],
"name" : "Multiply_75",
"op" : "Multiply",
"outputs" : ["Multiply_75_0"]
},
{
"inputs" : [ "Divide_82", "Tanh_84" ],
"name" : "Multiply_85",
"op" : "Multiply",
"outputs" : ["Multiply_85_0"]
},
{
"inputs" : [ "Multiply_75", "Multiply_85" ],
"name" : "Add_86",
"op" : "Add",
"outputs" : ["Add_86_0"]
},
{
"inputs" : ["Add_86"],
"name" : "Tanh_87",
"op" : "Tanh",
"outputs" : ["Tanh_87_0"]
},
{
"inputs" : [ "Divide_67", "Tanh_87" ],
"name" : "Multiply_88",
"op" : "Multiply",
"outputs" : ["Multiply_88_0"]
},
{
"inputs" : [ "Multiply_88", "Reshape_89" ],
"name" : "Dot_90",
"op" : "Dot",
"outputs" : ["Dot_90_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_90", "Broadcast_91" ],
"name" : "Add_92",
"op" : "Add",
"outputs" : ["Add_92_0"]
},
{
"inputs" : [ "Add_6", "Add_92" ],
"name" : "Add_93",
"op" : "Add",
"outputs" : ["Add_93_0"]
},
{
"inputs" : ["Add_93"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_94",
"op" : "Slice",
"outputs" : ["Slice_94_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 400 ]
},
{
"inputs" : ["Add_93"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_101",
"op" : "Slice",
"outputs" : ["Slice_101_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 200 ]
},
{
"inputs" : ["Add_93"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_109",
"op" : "Slice",
"outputs" : ["Slice_109_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 100 ]
},
{
"inputs" : ["Add_93"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_116",
"op" : "Slice",
"outputs" : ["Slice_116_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 300 ]
},
{
"inputs" : ["Slice_94"],
"name" : "Negative_97",
"op" : "Negative",
"outputs" : ["Negative_97_0"]
},
{
"inputs" : ["Slice_101"],
"name" : "Negative_104",
"op" : "Negative",
"outputs" : ["Negative_104_0"]
},
{
"inputs" : ["Slice_109"],
"name" : "Negative_112",
"op" : "Negative",
"outputs" : ["Negative_112_0"]
},
{
"inputs" : ["Slice_116"],
"name" : "Tanh_117",
"op" : "Tanh",
"outputs" : ["Tanh_117_0"]
},
{
"inputs" : ["Negative_97"],
"name" : "Exp_98",
"op" : "Exp",
"outputs" : ["Exp_98_0"]
},
{
"inputs" : ["Negative_104"],
"name" : "Exp_105",
"op" : "Exp",
"outputs" : ["Exp_105_0"]
},
{
"inputs" : ["Negative_112"],
"name" : "Exp_113",
"op" : "Exp",
"outputs" : ["Exp_113_0"]
},
{
"inputs" : [ "Broadcast_96", "Exp_98" ],
"name" : "Add_99",
"op" : "Add",
"outputs" : ["Add_99_0"]
},
{
"inputs" : [ "Broadcast_103", "Exp_105" ],
"name" : "Add_106",
"op" : "Add",
"outputs" : ["Add_106_0"]
},
{
"inputs" : [ "Broadcast_111", "Exp_113" ],
"name" : "Add_114",
"op" : "Add",
"outputs" : ["Add_114_0"]
},
{
"inputs" : [ "Broadcast_96", "Add_99" ],
"name" : "Divide_100",
"op" : "Divide",
"outputs" : ["Divide_100_0"]
},
{
"inputs" : [ "Broadcast_103", "Add_106" ],
"name" : "Divide_107",
"op" : "Divide",
"outputs" : ["Divide_107_0"]
},
{
"inputs" : [ "Broadcast_111", "Add_114" ],
"name" : "Divide_115",
"op" : "Divide",
"outputs" : ["Divide_115_0"]
},
{
"inputs" : [ "Divide_107", "Add_86" ],
"name" : "Multiply_108",
"op" : "Multiply",
"outputs" : ["Multiply_108_0"]
},
{
"inputs" : [ "Divide_115", "Tanh_117" ],
"name" : "Multiply_118",
"op" : "Multiply",
"outputs" : ["Multiply_118_0"]
},
{
"inputs" : [ "Multiply_108", "Multiply_118" ],
"name" : "Add_119",
"op" : "Add",
"outputs" : ["Add_119_0"]
},
{
"inputs" : ["Add_119"],
"name" : "Tanh_120",
"op" : "Tanh",
"outputs" : ["Tanh_120_0"]
},
{
"inputs" : [ "Divide_100", "Tanh_120" ],
"name" : "Multiply_121",
"op" : "Multiply",
"outputs" : ["Multiply_121_0"]
},
{
"inputs" : ["Multiply_121"],
"name" : "Result_122",
"op" : "Result",
"outputs" : ["Result_122_0"]
}
],
"parameters" : [
"Parameter_12", "Parameter_1", "Parameter_2", "Parameter_19",
"Parameter_20", "Parameter_7", "Parameter_0"
],
"result" : ["Result_122"]
}]
[{
"name" : "Function_0",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_170",
"op" : "Parameter",
"outputs" : ["Parameter_170_0"],
"shape" : [ 32, 1, 200 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_93",
"op" : "Parameter",
"outputs" : ["Parameter_93_0"],
"shape" : [ 32, 1, 200 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_56",
"op" : "Parameter",
"outputs" : ["Parameter_56_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_55",
"op" : "Parameter",
"outputs" : ["Parameter_55_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_48",
"op" : "Parameter",
"outputs" : ["Parameter_48_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_47",
"op" : "Parameter",
"outputs" : ["Parameter_47_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_11",
"op" : "Parameter",
"outputs" : ["Parameter_11_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_10",
"op" : "Parameter",
"outputs" : ["Parameter_10_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [ 400, 200 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 32, 1, 200 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_53",
"op" : "Constant",
"outputs" : ["Constant_53_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_8",
"op" : "Constant",
"outputs" : ["Constant_8_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_35",
"op" : "Constant",
"outputs" : ["Constant_35_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_31",
"op" : "Constant",
"outputs" : ["Constant_31_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_25",
"op" : "Constant",
"outputs" : ["Constant_25_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_18",
"op" : "Constant",
"outputs" : ["Constant_18_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_80",
"op" : "Constant",
"outputs" : ["Constant_80_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_76",
"op" : "Constant",
"outputs" : ["Constant_76_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_70",
"op" : "Constant",
"outputs" : ["Constant_70_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_63",
"op" : "Constant",
"outputs" : ["Constant_63_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_120",
"op" : "Constant",
"outputs" : ["Constant_120_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_112",
"op" : "Constant",
"outputs" : ["Constant_112_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_105",
"op" : "Constant",
"outputs" : ["Constant_105_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_157",
"op" : "Constant",
"outputs" : ["Constant_157_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_149",
"op" : "Constant",
"outputs" : ["Constant_149_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_142",
"op" : "Constant",
"outputs" : ["Constant_142_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_197",
"op" : "Constant",
"outputs" : ["Constant_197_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_189",
"op" : "Constant",
"outputs" : ["Constant_189_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_182",
"op" : "Constant",
"outputs" : ["Constant_182_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_234",
"op" : "Constant",
"outputs" : ["Constant_234_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_226",
"op" : "Constant",
"outputs" : ["Constant_226_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_219",
"op" : "Constant",
"outputs" : ["Constant_219_0"],
"shape" : [],
"value" : ["1"]
},
{
"input_order" : [ 0, 1, 2 ],
"inputs" : ["Parameter_170"],
"name" : "Reshape_171",
"op" : "Reshape",
"output_shape" : [ 32, 200 ],
"outputs" : ["Reshape_171_0"]
},
{
"input_order" : [ 0, 1, 2 ],
"inputs" : ["Parameter_93"],
"name" : "Reshape_94",
"op" : "Reshape",
"output_shape" : [ 32, 200 ],
"outputs" : ["Reshape_94_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_56"],
"name" : "Broadcast_59",
"op" : "Broadcast",
"outputs" : ["Broadcast_59_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_56"],
"name" : "Broadcast_138",
"op" : "Broadcast",
"outputs" : ["Broadcast_138_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_56"],
"name" : "Broadcast_215",
"op" : "Broadcast",
"outputs" : ["Broadcast_215_0"],
"shape" : [ 32, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_55"],
"name" : "Reshape_57",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_57_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_55"],
"name" : "Reshape_136",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_136_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_55"],
"name" : "Reshape_213",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_213_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_48"],
"name" : "Broadcast_51",
"op" : "Broadcast",
"outputs" : ["Broadcast_51_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_48"],
"name" : "Broadcast_134",
"op" : "Broadcast",
"outputs" : ["Broadcast_134_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_48"],
"name" : "Broadcast_211",
"op" : "Broadcast",
"outputs" : ["Broadcast_211_0"],
"shape" : [ 32, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_47"],
"name" : "Reshape_49",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_49_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_47"],
"name" : "Reshape_132",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_132_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_47"],
"name" : "Reshape_209",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_209_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_11"],
"name" : "Broadcast_14",
"op" : "Broadcast",
"outputs" : ["Broadcast_14_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_11"],
"name" : "Broadcast_101",
"op" : "Broadcast",
"outputs" : ["Broadcast_101_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_11"],
"name" : "Broadcast_178",
"op" : "Broadcast",
"outputs" : ["Broadcast_178_0"],
"shape" : [ 32, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_10"],
"name" : "Reshape_12",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_12_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_10"],
"name" : "Reshape_99",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_99_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_10"],
"name" : "Reshape_176",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_176_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_6",
"op" : "Broadcast",
"outputs" : ["Broadcast_6_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_97",
"op" : "Broadcast",
"outputs" : ["Broadcast_97_0"],
"shape" : [ 32, 400 ]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_174",
"op" : "Broadcast",
"outputs" : ["Broadcast_174_0"],
"shape" : [ 32, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_4",
"op" : "Reshape",
"output_shape" : [ 200, 400 ],
"outputs" : ["Reshape_4_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_95",
"op" : "Reshape",
"output_shape" : [ 200, 400 ],
"outputs" : ["Reshape_95_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_172",
"op" : "Reshape",
"output_shape" : [ 200, 400 ],
"outputs" : ["Reshape_172_0"]
},
{
"input_order" : [ 0, 1, 2 ],
"inputs" : ["Parameter_0"],
"name" : "Reshape_3",
"op" : "Reshape",
"output_shape" : [ 32, 200 ],
"outputs" : ["Reshape_3_0"]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_53"],
"name" : "Broadcast_54",
"op" : "Broadcast",
"outputs" : ["Broadcast_54_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_8"],
"name" : "Broadcast_9",
"op" : "Broadcast",
"outputs" : ["Broadcast_9_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_35"],
"name" : "Broadcast_36",
"op" : "Broadcast",
"outputs" : ["Broadcast_36_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_31"],
"name" : "Broadcast_32",
"op" : "Broadcast",
"outputs" : ["Broadcast_32_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_25"],
"name" : "Broadcast_26",
"op" : "Broadcast",
"outputs" : ["Broadcast_26_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_18"],
"name" : "Broadcast_19",
"op" : "Broadcast",
"outputs" : ["Broadcast_19_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_80"],
"name" : "Broadcast_81",
"op" : "Broadcast",
"outputs" : ["Broadcast_81_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_76"],
"name" : "Broadcast_77",
"op" : "Broadcast",
"outputs" : ["Broadcast_77_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_70"],
"name" : "Broadcast_71",
"op" : "Broadcast",
"outputs" : ["Broadcast_71_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_63"],
"name" : "Broadcast_64",
"op" : "Broadcast",
"outputs" : ["Broadcast_64_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_120"],
"name" : "Broadcast_121",
"op" : "Broadcast",
"outputs" : ["Broadcast_121_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_112"],
"name" : "Broadcast_113",
"op" : "Broadcast",
"outputs" : ["Broadcast_113_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_105"],
"name" : "Broadcast_106",
"op" : "Broadcast",
"outputs" : ["Broadcast_106_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_157"],
"name" : "Broadcast_158",
"op" : "Broadcast",
"outputs" : ["Broadcast_158_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_149"],
"name" : "Broadcast_150",
"op" : "Broadcast",
"outputs" : ["Broadcast_150_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_142"],
"name" : "Broadcast_143",
"op" : "Broadcast",
"outputs" : ["Broadcast_143_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_197"],
"name" : "Broadcast_198",
"op" : "Broadcast",
"outputs" : ["Broadcast_198_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_189"],
"name" : "Broadcast_190",
"op" : "Broadcast",
"outputs" : ["Broadcast_190_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_182"],
"name" : "Broadcast_183",
"op" : "Broadcast",
"outputs" : ["Broadcast_183_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_234"],
"name" : "Broadcast_235",
"op" : "Broadcast",
"outputs" : ["Broadcast_235_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_226"],
"name" : "Broadcast_227",
"op" : "Broadcast",
"outputs" : ["Broadcast_227_0"],
"shape" : [ 32, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_219"],
"name" : "Broadcast_220",
"op" : "Broadcast",
"outputs" : ["Broadcast_220_0"],
"shape" : [ 32, 100 ]
},
{
"inputs" : [ "Reshape_94", "Reshape_95" ],
"name" : "Dot_96",
"op" : "Dot",
"outputs" : ["Dot_96_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Reshape_171", "Reshape_172" ],
"name" : "Dot_173",
"op" : "Dot",
"outputs" : ["Dot_173_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Reshape_3", "Reshape_4" ],
"name" : "Dot_5",
"op" : "Dot",
"outputs" : ["Dot_5_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Broadcast_54", "Reshape_57" ],
"name" : "Dot_58",
"op" : "Dot",
"outputs" : ["Dot_58_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Broadcast_9", "Reshape_12" ],
"name" : "Dot_13",
"op" : "Dot",
"outputs" : ["Dot_13_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_96", "Broadcast_97" ],
"name" : "Add_98",
"op" : "Add",
"outputs" : ["Add_98_0"]
},
{
"inputs" : [ "Dot_173", "Broadcast_174" ],
"name" : "Add_175",
"op" : "Add",
"outputs" : ["Add_175_0"]
},
{
"inputs" : [ "Dot_5", "Broadcast_6" ],
"name" : "Add_7",
"op" : "Add",
"outputs" : ["Add_7_0"]
},
{
"inputs" : [ "Dot_58", "Broadcast_59" ],
"name" : "Add_60",
"op" : "Add",
"outputs" : ["Add_60_0"]
},
{
"inputs" : [ "Dot_13", "Broadcast_14" ],
"name" : "Add_15",
"op" : "Add",
"outputs" : ["Add_15_0"]
},
{
"inputs" : [ "Add_7", "Add_15" ],
"name" : "Add_16",
"op" : "Add",
"outputs" : ["Add_16_0"]
},
{
"inputs" : ["Add_16"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_17",
"op" : "Slice",
"outputs" : ["Slice_17_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_16"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_24",
"op" : "Slice",
"outputs" : ["Slice_24_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_16"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_34",
"op" : "Slice",
"outputs" : ["Slice_34_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_16"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_41",
"op" : "Slice",
"outputs" : ["Slice_41_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Slice_17"],
"name" : "Negative_20",
"op" : "Negative",
"outputs" : ["Negative_20_0"]
},
{
"inputs" : ["Slice_24"],
"name" : "Negative_27",
"op" : "Negative",
"outputs" : ["Negative_27_0"]
},
{
"inputs" : ["Slice_34"],
"name" : "Negative_37",
"op" : "Negative",
"outputs" : ["Negative_37_0"]
},
{
"inputs" : ["Slice_41"],
"name" : "Tanh_42",
"op" : "Tanh",
"outputs" : ["Tanh_42_0"]
},
{
"inputs" : ["Negative_20"],
"name" : "Exp_21",
"op" : "Exp",
"outputs" : ["Exp_21_0"]
},
{
"inputs" : ["Negative_27"],
"name" : "Exp_28",
"op" : "Exp",
"outputs" : ["Exp_28_0"]
},
{
"inputs" : ["Negative_37"],
"name" : "Exp_38",
"op" : "Exp",
"outputs" : ["Exp_38_0"]
},
{
"inputs" : [ "Broadcast_19", "Exp_21" ],
"name" : "Add_22",
"op" : "Add",
"outputs" : ["Add_22_0"]
},
{
"inputs" : [ "Broadcast_26", "Exp_28" ],
"name" : "Add_29",
"op" : "Add",
"outputs" : ["Add_29_0"]
},
{
"inputs" : [ "Broadcast_36", "Exp_38" ],
"name" : "Add_39",
"op" : "Add",
"outputs" : ["Add_39_0"]
},
{
"inputs" : [ "Broadcast_19", "Add_22" ],
"name" : "Divide_23",
"op" : "Divide",
"outputs" : ["Divide_23_0"]
},
{
"inputs" : [ "Broadcast_26", "Add_29" ],
"name" : "Divide_30",
"op" : "Divide",
"outputs" : ["Divide_30_0"]
},
{
"inputs" : [ "Broadcast_36", "Add_39" ],
"name" : "Divide_40",
"op" : "Divide",
"outputs" : ["Divide_40_0"]
},
{
"inputs" : [ "Divide_30", "Broadcast_32" ],
"name" : "Multiply_33",
"op" : "Multiply",
"outputs" : ["Multiply_33_0"]
},
{
"inputs" : [ "Divide_40", "Tanh_42" ],
"name" : "Multiply_43",
"op" : "Multiply",
"outputs" : ["Multiply_43_0"]
},
{
"inputs" : [ "Multiply_33", "Multiply_43" ],
"name" : "Add_44",
"op" : "Add",
"outputs" : ["Add_44_0"]
},
{
"inputs" : ["Add_44"],
"name" : "Tanh_45",
"op" : "Tanh",
"outputs" : ["Tanh_45_0"]
},
{
"inputs" : [ "Divide_23", "Tanh_45" ],
"name" : "Multiply_46",
"op" : "Multiply",
"outputs" : ["Multiply_46_0"]
},
{
"inputs" : [ "Multiply_46", "Reshape_49" ],
"name" : "Dot_50",
"op" : "Dot",
"outputs" : ["Dot_50_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Multiply_46", "Reshape_99" ],
"name" : "Dot_100",
"op" : "Dot",
"outputs" : ["Dot_100_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_50", "Broadcast_51" ],
"name" : "Add_52",
"op" : "Add",
"outputs" : ["Add_52_0"]
},
{
"inputs" : [ "Dot_100", "Broadcast_101" ],
"name" : "Add_102",
"op" : "Add",
"outputs" : ["Add_102_0"]
},
{
"inputs" : [ "Add_52", "Add_60" ],
"name" : "Add_61",
"op" : "Add",
"outputs" : ["Add_61_0"]
},
{
"inputs" : [ "Add_98", "Add_102" ],
"name" : "Add_103",
"op" : "Add",
"outputs" : ["Add_103_0"]
},
{
"inputs" : ["Add_61"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_62",
"op" : "Slice",
"outputs" : ["Slice_62_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_61"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_69",
"op" : "Slice",
"outputs" : ["Slice_69_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_61"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_79",
"op" : "Slice",
"outputs" : ["Slice_79_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_61"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_86",
"op" : "Slice",
"outputs" : ["Slice_86_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Add_103"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_104",
"op" : "Slice",
"outputs" : ["Slice_104_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_103"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_111",
"op" : "Slice",
"outputs" : ["Slice_111_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_103"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_119",
"op" : "Slice",
"outputs" : ["Slice_119_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_103"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_126",
"op" : "Slice",
"outputs" : ["Slice_126_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Slice_62"],
"name" : "Negative_65",
"op" : "Negative",
"outputs" : ["Negative_65_0"]
},
{
"inputs" : ["Slice_69"],
"name" : "Negative_72",
"op" : "Negative",
"outputs" : ["Negative_72_0"]
},
{
"inputs" : ["Slice_79"],
"name" : "Negative_82",
"op" : "Negative",
"outputs" : ["Negative_82_0"]
},
{
"inputs" : ["Slice_86"],
"name" : "Tanh_87",
"op" : "Tanh",
"outputs" : ["Tanh_87_0"]
},
{
"inputs" : ["Slice_104"],
"name" : "Negative_107",
"op" : "Negative",
"outputs" : ["Negative_107_0"]
},
{
"inputs" : ["Slice_111"],
"name" : "Negative_114",
"op" : "Negative",
"outputs" : ["Negative_114_0"]
},
{
"inputs" : ["Slice_119"],
"name" : "Negative_122",
"op" : "Negative",
"outputs" : ["Negative_122_0"]
},
{
"inputs" : ["Slice_126"],
"name" : "Tanh_127",
"op" : "Tanh",
"outputs" : ["Tanh_127_0"]
},
{
"inputs" : ["Negative_65"],
"name" : "Exp_66",
"op" : "Exp",
"outputs" : ["Exp_66_0"]
},
{
"inputs" : ["Negative_72"],
"name" : "Exp_73",
"op" : "Exp",
"outputs" : ["Exp_73_0"]
},
{
"inputs" : ["Negative_82"],
"name" : "Exp_83",
"op" : "Exp",
"outputs" : ["Exp_83_0"]
},
{
"inputs" : ["Negative_107"],
"name" : "Exp_108",
"op" : "Exp",
"outputs" : ["Exp_108_0"]
},
{
"inputs" : ["Negative_114"],
"name" : "Exp_115",
"op" : "Exp",
"outputs" : ["Exp_115_0"]
},
{
"inputs" : ["Negative_122"],
"name" : "Exp_123",
"op" : "Exp",
"outputs" : ["Exp_123_0"]
},
{
"inputs" : [ "Broadcast_64", "Exp_66" ],
"name" : "Add_67",
"op" : "Add",
"outputs" : ["Add_67_0"]
},
{
"inputs" : [ "Broadcast_71", "Exp_73" ],
"name" : "Add_74",
"op" : "Add",
"outputs" : ["Add_74_0"]
},
{
"inputs" : [ "Broadcast_81", "Exp_83" ],
"name" : "Add_84",
"op" : "Add",
"outputs" : ["Add_84_0"]
},
{
"inputs" : [ "Broadcast_106", "Exp_108" ],
"name" : "Add_109",
"op" : "Add",
"outputs" : ["Add_109_0"]
},
{
"inputs" : [ "Broadcast_113", "Exp_115" ],
"name" : "Add_116",
"op" : "Add",
"outputs" : ["Add_116_0"]
},
{
"inputs" : [ "Broadcast_121", "Exp_123" ],
"name" : "Add_124",
"op" : "Add",
"outputs" : ["Add_124_0"]
},
{
"inputs" : [ "Broadcast_64", "Add_67" ],
"name" : "Divide_68",
"op" : "Divide",
"outputs" : ["Divide_68_0"]
},
{
"inputs" : [ "Broadcast_71", "Add_74" ],
"name" : "Divide_75",
"op" : "Divide",
"outputs" : ["Divide_75_0"]
},
{
"inputs" : [ "Broadcast_81", "Add_84" ],
"name" : "Divide_85",
"op" : "Divide",
"outputs" : ["Divide_85_0"]
},
{
"inputs" : [ "Broadcast_106", "Add_109" ],
"name" : "Divide_110",
"op" : "Divide",
"outputs" : ["Divide_110_0"]
},
{
"inputs" : [ "Broadcast_113", "Add_116" ],
"name" : "Divide_117",
"op" : "Divide",
"outputs" : ["Divide_117_0"]
},
{
"inputs" : [ "Broadcast_121", "Add_124" ],
"name" : "Divide_125",
"op" : "Divide",
"outputs" : ["Divide_125_0"]
},
{
"inputs" : [ "Divide_75", "Broadcast_77" ],
"name" : "Multiply_78",
"op" : "Multiply",
"outputs" : ["Multiply_78_0"]
},
{
"inputs" : [ "Divide_85", "Tanh_87" ],
"name" : "Multiply_88",
"op" : "Multiply",
"outputs" : ["Multiply_88_0"]
},
{
"inputs" : [ "Divide_117", "Add_44" ],
"name" : "Multiply_118",
"op" : "Multiply",
"outputs" : ["Multiply_118_0"]
},
{
"inputs" : [ "Divide_125", "Tanh_127" ],
"name" : "Multiply_128",
"op" : "Multiply",
"outputs" : ["Multiply_128_0"]
},
{
"inputs" : [ "Multiply_78", "Multiply_88" ],
"name" : "Add_89",
"op" : "Add",
"outputs" : ["Add_89_0"]
},
{
"inputs" : [ "Multiply_118", "Multiply_128" ],
"name" : "Add_129",
"op" : "Add",
"outputs" : ["Add_129_0"]
},
{
"inputs" : ["Add_89"],
"name" : "Tanh_90",
"op" : "Tanh",
"outputs" : ["Tanh_90_0"]
},
{
"inputs" : ["Add_129"],
"name" : "Tanh_130",
"op" : "Tanh",
"outputs" : ["Tanh_130_0"]
},
{
"inputs" : [ "Divide_68", "Tanh_90" ],
"name" : "Multiply_91",
"op" : "Multiply",
"outputs" : ["Multiply_91_0"]
},
{
"inputs" : [ "Divide_110", "Tanh_130" ],
"name" : "Multiply_131",
"op" : "Multiply",
"outputs" : ["Multiply_131_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Multiply_91"],
"name" : "Reshape_92",
"op" : "Reshape",
"output_shape" : [ 32, 1, 100 ],
"outputs" : ["Reshape_92_0"]
},
{
"inputs" : [ "Multiply_91", "Reshape_136" ],
"name" : "Dot_137",
"op" : "Dot",
"outputs" : ["Dot_137_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Multiply_131", "Reshape_132" ],
"name" : "Dot_133",
"op" : "Dot",
"outputs" : ["Dot_133_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Multiply_131", "Reshape_176" ],
"name" : "Dot_177",
"op" : "Dot",
"outputs" : ["Dot_177_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_137", "Broadcast_138" ],
"name" : "Add_139",
"op" : "Add",
"outputs" : ["Add_139_0"]
},
{
"inputs" : [ "Dot_133", "Broadcast_134" ],
"name" : "Add_135",
"op" : "Add",
"outputs" : ["Add_135_0"]
},
{
"inputs" : [ "Dot_177", "Broadcast_178" ],
"name" : "Add_179",
"op" : "Add",
"outputs" : ["Add_179_0"]
},
{
"inputs" : [ "Add_135", "Add_139" ],
"name" : "Add_140",
"op" : "Add",
"outputs" : ["Add_140_0"]
},
{
"inputs" : [ "Add_175", "Add_179" ],
"name" : "Add_180",
"op" : "Add",
"outputs" : ["Add_180_0"]
},
{
"inputs" : ["Add_140"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_141",
"op" : "Slice",
"outputs" : ["Slice_141_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_140"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_148",
"op" : "Slice",
"outputs" : ["Slice_148_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_140"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_156",
"op" : "Slice",
"outputs" : ["Slice_156_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_140"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_163",
"op" : "Slice",
"outputs" : ["Slice_163_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Add_180"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_181",
"op" : "Slice",
"outputs" : ["Slice_181_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_180"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_188",
"op" : "Slice",
"outputs" : ["Slice_188_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_180"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_196",
"op" : "Slice",
"outputs" : ["Slice_196_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_180"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_203",
"op" : "Slice",
"outputs" : ["Slice_203_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Slice_141"],
"name" : "Negative_144",
"op" : "Negative",
"outputs" : ["Negative_144_0"]
},
{
"inputs" : ["Slice_148"],
"name" : "Negative_151",
"op" : "Negative",
"outputs" : ["Negative_151_0"]
},
{
"inputs" : ["Slice_156"],
"name" : "Negative_159",
"op" : "Negative",
"outputs" : ["Negative_159_0"]
},
{
"inputs" : ["Slice_163"],
"name" : "Tanh_164",
"op" : "Tanh",
"outputs" : ["Tanh_164_0"]
},
{
"inputs" : ["Slice_181"],
"name" : "Negative_184",
"op" : "Negative",
"outputs" : ["Negative_184_0"]
},
{
"inputs" : ["Slice_188"],
"name" : "Negative_191",
"op" : "Negative",
"outputs" : ["Negative_191_0"]
},
{
"inputs" : ["Slice_196"],
"name" : "Negative_199",
"op" : "Negative",
"outputs" : ["Negative_199_0"]
},
{
"inputs" : ["Slice_203"],
"name" : "Tanh_204",
"op" : "Tanh",
"outputs" : ["Tanh_204_0"]
},
{
"inputs" : ["Negative_144"],
"name" : "Exp_145",
"op" : "Exp",
"outputs" : ["Exp_145_0"]
},
{
"inputs" : ["Negative_151"],
"name" : "Exp_152",
"op" : "Exp",
"outputs" : ["Exp_152_0"]
},
{
"inputs" : ["Negative_159"],
"name" : "Exp_160",
"op" : "Exp",
"outputs" : ["Exp_160_0"]
},
{
"inputs" : ["Negative_184"],
"name" : "Exp_185",
"op" : "Exp",
"outputs" : ["Exp_185_0"]
},
{
"inputs" : ["Negative_191"],
"name" : "Exp_192",
"op" : "Exp",
"outputs" : ["Exp_192_0"]
},
{
"inputs" : ["Negative_199"],
"name" : "Exp_200",
"op" : "Exp",
"outputs" : ["Exp_200_0"]
},
{
"inputs" : [ "Broadcast_143", "Exp_145" ],
"name" : "Add_146",
"op" : "Add",
"outputs" : ["Add_146_0"]
},
{
"inputs" : [ "Broadcast_150", "Exp_152" ],
"name" : "Add_153",
"op" : "Add",
"outputs" : ["Add_153_0"]
},
{
"inputs" : [ "Broadcast_158", "Exp_160" ],
"name" : "Add_161",
"op" : "Add",
"outputs" : ["Add_161_0"]
},
{
"inputs" : [ "Broadcast_183", "Exp_185" ],
"name" : "Add_186",
"op" : "Add",
"outputs" : ["Add_186_0"]
},
{
"inputs" : [ "Broadcast_190", "Exp_192" ],
"name" : "Add_193",
"op" : "Add",
"outputs" : ["Add_193_0"]
},
{
"inputs" : [ "Broadcast_198", "Exp_200" ],
"name" : "Add_201",
"op" : "Add",
"outputs" : ["Add_201_0"]
},
{
"inputs" : [ "Broadcast_143", "Add_146" ],
"name" : "Divide_147",
"op" : "Divide",
"outputs" : ["Divide_147_0"]
},
{
"inputs" : [ "Broadcast_150", "Add_153" ],
"name" : "Divide_154",
"op" : "Divide",
"outputs" : ["Divide_154_0"]
},
{
"inputs" : [ "Broadcast_158", "Add_161" ],
"name" : "Divide_162",
"op" : "Divide",
"outputs" : ["Divide_162_0"]
},
{
"inputs" : [ "Broadcast_183", "Add_186" ],
"name" : "Divide_187",
"op" : "Divide",
"outputs" : ["Divide_187_0"]
},
{
"inputs" : [ "Broadcast_190", "Add_193" ],
"name" : "Divide_194",
"op" : "Divide",
"outputs" : ["Divide_194_0"]
},
{
"inputs" : [ "Broadcast_198", "Add_201" ],
"name" : "Divide_202",
"op" : "Divide",
"outputs" : ["Divide_202_0"]
},
{
"inputs" : [ "Divide_154", "Add_89" ],
"name" : "Multiply_155",
"op" : "Multiply",
"outputs" : ["Multiply_155_0"]
},
{
"inputs" : [ "Divide_162", "Tanh_164" ],
"name" : "Multiply_165",
"op" : "Multiply",
"outputs" : ["Multiply_165_0"]
},
{
"inputs" : [ "Divide_194", "Add_129" ],
"name" : "Multiply_195",
"op" : "Multiply",
"outputs" : ["Multiply_195_0"]
},
{
"inputs" : [ "Divide_202", "Tanh_204" ],
"name" : "Multiply_205",
"op" : "Multiply",
"outputs" : ["Multiply_205_0"]
},
{
"inputs" : [ "Multiply_155", "Multiply_165" ],
"name" : "Add_166",
"op" : "Add",
"outputs" : ["Add_166_0"]
},
{
"inputs" : [ "Multiply_195", "Multiply_205" ],
"name" : "Add_206",
"op" : "Add",
"outputs" : ["Add_206_0"]
},
{
"inputs" : ["Add_166"],
"name" : "Tanh_167",
"op" : "Tanh",
"outputs" : ["Tanh_167_0"]
},
{
"inputs" : ["Add_206"],
"name" : "Tanh_207",
"op" : "Tanh",
"outputs" : ["Tanh_207_0"]
},
{
"inputs" : [ "Divide_147", "Tanh_167" ],
"name" : "Multiply_168",
"op" : "Multiply",
"outputs" : ["Multiply_168_0"]
},
{
"inputs" : [ "Divide_187", "Tanh_207" ],
"name" : "Multiply_208",
"op" : "Multiply",
"outputs" : ["Multiply_208_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Multiply_168"],
"name" : "Reshape_169",
"op" : "Reshape",
"output_shape" : [ 32, 1, 100 ],
"outputs" : ["Reshape_169_0"]
},
{
"inputs" : [ "Multiply_168", "Reshape_213" ],
"name" : "Dot_214",
"op" : "Dot",
"outputs" : ["Dot_214_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Multiply_208", "Reshape_209" ],
"name" : "Dot_210",
"op" : "Dot",
"outputs" : ["Dot_210_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_214", "Broadcast_215" ],
"name" : "Add_216",
"op" : "Add",
"outputs" : ["Add_216_0"]
},
{
"inputs" : [ "Dot_210", "Broadcast_211" ],
"name" : "Add_212",
"op" : "Add",
"outputs" : ["Add_212_0"]
},
{
"inputs" : [ "Add_212", "Add_216" ],
"name" : "Add_217",
"op" : "Add",
"outputs" : ["Add_217_0"]
},
{
"inputs" : ["Add_217"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_218",
"op" : "Slice",
"outputs" : ["Slice_218_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 400 ]
},
{
"inputs" : ["Add_217"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_225",
"op" : "Slice",
"outputs" : ["Slice_225_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 200 ]
},
{
"inputs" : ["Add_217"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_233",
"op" : "Slice",
"outputs" : ["Slice_233_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 100 ]
},
{
"inputs" : ["Add_217"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_240",
"op" : "Slice",
"outputs" : ["Slice_240_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 32, 300 ]
},
{
"inputs" : ["Slice_218"],
"name" : "Negative_221",
"op" : "Negative",
"outputs" : ["Negative_221_0"]
},
{
"inputs" : ["Slice_225"],
"name" : "Negative_228",
"op" : "Negative",
"outputs" : ["Negative_228_0"]
},
{
"inputs" : ["Slice_233"],
"name" : "Negative_236",
"op" : "Negative",
"outputs" : ["Negative_236_0"]
},
{
"inputs" : ["Slice_240"],
"name" : "Tanh_241",
"op" : "Tanh",
"outputs" : ["Tanh_241_0"]
},
{
"inputs" : ["Negative_221"],
"name" : "Exp_222",
"op" : "Exp",
"outputs" : ["Exp_222_0"]
},
{
"inputs" : ["Negative_228"],
"name" : "Exp_229",
"op" : "Exp",
"outputs" : ["Exp_229_0"]
},
{
"inputs" : ["Negative_236"],
"name" : "Exp_237",
"op" : "Exp",
"outputs" : ["Exp_237_0"]
},
{
"inputs" : [ "Broadcast_220", "Exp_222" ],
"name" : "Add_223",
"op" : "Add",
"outputs" : ["Add_223_0"]
},
{
"inputs" : [ "Broadcast_227", "Exp_229" ],
"name" : "Add_230",
"op" : "Add",
"outputs" : ["Add_230_0"]
},
{
"inputs" : [ "Broadcast_235", "Exp_237" ],
"name" : "Add_238",
"op" : "Add",
"outputs" : ["Add_238_0"]
},
{
"inputs" : [ "Broadcast_220", "Add_223" ],
"name" : "Divide_224",
"op" : "Divide",
"outputs" : ["Divide_224_0"]
},
{
"inputs" : [ "Broadcast_227", "Add_230" ],
"name" : "Divide_231",
"op" : "Divide",
"outputs" : ["Divide_231_0"]
},
{
"inputs" : [ "Broadcast_235", "Add_238" ],
"name" : "Divide_239",
"op" : "Divide",
"outputs" : ["Divide_239_0"]
},
{
"inputs" : [ "Divide_231", "Add_166" ],
"name" : "Multiply_232",
"op" : "Multiply",
"outputs" : ["Multiply_232_0"]
},
{
"inputs" : [ "Divide_239", "Tanh_241" ],
"name" : "Multiply_242",
"op" : "Multiply",
"outputs" : ["Multiply_242_0"]
},
{
"inputs" : [ "Multiply_232", "Multiply_242" ],
"name" : "Add_243",
"op" : "Add",
"outputs" : ["Add_243_0"]
},
{
"inputs" : ["Add_243"],
"name" : "Tanh_244",
"op" : "Tanh",
"outputs" : ["Tanh_244_0"]
},
{
"inputs" : [ "Divide_224", "Tanh_244" ],
"name" : "Multiply_245",
"op" : "Multiply",
"outputs" : ["Multiply_245_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Multiply_245"],
"name" : "Reshape_246",
"op" : "Reshape",
"output_shape" : [ 32, 1, 100 ],
"outputs" : ["Reshape_246_0"]
},
{
"axis" : 1,
"inputs" : [ "Reshape_92", "Reshape_169", "Reshape_246" ],
"name" : "Concat_247",
"op" : "Concat",
"outputs" : ["Concat_247_0"]
},
{
"inputs" : ["Concat_247"],
"name" : "Result_248",
"op" : "Result",
"outputs" : ["Result_248_0"]
}
],
"parameters" : [
"Parameter_0", "Parameter_1", "Parameter_2", "Parameter_10", "Parameter_11",
"Parameter_47", "Parameter_48", "Parameter_55", "Parameter_56",
"Parameter_93", "Parameter_170"
],
"result" : ["Result_248"]
}]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment