Commit e3ad1b31 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

CPU Loop Kernel Fusion optimization (#1190)

* cpu loop kernel fusion pass

*  remove extra code

* bounded relu test

* address scotts feedback
parent 9a3a0314
......@@ -59,6 +59,7 @@ set(SRC
pass/cpu_post_layout_optimizations.cpp
pass/cpu_rnn_fusion.cpp
pass/cpu_mat_fusion.cpp
pass/cpu_loop_kernel_fusion.cpp
pass/cpu_shuffle_folding.cpp
pass/cpu_workspace_insertion.cpp
)
......
......@@ -4466,6 +4466,10 @@ namespace ngraph
{
auto abse =
std::bind(emit_function_call, std::string("std::abs"), std::placeholders::_1);
auto mine =
std::bind(emit_function_call, std::string("std::min"), std::placeholders::_1);
auto maxe =
std::bind(emit_function_call, std::string("std::max"), std::placeholders::_1);
auto adde = std::bind(emit_infix_operator, std::string("+"), std::placeholders::_1);
auto nege =
std::bind(emit_prefix_operator, std::string("-"), std::placeholders::_1);
......@@ -4475,6 +4479,9 @@ namespace ngraph
std::type_index,
std::function<std::string(const std::vector<std::string>&)>>{
{TI(ngraph::op::Abs), abse},
{TI(ngraph::op::Minimum), mine},
{TI(ngraph::op::Relu), maxe},
{TI(ngraph::op::Maximum), maxe},
{TI(ngraph::op::Add), adde},
{TI(ngraph::op::Negative), nege},
{TI(ngraph::op::Subtract), sube},
......@@ -4485,10 +4492,25 @@ namespace ngraph
std::function<std::string(const std::vector<std::string>&)>>
inline_emitters = initialize_inline_emitters();
//GOEE doesn't see GOEs in subgraphs that are hidden inside LoopKernels
//we have to manually propagate the source output
static const ngraph::descriptor::Output*
get_goe_input_output(ngraph::descriptor::Output* output)
{
auto it = output;
while (auto goe =
std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(it->get_node()))
{
it = &goe->get_inputs().at(goe->get_n()).get_output();
}
return it;
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::runtime::cpu::op::LoopKernel)
{
std::unordered_map<std::shared_ptr<Node>, std::string> loop_symbol_table;
std::unordered_map<const ngraph::descriptor::Output*, std::string>
loop_symbol_table;
//pre-fill symbol table with inputs
const ngraph::runtime::cpu::op::LoopKernel* clk =
......@@ -4496,10 +4518,11 @@ namespace ngraph
NodeVector output_nodes = clk->get_kernel_outputs();
NodeVector node_list = clk->get_node_list();
for (size_t i = 0; i < args.size(); i++)
{
std::string sname = std::string(args[i].get_name()) + "[i]";
auto entry = std::make_pair(clk->get_argument(i), sname);
auto entry = std::make_pair(&clk->get_inputs().at(i).get_output(), sname);
loop_symbol_table.insert(entry);
}
......@@ -4508,7 +4531,8 @@ namespace ngraph
for (size_t i = 0; i < out.size(); i++)
{
std::string sname = std::string(out[i].get_name()) + "[i]";
auto entry = std::make_pair(output_nodes.at(i), sname);
//TODO: no support for multiple-output ops in loop kernel
auto entry = std::make_pair(&output_nodes.at(i)->get_outputs().at(0), sname);
loop_symbol_table.insert(entry);
}
......@@ -4520,7 +4544,8 @@ namespace ngraph
for (size_t i = 0; i < node_list.size(); i++)
{
auto op = node_list[i];
auto op_node = node_list[i];
auto op = &op_node->get_outputs().at(0);
std::string tmp;
if (loop_symbol_table.count(op) == 0)
{
......@@ -4540,13 +4565,22 @@ namespace ngraph
//prepare arguments
std::vector<std::string> sargs;
for (auto arg : op->get_arguments())
for (auto& input : op_node->get_inputs())
{
//args are expected to be in a map already
sargs.push_back(loop_symbol_table.at(arg));
sargs.push_back(
loop_symbol_table.at(get_goe_input_output(&input.get_output())));
}
if (std::dynamic_pointer_cast<ngraph::op::Relu>(op_node))
{
auto casted_zero = std::string("static_cast<") +
op->get_element_type().c_type_string() +
std::string(">(0)");
sargs.push_back(casted_zero);
}
const Node& n = *op;
const Node& n = *op_node;
auto emitter = inline_emitters.at(TI(n));
writer << tmp << " = " << emitter(sargs) << ";\n";
}
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace ngraph;
struct LKGraph
{
LKGraph(const NodeVector& ns, const NodeVector& ins)
: m_inputs(ins)
, m_nodes(ns)
{
}
NodeVector m_inputs;
NodeVector m_nodes;
};
class LoopKernelCollector
{
public:
LoopKernelCollector(std::shared_ptr<Function> f, size_t min_nodes_to_fuse)
{
for (auto n : f->get_ordered_ops())
{
if (is_fusible(n))
{
auto arg_from_fusible_group = collect_fusible_args(n);
//create a new group
if (!arg_from_fusible_group)
{
m_heads.insert(std::make_pair(n, n));
m_graphs.insert(std::make_pair(n, LKGraph{{n}, n->get_arguments()}));
NGRAPH_DEBUG << "Created a new group for " << n->get_name();
log_group(n);
}
else
{
auto smallest_head = m_heads.at(arg_from_fusible_group);
auto& lkgraph = m_graphs.at(smallest_head);
lkgraph.m_nodes.push_back(n);
for (auto arg : n->get_arguments())
{
if (is_leaf(arg))
{
lkgraph.m_inputs.push_back(arg);
}
}
m_heads.insert(std::make_pair(n, smallest_head));
log_group(smallest_head);
}
}
}
prune_graphs(min_nodes_to_fuse);
}
const std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> get_loop_kernels() const
{
std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> lks;
for (auto e : m_graphs)
{
auto& lkg = e.second;
std::unordered_set<std::shared_ptr<Node>> graph_nodes{lkg.m_nodes.begin(),
lkg.m_nodes.end()};
NodeVector member_outputs;
auto has_external_user = [&graph_nodes](std::shared_ptr<Node> u) {
return graph_nodes.count(u) == 0;
};
for (auto member : lkg.m_nodes)
{
auto member_users = member->get_users();
if (std::any_of(member_users.cbegin(), member_users.cend(), has_external_user))
{
member_outputs.push_back(member);
}
}
auto lk = std::make_shared<runtime::cpu::op::LoopKernel>(
lkg.m_nodes, member_outputs, lkg.m_inputs);
lks.push_back(lk);
}
return lks;
}
private:
static bool is_fusible(std::shared_ptr<Node> n)
{
static const std::set<std::type_index> fusible_ops_set{TI(ngraph::op::Abs),
TI(ngraph::op::Add),
TI(ngraph::op::Negative),
TI(ngraph::op::Subtract),
TI(ngraph::op::Relu),
TI(ngraph::op::Minimum),
TI(ngraph::op::Maximum)};
const Node& node = *n;
return fusible_ops_set.count(TI(node)) != 0;
// return (std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n) ||
// std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n));
}
bool is_leaf(std::shared_ptr<Node> src) { return src->is_parameter() || src->is_constant(); }
void prune_graphs(size_t min_nodes_to_fuse)
{
for (auto it = m_graphs.begin(); it != m_graphs.end();)
{
if (it->second.m_nodes.size() < min_nodes_to_fuse)
{
it = m_graphs.erase(it);
}
else
{
it++;
}
}
}
void log_group(std::shared_ptr<Node> head) const
{
NGRAPH_DEBUG << "Group leader : " << head->get_name() << std::endl;
NGRAPH_DEBUG << "Group members : " << m_graphs.at(head).m_nodes << std::endl;
NGRAPH_DEBUG << "Inputs: " << m_graphs.at(head).m_inputs << std::endl;
}
std::shared_ptr<Node> collect_fusible_args(std::shared_ptr<Node> n)
{
std::shared_ptr<Node> arg_from_fusible_group;
for (auto arg : n->get_arguments())
{
//an argument is fusible and a part of some group
NGRAPH_DEBUG << "Considering " << arg->get_name();
if (m_heads.count(arg) != 0)
{
if (!arg_from_fusible_group)
{
arg_from_fusible_group = arg;
}
else
{
if (!is_leaf(arg) && m_heads.at(arg) != m_heads.at(arg_from_fusible_group))
{
return {nullptr};
}
}
}
}
return arg_from_fusible_group;
}
std::unordered_map<std::shared_ptr<Node>, LKGraph> m_graphs;
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> m_heads;
};
bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
LoopKernelCollector lkc(function, m_min_kernel_size);
auto loop_kernels = lkc.get_loop_kernels();
for (auto lk : loop_kernels)
{
auto outputs = lk->get_kernel_outputs();
std::set<std::shared_ptr<Node>> lk_nodes_set(lk->get_node_list().begin(),
lk->get_node_list().end());
for (size_t i = 0; i < outputs.size(); i++)
{
auto ith_goe = std::make_shared<ngraph::op::GetOutputElement>(lk, i);
auto& ith_output = ith_goe->get_outputs().at(0);
if (outputs.at(i)->get_outputs().size() > 1)
{
throw ngraph_error(
"support for fusing multi-output nodes in loop kernels isn't yet implemented");
}
//TODO: revisit when we need support for multi-output nodes
auto& orig_output = outputs.at(i)->get_outputs().at(0);
//this is needed since replace_output modifies orig_output.get_inputs()
std::set<ngraph::descriptor::Input*> inputs_copy{begin(orig_output.get_inputs()),
end(orig_output.get_inputs())};
for (auto input : inputs_copy)
{
//this user is NOT internal to this loop kernel
//so it needs to be replaced with corresponding lk's GOE
if (lk_nodes_set.count(input->get_node()) == 0)
{
input->replace_output(ith_output);
}
}
}
}
return !lkc.get_loop_kernels().empty();
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPULoopKernelFusion : public ngraph::pass::FunctionPass
{
public:
CPULoopKernelFusion(size_t min_kernel_size = 2)
: FunctionPass()
, m_min_kernel_size(min_kernel_size)
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
protected:
size_t m_min_kernel_size;
};
}
}
}
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment