Commit b88fa59d authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Shuffle folding to collapse transposes into layout conversions (#950)

parent b5844622
...@@ -222,6 +222,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -222,6 +222,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/pass/cpu_nop_elimination.cpp runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
runtime/cpu/pass/cpu_post_layout_optimizations.cpp runtime/cpu/pass/cpu_post_layout_optimizations.cpp
runtime/cpu/pass/cpu_shuffle_folding.cpp
) )
# LLVM binary builds are typically built without RTTI # LLVM binary builds are typically built without RTTI
# The built-in headers are in a version-specific directory # The built-in headers are in a version-specific directory
......
...@@ -3312,9 +3312,18 @@ namespace ngraph ...@@ -3312,9 +3312,18 @@ namespace ngraph
{ {
auto input_tvl = auto input_tvl =
node->get_inputs()[0].get_output().get_tensor_view()->get_tensor_view_layout(); node->get_inputs()[0].get_output().get_tensor_view()->get_tensor_view_layout();
auto input_cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor&>(*input_tvl);
auto input_format = input_cpu_tvl.get_mkldnn_format();
// Reorder input shape if needed
auto input_axis_order = input_cpu_tvl.get_axis_order();
Shape input_shape(input_axis_order.size());
for (size_t idx = 0; idx < input_axis_order.size(); idx++)
{
input_shape[idx] = args[0].get_shape()[input_axis_order[idx]];
}
auto output_tvl = node->get_output_tensor_view(0)->get_tensor_view_layout(); auto output_tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto input_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*input_tvl).get_mkldnn_format();
auto output_format = auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format(); dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
...@@ -3332,7 +3341,9 @@ namespace ngraph ...@@ -3332,7 +3341,9 @@ namespace ngraph
} }
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto input_desc = mkldnn_emitter->build_memory_descriptor(
input_shape, args[0].get_element_type(), input_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format); auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc); size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc);
......
...@@ -124,6 +124,7 @@ ...@@ -124,6 +124,7 @@
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp" #include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp" #include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp" #include "ngraph/op/allreduce.hpp"
...@@ -317,6 +318,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -317,6 +318,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>(); pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<runtime::cpu::pass::CPUShuffleFolding>();
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>(); pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
......
...@@ -64,6 +64,7 @@ namespace ngraph ...@@ -64,6 +64,7 @@ namespace ngraph
std::reverse(strides.begin(), strides.end()); std::reverse(strides.begin(), strides.end());
} }
void LayoutDescriptor::set_axis_order(const AxisVector& perm) { axis_order = perm; }
size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices) size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
{ {
if (indices.size() != strides.size()) if (indices.size() != strides.size())
......
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
} }
mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; } mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; }
const AxisVector& get_axis_order() const { return axis_order; } const AxisVector& get_axis_order() const { return axis_order; }
void set_axis_order(const AxisVector& perm);
static const AxisVector Native2DAxisOrder; static const AxisVector Native2DAxisOrder;
static const AxisVector Native4DAxisOrder; static const AxisVector Native4DAxisOrder;
static const AxisVector CHWNAxisOrder; static const AxisVector CHWNAxisOrder;
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "cpu_shuffle_folding.hpp"
static const std::map<const ngraph::AxisVector, const mkldnn::memory::format>
input_order_format_map{{ngraph::AxisVector{3, 2, 0, 1}, mkldnn::memory::format::hwio}};
bool ngraph::runtime::cpu::pass::CPUShuffleFolding::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
bool clobbered = false;
for (const auto& n : function->get_ordered_ops())
{
auto convert_layout = std::dynamic_pointer_cast<op::ConvertLayout>(n);
if (convert_layout)
{
auto reshape = std::dynamic_pointer_cast<ngraph::op::Reshape>(n->get_argument(0));
if (reshape)
{
auto output_shape = reshape->get_output_shape();
auto input_shape = reshape->get_input_shape(0);
if (output_shape.size() != input_shape.size())
{
continue;
}
size_t j = 0;
bool is_shuffle = true;
for (auto i : reshape->get_input_order())
{
if (input_shape.at(i) != output_shape.at(j++))
{
is_shuffle = false;
break;
}
}
if (!is_shuffle)
{
continue;
}
auto reshape_input_layout =
reshape->get_argument(0)->get_output_tensor_view()->get_tensor_view_layout();
auto output_layout =
convert_layout->get_output_tensor_view()->get_tensor_view_layout();
if (reshape_input_layout)
{
auto reshape_input_layout_descriptor =
std::static_pointer_cast<runtime::cpu::LayoutDescriptor>(
reshape_input_layout);
auto reshape_input_format =
reshape_input_layout_descriptor->get_mkldnn_format();
auto output_format =
std::static_pointer_cast<runtime::cpu::LayoutDescriptor>(output_layout)
->get_mkldnn_format();
if (mkldnn_utils::is_mkldnn_filter_format(output_format) &&
output_format == mkldnn::memory::format::OIhw16i16o &&
reshape_input_format == mkldnn::memory::format::nchw)
{
if (input_order_format_map.find(reshape->get_input_order()) !=
input_order_format_map.end())
{
reshape_input_layout_descriptor->set_mkldnn_format(
input_order_format_map.at(reshape->get_input_order()));
reshape_input_layout_descriptor->set_axis_order(
reshape->get_input_order());
function->replace_node(reshape, reshape->get_argument(0));
}
}
}
}
}
}
return clobbered;
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUShuffleFolding : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment