Commit e3ad1b31 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

CPU Loop Kernel Fusion optimization (#1190)

* cpu loop kernel fusion pass

*  remove extra code

* bounded relu test

* address scotts feedback
parent 9a3a0314
...@@ -59,6 +59,7 @@ set(SRC ...@@ -59,6 +59,7 @@ set(SRC
pass/cpu_post_layout_optimizations.cpp pass/cpu_post_layout_optimizations.cpp
pass/cpu_rnn_fusion.cpp pass/cpu_rnn_fusion.cpp
pass/cpu_mat_fusion.cpp pass/cpu_mat_fusion.cpp
pass/cpu_loop_kernel_fusion.cpp
pass/cpu_shuffle_folding.cpp pass/cpu_shuffle_folding.cpp
pass/cpu_workspace_insertion.cpp pass/cpu_workspace_insertion.cpp
) )
......
...@@ -4466,6 +4466,10 @@ namespace ngraph ...@@ -4466,6 +4466,10 @@ namespace ngraph
{ {
auto abse = auto abse =
std::bind(emit_function_call, std::string("std::abs"), std::placeholders::_1); std::bind(emit_function_call, std::string("std::abs"), std::placeholders::_1);
auto mine =
std::bind(emit_function_call, std::string("std::min"), std::placeholders::_1);
auto maxe =
std::bind(emit_function_call, std::string("std::max"), std::placeholders::_1);
auto adde = std::bind(emit_infix_operator, std::string("+"), std::placeholders::_1); auto adde = std::bind(emit_infix_operator, std::string("+"), std::placeholders::_1);
auto nege = auto nege =
std::bind(emit_prefix_operator, std::string("-"), std::placeholders::_1); std::bind(emit_prefix_operator, std::string("-"), std::placeholders::_1);
...@@ -4475,6 +4479,9 @@ namespace ngraph ...@@ -4475,6 +4479,9 @@ namespace ngraph
std::type_index, std::type_index,
std::function<std::string(const std::vector<std::string>&)>>{ std::function<std::string(const std::vector<std::string>&)>>{
{TI(ngraph::op::Abs), abse}, {TI(ngraph::op::Abs), abse},
{TI(ngraph::op::Minimum), mine},
{TI(ngraph::op::Relu), maxe},
{TI(ngraph::op::Maximum), maxe},
{TI(ngraph::op::Add), adde}, {TI(ngraph::op::Add), adde},
{TI(ngraph::op::Negative), nege}, {TI(ngraph::op::Negative), nege},
{TI(ngraph::op::Subtract), sube}, {TI(ngraph::op::Subtract), sube},
...@@ -4485,10 +4492,25 @@ namespace ngraph ...@@ -4485,10 +4492,25 @@ namespace ngraph
std::function<std::string(const std::vector<std::string>&)>> std::function<std::string(const std::vector<std::string>&)>>
inline_emitters = initialize_inline_emitters(); inline_emitters = initialize_inline_emitters();
//GOEE doesn't see GOEs in subgraphs that are hidden inside LoopKernels
//we have to manually propagate the source output
static const ngraph::descriptor::Output*
get_goe_input_output(ngraph::descriptor::Output* output)
{
auto it = output;
while (auto goe =
std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(it->get_node()))
{
it = &goe->get_inputs().at(goe->get_n()).get_output();
}
return it;
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::runtime::cpu::op::LoopKernel) void CPU_Emitter::EMITTER_DECL(ngraph::runtime::cpu::op::LoopKernel)
{ {
std::unordered_map<std::shared_ptr<Node>, std::string> loop_symbol_table; std::unordered_map<const ngraph::descriptor::Output*, std::string>
loop_symbol_table;
//pre-fill symbol table with inputs //pre-fill symbol table with inputs
const ngraph::runtime::cpu::op::LoopKernel* clk = const ngraph::runtime::cpu::op::LoopKernel* clk =
...@@ -4496,10 +4518,11 @@ namespace ngraph ...@@ -4496,10 +4518,11 @@ namespace ngraph
NodeVector output_nodes = clk->get_kernel_outputs(); NodeVector output_nodes = clk->get_kernel_outputs();
NodeVector node_list = clk->get_node_list(); NodeVector node_list = clk->get_node_list();
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
std::string sname = std::string(args[i].get_name()) + "[i]"; std::string sname = std::string(args[i].get_name()) + "[i]";
auto entry = std::make_pair(clk->get_argument(i), sname); auto entry = std::make_pair(&clk->get_inputs().at(i).get_output(), sname);
loop_symbol_table.insert(entry); loop_symbol_table.insert(entry);
} }
...@@ -4508,7 +4531,8 @@ namespace ngraph ...@@ -4508,7 +4531,8 @@ namespace ngraph
for (size_t i = 0; i < out.size(); i++) for (size_t i = 0; i < out.size(); i++)
{ {
std::string sname = std::string(out[i].get_name()) + "[i]"; std::string sname = std::string(out[i].get_name()) + "[i]";
auto entry = std::make_pair(output_nodes.at(i), sname); //TODO: no support for multiple-output ops in loop kernel
auto entry = std::make_pair(&output_nodes.at(i)->get_outputs().at(0), sname);
loop_symbol_table.insert(entry); loop_symbol_table.insert(entry);
} }
...@@ -4520,7 +4544,8 @@ namespace ngraph ...@@ -4520,7 +4544,8 @@ namespace ngraph
for (size_t i = 0; i < node_list.size(); i++) for (size_t i = 0; i < node_list.size(); i++)
{ {
auto op = node_list[i]; auto op_node = node_list[i];
auto op = &op_node->get_outputs().at(0);
std::string tmp; std::string tmp;
if (loop_symbol_table.count(op) == 0) if (loop_symbol_table.count(op) == 0)
{ {
...@@ -4540,13 +4565,22 @@ namespace ngraph ...@@ -4540,13 +4565,22 @@ namespace ngraph
//prepare arguments //prepare arguments
std::vector<std::string> sargs; std::vector<std::string> sargs;
for (auto arg : op->get_arguments()) for (auto& input : op_node->get_inputs())
{ {
//args are expected to be in a map already //args are expected to be in a map already
sargs.push_back(loop_symbol_table.at(arg)); sargs.push_back(
loop_symbol_table.at(get_goe_input_output(&input.get_output())));
}
if (std::dynamic_pointer_cast<ngraph::op::Relu>(op_node))
{
auto casted_zero = std::string("static_cast<") +
op->get_element_type().c_type_string() +
std::string(">(0)");
sargs.push_back(casted_zero);
} }
const Node& n = *op; const Node& n = *op_node;
auto emitter = inline_emitters.at(TI(n)); auto emitter = inline_emitters.at(TI(n));
writer << tmp << " = " << emitter(sargs) << ";\n"; writer << tmp << " = " << emitter(sargs) << ";\n";
} }
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace ngraph;
struct LKGraph
{
LKGraph(const NodeVector& ns, const NodeVector& ins)
: m_inputs(ins)
, m_nodes(ns)
{
}
NodeVector m_inputs;
NodeVector m_nodes;
};
class LoopKernelCollector
{
public:
LoopKernelCollector(std::shared_ptr<Function> f, size_t min_nodes_to_fuse)
{
for (auto n : f->get_ordered_ops())
{
if (is_fusible(n))
{
auto arg_from_fusible_group = collect_fusible_args(n);
//create a new group
if (!arg_from_fusible_group)
{
m_heads.insert(std::make_pair(n, n));
m_graphs.insert(std::make_pair(n, LKGraph{{n}, n->get_arguments()}));
NGRAPH_DEBUG << "Created a new group for " << n->get_name();
log_group(n);
}
else
{
auto smallest_head = m_heads.at(arg_from_fusible_group);
auto& lkgraph = m_graphs.at(smallest_head);
lkgraph.m_nodes.push_back(n);
for (auto arg : n->get_arguments())
{
if (is_leaf(arg))
{
lkgraph.m_inputs.push_back(arg);
}
}
m_heads.insert(std::make_pair(n, smallest_head));
log_group(smallest_head);
}
}
}
prune_graphs(min_nodes_to_fuse);
}
const std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> get_loop_kernels() const
{
std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> lks;
for (auto e : m_graphs)
{
auto& lkg = e.second;
std::unordered_set<std::shared_ptr<Node>> graph_nodes{lkg.m_nodes.begin(),
lkg.m_nodes.end()};
NodeVector member_outputs;
auto has_external_user = [&graph_nodes](std::shared_ptr<Node> u) {
return graph_nodes.count(u) == 0;
};
for (auto member : lkg.m_nodes)
{
auto member_users = member->get_users();
if (std::any_of(member_users.cbegin(), member_users.cend(), has_external_user))
{
member_outputs.push_back(member);
}
}
auto lk = std::make_shared<runtime::cpu::op::LoopKernel>(
lkg.m_nodes, member_outputs, lkg.m_inputs);
lks.push_back(lk);
}
return lks;
}
private:
static bool is_fusible(std::shared_ptr<Node> n)
{
static const std::set<std::type_index> fusible_ops_set{TI(ngraph::op::Abs),
TI(ngraph::op::Add),
TI(ngraph::op::Negative),
TI(ngraph::op::Subtract),
TI(ngraph::op::Relu),
TI(ngraph::op::Minimum),
TI(ngraph::op::Maximum)};
const Node& node = *n;
return fusible_ops_set.count(TI(node)) != 0;
// return (std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n) ||
// std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n));
}
bool is_leaf(std::shared_ptr<Node> src) { return src->is_parameter() || src->is_constant(); }
void prune_graphs(size_t min_nodes_to_fuse)
{
for (auto it = m_graphs.begin(); it != m_graphs.end();)
{
if (it->second.m_nodes.size() < min_nodes_to_fuse)
{
it = m_graphs.erase(it);
}
else
{
it++;
}
}
}
void log_group(std::shared_ptr<Node> head) const
{
NGRAPH_DEBUG << "Group leader : " << head->get_name() << std::endl;
NGRAPH_DEBUG << "Group members : " << m_graphs.at(head).m_nodes << std::endl;
NGRAPH_DEBUG << "Inputs: " << m_graphs.at(head).m_inputs << std::endl;
}
std::shared_ptr<Node> collect_fusible_args(std::shared_ptr<Node> n)
{
std::shared_ptr<Node> arg_from_fusible_group;
for (auto arg : n->get_arguments())
{
//an argument is fusible and a part of some group
NGRAPH_DEBUG << "Considering " << arg->get_name();
if (m_heads.count(arg) != 0)
{
if (!arg_from_fusible_group)
{
arg_from_fusible_group = arg;
}
else
{
if (!is_leaf(arg) && m_heads.at(arg) != m_heads.at(arg_from_fusible_group))
{
return {nullptr};
}
}
}
}
return arg_from_fusible_group;
}
std::unordered_map<std::shared_ptr<Node>, LKGraph> m_graphs;
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> m_heads;
};
bool ngraph::runtime::cpu::pass::CPULoopKernelFusion::run_on_function(
std::shared_ptr<ngraph::Function> function)
{
LoopKernelCollector lkc(function, m_min_kernel_size);
auto loop_kernels = lkc.get_loop_kernels();
for (auto lk : loop_kernels)
{
auto outputs = lk->get_kernel_outputs();
std::set<std::shared_ptr<Node>> lk_nodes_set(lk->get_node_list().begin(),
lk->get_node_list().end());
for (size_t i = 0; i < outputs.size(); i++)
{
auto ith_goe = std::make_shared<ngraph::op::GetOutputElement>(lk, i);
auto& ith_output = ith_goe->get_outputs().at(0);
if (outputs.at(i)->get_outputs().size() > 1)
{
throw ngraph_error(
"support for fusing multi-output nodes in loop kernels isn't yet implemented");
}
//TODO: revisit when we need support for multi-output nodes
auto& orig_output = outputs.at(i)->get_outputs().at(0);
//this is needed since replace_output modifies orig_output.get_inputs()
std::set<ngraph::descriptor::Input*> inputs_copy{begin(orig_output.get_inputs()),
end(orig_output.get_inputs())};
for (auto input : inputs_copy)
{
//this user is NOT internal to this loop kernel
//so it needs to be replaced with corresponding lk's GOE
if (lk_nodes_set.count(input->get_node()) == 0)
{
input->replace_output(ith_output);
}
}
}
}
return !lkc.get_loop_kernels().empty();
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPULoopKernelFusion : public ngraph::pass::FunctionPass
{
public:
CPULoopKernelFusion(size_t min_kernel_size = 2)
: FunctionPass()
, m_min_kernel_size(min_kernel_size)
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
protected:
size_t m_min_kernel_size;
};
}
}
}
}
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp" #include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp" #include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp" #include "ngraph/runtime/cpu/op/bounded_relu.hpp"
...@@ -60,6 +61,7 @@ ...@@ -60,6 +61,7 @@
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp" #include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp" #include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_loop_kernel_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp" #include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp" #include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
...@@ -75,8 +77,6 @@ ...@@ -75,8 +77,6 @@
#include "util/random.hpp" #include "util/random.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -1890,156 +1890,112 @@ TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_2rnn_layer_3lstm_cell) ...@@ -1890,156 +1890,112 @@ TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_2rnn_layer_3lstm_cell)
} }
} }
struct LKGraph TEST(cpu_fusion, loop_kernel_fusion_multiple_groups_pruned)
{ {
LKGraph(const NodeVector& ns, const NodeVector& ins) auto make_function = []() -> std::shared_ptr<Function> {
: m_inputs(ins) Shape shape{};
, m_nodes(ns) auto a = make_shared<op::Parameter>(element::f32, shape);
{ auto b = make_shared<op::Parameter>(element::f32, shape);
} auto c = make_shared<op::Parameter>(element::f32, shape);
NodeVector m_inputs; auto add_ab = a + b;
NodeVector m_nodes; auto add_abs = std::make_shared<op::Abs>(add_ab);
}; auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg;
class LoopKernelCollector
{
public:
LoopKernelCollector(std::shared_ptr<Function> f, size_t MIN_NODES_TO_FUSE)
{
for (auto n : f->get_ordered_ops())
{
if (is_fusible(n))
{
auto arg_from_fusible_group = collect_fusible_args(n);
//create a new group
if (!arg_from_fusible_group)
{
m_heads.insert(std::make_pair(n, n));
m_graphs.insert(std::make_pair(n, LKGraph{{n}, n->get_arguments()}));
NGRAPH_DEBUG << "Created a new group for " << n->get_name();
log_group(n);
}
else
{
auto smallest_head = m_heads.at(arg_from_fusible_group);
auto& lkgraph = m_graphs.at(smallest_head);
lkgraph.m_nodes.push_back(n);
for (auto arg : n->get_arguments())
{
if (is_leaf(arg))
{
lkgraph.m_inputs.push_back(arg);
}
}
m_heads.insert(std::make_pair(n, smallest_head));
log_group(smallest_head);
}
}
}
prune_graphs(MIN_NODES_TO_FUSE); auto d = make_shared<op::Parameter>(element::f32, shape);
} auto d_abs = std::make_shared<op::Abs>(d);
auto add_d = d_abs + add_ab;
auto neg_d = std::make_shared<op::Negative>(add_d);
const std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> get_loop_kernels() const auto mul_cd = neg_d * sub_c_neg;
{ auto f =
std::vector<std::shared_ptr<runtime::cpu::op::LoopKernel>> lks; std::make_shared<Function>(ngraph::NodeVector{mul_cd}, op::ParameterVector{a, b, c, d});
for (auto e : m_graphs)
{
auto& lkg = e.second;
std::unordered_set<std::shared_ptr<Node>> graph_nodes{lkg.m_nodes.begin(),
lkg.m_nodes.end()};
NodeVector member_outputs;
auto has_external_user = [&graph_nodes](std::shared_ptr<Node> u) { return f;
return graph_nodes.count(u) == 0;
}; };
for (auto member : lkg.m_nodes) pass::Manager pass_manager;
{ pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(3);
auto member_users = member->get_users(); auto cpu_f = make_function();
if (std::any_of(member_users.cbegin(), member_users.cend(), has_external_user)) auto int_f = make_function();
{ pass_manager.run_passes(cpu_f);
member_outputs.push_back(member); test::Uniform<float> rng(-100.0f, 100.0f);
} vector<vector<float>> args;
}
auto lk = make_shared<runtime::cpu::op::LoopKernel>(
lkg.m_nodes, member_outputs, lkg.m_inputs);
lks.push_back(lk);
}
return lks;
}
private: size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f);
static bool is_fusible(std::shared_ptr<Node> n) ASSERT_GT(lkn, 0);
{
return (std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n) ||
std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n));
}
bool is_leaf(std::shared_ptr<Node> src) { return src->is_parameter() || src->is_constant(); } for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
void prune_graphs(size_t MIN_NODES_TO_FUSE)
{
for (auto it = m_graphs.begin(); it != m_graphs.end();)
{
if (it->second.m_nodes.size() < MIN_NODES_TO_FUSE)
{ {
it = m_graphs.erase(it); vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
} }
else auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{ {
it++; EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
} }
}
void log_group(std::shared_ptr<Node> head) const TEST(cpu_fusion, loop_kernel_fusion_bounded_relu)
{ {
NGRAPH_DEBUG << "Group leader : " << head->get_name() << std::endl; auto make_function = []() -> std::shared_ptr<Function> {
NGRAPH_DEBUG << "Group members : " << m_graphs.at(head).m_nodes << std::endl; Shape shape{};
NGRAPH_DEBUG << "Inputs: " << m_graphs.at(head).m_inputs << std::endl; auto a = make_shared<op::Parameter>(element::f32, shape);
} auto relu = make_shared<op::Relu>(a);
auto upper_bound =
op::Constant::create<float>(element::f32, shape, std::vector<float>{6.0f});
auto minn = make_shared<op::Minimum>(relu, upper_bound);
auto absn = make_shared<op::Abs>(minn);
auto negn = std::make_shared<op::Negative>(absn);
std::shared_ptr<Node> collect_fusible_args(std::shared_ptr<Node> n) auto f = std::make_shared<Function>(ngraph::NodeVector{negn}, op::ParameterVector{a});
{
std::shared_ptr<Node> arg_from_fusible_group; return f;
for (auto arg : n->get_arguments()) };
{
//an argument is fusible and a part of some group pass::Manager pass_manager;
NGRAPH_DEBUG << "Considering " << arg->get_name(); pass_manager.register_pass<pass::VisualizeTree>("before_relu_fusion.pdf");
if (m_heads.count(arg) != 0) pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(3);
{ pass_manager.register_pass<pass::VisualizeTree>("after_relu_fusion.pdf");
if (!arg_from_fusible_group) auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f);
ASSERT_GT(lkn, 0);
for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
{ {
arg_from_fusible_group = arg; vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
} }
else auto int_results = execute(int_f, args, "INTERPRETER");
{ auto cpu_results = execute(cpu_f, args, "CPU");
if (!is_leaf(arg) && m_heads.at(arg) != m_heads.at(arg_from_fusible_group)) for (size_t i = 0; i < cpu_results.size(); i++)
{ {
return {nullptr}; EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
}
}
return arg_from_fusible_group;
} }
}
std::unordered_map<std::shared_ptr<Node>, LKGraph> m_graphs; TEST(cpu_fusion, loop_kernel_fusion_multiple_groups)
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> m_heads;
};
TEST(cpu_fusion, graph_partition_multiple_groups_one_pruned)
{ {
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::f32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::f32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape); auto c = make_shared<op::Parameter>(element::f32, shape);
auto add_ab = a + b; auto add_ab = a + b;
auto add_abs = std::make_shared<op::Abs>(add_ab); auto add_abs = std::make_shared<op::Abs>(add_ab);
auto abs_neg = std::make_shared<op::Negative>(add_abs); auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg; auto sub_c_neg = c - abs_neg;
auto d = make_shared<op::Parameter>(element::i32, shape); auto d = make_shared<op::Parameter>(element::f32, shape);
auto d_abs = std::make_shared<op::Abs>(d); auto d_abs = std::make_shared<op::Abs>(d);
auto add_d = d_abs + add_ab; auto add_d = d_abs + add_ab;
auto neg_d = std::make_shared<op::Negative>(add_d); auto neg_d = std::make_shared<op::Negative>(add_d);
...@@ -2048,47 +2004,82 @@ TEST(cpu_fusion, graph_partition_multiple_groups_one_pruned) ...@@ -2048,47 +2004,82 @@ TEST(cpu_fusion, graph_partition_multiple_groups_one_pruned)
auto f = auto f =
std::make_shared<Function>(ngraph::NodeVector{mul_cd}, op::ParameterVector{a, b, c, d}); std::make_shared<Function>(ngraph::NodeVector{mul_cd}, op::ParameterVector{a, b, c, d});
const size_t MIN_NODES_TO_FUSE = 3; return f;
LoopKernelCollector lkc(f, MIN_NODES_TO_FUSE); };
const auto& kernels = lkc.get_loop_kernels();
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(2);
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f);
ASSERT_GT(lkn, 0);
ASSERT_EQ(kernels.size(), 1); for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
ASSERT_EQ(kernels.at(0)->get_arguments(), (NodeVector{a, b, c})); {
ASSERT_EQ(kernels.at(0)->get_kernel_outputs(), (NodeVector{add_ab, sub_c_neg})); vector<float> tensor_val(shape_size(param->get_shape()));
ASSERT_EQ(kernels.at(0)->get_node_list(), (NodeVector{add_ab, add_abs, abs_neg, sub_c_neg})); rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
} }
TEST(cpu_fusion, graph_partition_one_group) TEST(cpu_fusion, loop_kernel_fusion_one_group)
{ {
auto make_function = []() -> std::shared_ptr<Function> {
Shape shape{}; Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::f32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::f32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape); auto c = make_shared<op::Parameter>(element::f32, shape);
auto add_ab = a + b; auto add_ab = a + b;
auto add_abs = std::make_shared<op::Abs>(add_ab); auto add_abs = std::make_shared<op::Abs>(add_ab);
auto abs_neg = std::make_shared<op::Negative>(add_abs); auto abs_neg = std::make_shared<op::Negative>(add_abs);
auto sub_c_neg = c - abs_neg; auto sub_c_neg = c - abs_neg;
auto d = make_shared<op::Parameter>(element::i32, shape); auto d = make_shared<op::Parameter>(element::f32, shape);
auto add_d = sub_c_neg + d; auto add_d = sub_c_neg + d;
auto abs_add_d = std::make_shared<op::Abs>(add_d); auto abs_add_d = std::make_shared<op::Abs>(add_d);
auto e = make_shared<op::Parameter>(element::i32, shape); auto e = make_shared<op::Parameter>(element::f32, shape);
auto add_e = e + abs_add_d; auto add_e = e + abs_add_d;
auto neg_e = std::make_shared<op::Negative>(add_e); auto neg_e = std::make_shared<op::Negative>(add_e);
auto f = auto f = std::make_shared<Function>(ngraph::NodeVector{neg_e},
std::make_shared<Function>(ngraph::NodeVector{neg_e}, op::ParameterVector{a, b, c, d, e}); op::ParameterVector{a, b, c, d, e});
return f;
};
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.run_passes(f); pass_manager.register_pass<runtime::cpu::pass::CPULoopKernelFusion>(2);
auto cpu_f = make_function();
auto int_f = make_function();
pass_manager.run_passes(cpu_f);
test::Uniform<float> rng(-100.0f, 100.0f);
vector<vector<float>> args;
const size_t MIN_NODES_TO_FUSE = 3; size_t lkn = count_ops_of_type<runtime::cpu::op::LoopKernel>(cpu_f);
LoopKernelCollector lkc(f, MIN_NODES_TO_FUSE); ASSERT_GT(lkn, 0);
const auto& kernels = lkc.get_loop_kernels();
ASSERT_EQ(kernels.size(), 1); for (shared_ptr<op::Parameter> param : cpu_f->get_parameters())
ASSERT_EQ(kernels.at(0)->get_arguments(), (NodeVector{a, b, c, d, e})); {
ASSERT_EQ(kernels.at(0)->get_kernel_outputs(), (NodeVector{neg_e})); vector<float> tensor_val(shape_size(param->get_shape()));
ASSERT_EQ(kernels.at(0)->get_node_list(), rng.initialize(tensor_val);
(NodeVector{add_ab, add_abs, abs_neg, sub_c_neg, add_d, abs_add_d, add_e, neg_e})); args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
} }
TEST(cpu_fusion, sigmoid_multiply_fusion) TEST(cpu_fusion, sigmoid_multiply_fusion)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment