Commit 64c7a437 authored by arogowie-intel's avatar arogowie-intel

Merge branch 'arogowiec/fused_lstm_cell' into arogowiec/fused_rnn_gru_cells

parents 5cd708c2 75c0b4cc
......@@ -241,14 +241,14 @@ namespace ngraph
const LSTMAttributes& attributes)
: m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)}
, m_R{reshape::squeeze(R)}
, m_B{reshape::squeeze(B)}
, m_P{reshape::squeeze(P)}
, m_initial_h{reshape::squeeze(initial_h)}
, m_initial_c{reshape::squeeze(initial_c)}
, m_seq_lengths{seq_lengths}
, m_attributes{attributes}
, m_W(reshape::squeeze(W))
, m_R(reshape::squeeze(R))
, m_B(reshape::squeeze(B))
, m_P(reshape::squeeze(P))
, m_initial_h(reshape::squeeze(initial_h))
, m_initial_c(reshape::squeeze(initial_c))
, m_seq_lengths(seq_lengths)
, m_attributes(attributes)
{
}
......@@ -302,7 +302,7 @@ namespace ngraph
std::int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
const std::shared_ptr<ngraph::Node>& lstm_cell =
std::shared_ptr<ngraph::Node> lstm_cell =
std::make_shared<ngraph::op::LSTMCell>(
in_x,
m_W,
......@@ -318,10 +318,8 @@ namespace ngraph
m_attributes.m_clip_threshold,
m_attributes.m_input_forget);
const std::shared_ptr<ngraph::Node>& H =
get_output_element(lstm_cell, 0);
const std::shared_ptr<ngraph::Node>& C =
get_output_element(lstm_cell, 1);
std::shared_ptr<ngraph::Node> H = get_output_element(lstm_cell, 0);
std::shared_ptr<ngraph::Node> C = get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
......
......@@ -172,6 +172,8 @@ void runtime::gpu::GPUCompiledFunction::compile()
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
......@@ -428,6 +428,8 @@ shared_ptr<runtime::Executable>
{
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment