Commit 64c7a437 authored by arogowie-intel's avatar arogowie-intel

Merge branch 'arogowiec/fused_lstm_cell' into arogowiec/fused_rnn_gru_cells

parents 5cd708c2 75c0b4cc
...@@ -241,14 +241,14 @@ namespace ngraph ...@@ -241,14 +241,14 @@ namespace ngraph
const LSTMAttributes& attributes) const LSTMAttributes& attributes)
: m_X{X} : m_X{X}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs. // Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
, m_W{reshape::squeeze(W)} , m_W(reshape::squeeze(W))
, m_R{reshape::squeeze(R)} , m_R(reshape::squeeze(R))
, m_B{reshape::squeeze(B)} , m_B(reshape::squeeze(B))
, m_P{reshape::squeeze(P)} , m_P(reshape::squeeze(P))
, m_initial_h{reshape::squeeze(initial_h)} , m_initial_h(reshape::squeeze(initial_h))
, m_initial_c{reshape::squeeze(initial_c)} , m_initial_c(reshape::squeeze(initial_c))
, m_seq_lengths{seq_lengths} , m_seq_lengths(seq_lengths)
, m_attributes{attributes} , m_attributes(attributes)
{ {
} }
...@@ -302,7 +302,7 @@ namespace ngraph ...@@ -302,7 +302,7 @@ namespace ngraph
std::int32_t time_step{1}; std::int32_t time_step{1};
for (const auto& in_x : in_seqs) for (const auto& in_x : in_seqs)
{ {
const std::shared_ptr<ngraph::Node>& lstm_cell = std::shared_ptr<ngraph::Node> lstm_cell =
std::make_shared<ngraph::op::LSTMCell>( std::make_shared<ngraph::op::LSTMCell>(
in_x, in_x,
m_W, m_W,
...@@ -318,10 +318,8 @@ namespace ngraph ...@@ -318,10 +318,8 @@ namespace ngraph
m_attributes.m_clip_threshold, m_attributes.m_clip_threshold,
m_attributes.m_input_forget); m_attributes.m_input_forget);
const std::shared_ptr<ngraph::Node>& H = std::shared_ptr<ngraph::Node> H = get_output_element(lstm_cell, 0);
get_output_element(lstm_cell, 0); std::shared_ptr<ngraph::Node> C = get_output_element(lstm_cell, 1);
const std::shared_ptr<ngraph::Node>& C =
get_output_element(lstm_cell, 1);
// Expand tensors with empty outermost dim, so we can later concatenate // Expand tensors with empty outermost dim, so we can later concatenate
// them. // them.
......
...@@ -172,6 +172,8 @@ void runtime::gpu::GPUCompiledFunction::compile() ...@@ -172,6 +172,8 @@ void runtime::gpu::GPUCompiledFunction::compile()
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>(); pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(); pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(); pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this); pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>(); pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
......
...@@ -428,6 +428,8 @@ shared_ptr<runtime::Executable> ...@@ -428,6 +428,8 @@ shared_ptr<runtime::Executable>
{ {
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>( pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl); IntelGPUBackend::is_supported_impl);
// Run this pass for the second time since, some fused operators like LSTMCell may use
// other fused operators inside.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>( pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>(
IntelGPUBackend::is_supported_impl); IntelGPUBackend::is_supported_impl);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment