Unverified Commit f078800c authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Make Common Function Elimination a pass (#1236)

* change GPU to use cfe pass

* update per review comments
parent 4d3e4721
...@@ -109,6 +109,7 @@ set (SRC ...@@ -109,6 +109,7 @@ set (SRC
op/util/unary_elementwise.cpp op/util/unary_elementwise.cpp
pass/assign_placement.cpp pass/assign_placement.cpp
pass/algebraic_simplification.cpp pass/algebraic_simplification.cpp
pass/common_function_collection.cpp
pass/constant_folding.cpp pass/constant_folding.cpp
pass/cse.cpp pass/cse.cpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <sstream>
#include "common_function_collection.hpp"
using namespace std;
using namespace ngraph;
pass::CommonFunctionCollection::CommonFunctionCollection(function<string(Node&, string)> emitter,
unordered_map<Node*, Node*>& result_map,
string& emitted_functions)
: m_emit_op_as_function(emitter)
, m_node_function_map(result_map)
, m_emitted_functions(emitted_functions)
{
}
pass::CommonFunctionCollection::~CommonFunctionCollection()
{
}
bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>& functions)
{
// This for loop creates a collection of functions that are called more than once
// and emitting them as globally callable functions.
// match_function_map `key` contains the entire string of the function emitted for the
// `value` Node*
unordered_map<string, Node*> match_function_map;
stringstream ss;
const string function_name = "__f__";
for (const shared_ptr<Function>& current_function : functions)
{
list<shared_ptr<Node>> op_list = current_function->get_ordered_ops();
for (const shared_ptr<Node>& op : op_list)
{
if (op->is_constant() || op->is_parameter())
{
continue;
}
Node& node = *op;
// First emit the op as a function, something like this:
// static void __f__(float* _arg0, float *_out1)
// {
// op specific code here
// }
//
// Then do a simple string compare in match_function_map to see if there is
// another op that emits the exact same code.
// If a match is found then the current node is mapped to call the original node's
// function and the original node is *also* mapped to call the original node's function.
// We also emit the static function declaration to m_emitted_functions when the match
// is found the first time.
string match_function = m_emit_op_as_function(node, function_name);
auto it = match_function_map.find(match_function);
if (it != match_function_map.end())
{
m_node_function_map.insert({&node, it->second});
if (m_node_function_map.find(it->second) == m_node_function_map.end())
{
m_node_function_map.insert({it->second, it->second});
// All of the functions are created with the same name `__f__` so here
// we rename it to something unique so we can compile everything when done.
auto offset = match_function.find(function_name);
string emitted_function = match_function;
string match_function_name = create_function_name(*it->second);
emitted_function.replace(offset, function_name.size(), match_function_name);
ss << emitted_function << "\n";
}
}
else
{
match_function_map.insert({match_function, &node});
}
}
}
m_emitted_functions = ss.str();
return false;
}
string pass::CommonFunctionCollection::create_function_name(const Node& node)
{
return "func_" + node.get_name();
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <unordered_map>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class CommonFunctionCollection;
}
}
class ngraph::pass::CommonFunctionCollection : public ModulePass
{
public:
// @brief Create the CommonFunctionCollection pass
// @param function_emitter - This is a function that takes a reference to a Node and as string.
// The string is the name of the emitted function and the body of the function is
// the code for the op.
// @param result_map - This is a mapping of source node -> emitted static function node, where
/// the key is the source node and the value is the emitted static function node. The
// name of the function to call is create_function_name(<emitted static function node>)
// @param emitted_functions - string to contain the emitted code for all of the static
// functions.
CommonFunctionCollection(std::function<std::string(Node&, std::string)> function_emitter,
std::unordered_map<Node*, Node*>& result_map,
std::string& emitted_functions);
virtual ~CommonFunctionCollection();
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
// @brief Construct the name of the function to call for this op
// @param node - Node used to construct the function name. This node is the `value` of the
// result_map passed to the pass's constructor.
// @return string containing the name of the function to be called
static std::string create_function_name(const Node& node);
private:
std::function<std::string(Node&, std::string)> m_emit_op_as_function;
std::unordered_map<Node*, Node*>& m_node_function_map;
std::string& m_emitted_functions;
};
...@@ -100,6 +100,7 @@ ...@@ -100,6 +100,7 @@
#include "ngraph/op/tan.hpp" #include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/core_fusion.hpp" #include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp" #include "ngraph/pass/cse.hpp"
#include "ngraph/pass/dump_sorted.hpp" #include "ngraph/pass/dump_sorted.hpp"
...@@ -364,6 +365,14 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -364,6 +365,14 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUShuffleFolding>(); pass_manager.register_pass<runtime::cpu::pass::CPUShuffleFolding>();
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>(); pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
unordered_map<Node*, Node*> node_function_map;
string common_function_string;
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
this,
placeholders::_1,
placeholders::_2);
pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, node_function_map, common_function_string);
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment, true); pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment, true);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
...@@ -518,73 +527,7 @@ using namespace ngraph::runtime; ...@@ -518,73 +527,7 @@ using namespace ngraph::runtime;
} }
writer << "\n"; writer << "\n";
// This for loop creates a collection of functions that are called more than once writer << common_function_string << "\n";
// and emitting them as globally callable functions.
// ops implement the is_functionally_identical method
unordered_map<Node*, string> match_functions;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
list<shared_ptr<Node>> tmp = function_ordered_ops.at(current_function);
if (tmp.size() < 2)
{
// Since we are comparing ops there must be at least two ops to proceed.
continue;
}
vector<shared_ptr<Node>> op_list{tmp.begin(), tmp.end()};
unordered_map<const Node*, string> node_cache;
for (size_t i = 0; i < op_list.size(); i++)
{
// constants and parameters cannot be outlined
if (op_list[i]->is_constant() || op_list[i]->is_parameter())
{
continue;
}
Node& node = *op_list[i];
auto handler = dispatcher.find(type_index(typeid(node)));
if (handler == dispatcher.end())
{
throw ngraph_error("Unhandled op during code generation : " + node.description());
}
string s = emit_op_as_function(node, "f");
node_cache.insert({&node, s});
}
for (size_t i = 0; i < op_list.size() - 1; i++)
{
if (op_list[i]->is_constant() || op_list[i]->is_parameter())
{
continue;
}
if (contains_key(match_functions, op_list[i].get()))
{
continue;
}
string match_function_name;
for (size_t j = i + 1; j < op_list.size(); j++)
{
if (op_list[j]->is_constant() || op_list[j]->is_parameter())
{
continue;
}
Node* op1 = op_list[i].get();
Node* op2 = op_list[j].get();
if (is_functionally_identical(*op1, *op2, node_cache))
{
if (match_function_name.empty())
{
match_function_name = "func_" + op1->get_name();
match_functions.insert({op1, match_function_name});
}
match_functions.insert({op2, match_function_name});
}
}
if (!match_function_name.empty())
{
writer << emit_op_as_function(*op_list[i], match_function_name);
}
}
}
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions()) for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{ {
...@@ -818,15 +761,15 @@ using namespace ngraph::runtime; ...@@ -818,15 +761,15 @@ using namespace ngraph::runtime;
writer.indent++; writer.indent++;
} }
string func_name; auto it = node_function_map.find(node.get());
auto it = match_functions.find(node.get()); if (it == node_function_map.end())
if (it == match_functions.end())
{ {
handler->second(this, writer, node.get(), in, out); handler->second(this, writer, node.get(), in, out);
} }
else else
{ {
func_name = it->second; string func_name =
ngraph::pass::CommonFunctionCollection::create_function_name(*it->second);
vector<string> names; vector<string> names;
for (const TensorViewWrapper& tv : in) for (const TensorViewWrapper& tv : in)
{ {
...@@ -1302,6 +1245,10 @@ string runtime::cpu::CPU_ExternalFunction::emit_op_as_function(const Node& node, ...@@ -1302,6 +1245,10 @@ string runtime::cpu::CPU_ExternalFunction::emit_op_as_function(const Node& node,
// Work around a compiler warning (*node inside typeid may have effects // Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.) // with shared pointers, which is fine here but clang doesn't like it.)
auto handler = dispatcher.find(type_index(typeid(node))); auto handler = dispatcher.find(type_index(typeid(node)));
if (handler == dispatcher.end())
{
throw ngraph_error("Unhandled op during function emit : " + node.description());
}
vector<TensorViewWrapper> in; vector<TensorViewWrapper> in;
size_t arg_index = 0; size_t arg_index = 0;
set<string> arg_names; set<string> arg_names;
......
...@@ -97,6 +97,7 @@ ...@@ -97,6 +97,7 @@
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp" #include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp" #include "ngraph/runtime/gpu/gpu_emitter.hpp"
...@@ -428,55 +429,6 @@ void runtime::gpu::GPU_ExternalFunction::emit_function_declarations() ...@@ -428,55 +429,6 @@ void runtime::gpu::GPU_ExternalFunction::emit_function_declarations()
m_writer << "\n"; m_writer << "\n";
} }
void runtime::gpu::GPU_ExternalFunction::collect_unique_functions()
{
// This for loop creates a collection of functions that are called more than once
// and emitting them as globally callable functions.
// ops implement the is_functionally_identical method
unordered_map<string, string> match_function_map;
for (shared_ptr<Function> current_function : m_pass_manager.get_state().get_functions())
{
list<shared_ptr<Node>> tmp = m_function_ordered_ops.at(current_function);
if (tmp.size() < 2)
{
// Since we are comparing ops there must be at least two ops to proceed.
continue;
}
vector<shared_ptr<Node>> op_list{tmp.begin(), tmp.end()};
for (size_t i = 0; i < op_list.size(); i++)
{
if (op_list[i]->is_constant() || op_list[i]->is_parameter())
{
continue;
}
Node& node = *op_list[i];
auto handler = dispatcher.find(type_index(typeid(node)));
if (handler == dispatcher.end())
{
throw ngraph_error("Unhandled op during code generation : " + node.description());
}
string match_function = emit_op_as_function(node, "__f__");
string match_function_name;
if (contains_key(match_function_map, match_function))
{
match_function_name = match_function_map[match_function];
}
else
{
auto offset = match_function.find("__f__");
string emitted_function = match_function;
match_function_name = "func_" + node.get_name();
emitted_function.replace(offset, 5, match_function_name);
match_function_map.insert({match_function, match_function_name});
m_writer << emitted_function << "\n";
}
m_node_function_map.insert({&node, match_function_name});
}
}
}
void runtime::gpu::GPU_ExternalFunction::emit_temp_mem_pool_allocation( void runtime::gpu::GPU_ExternalFunction::emit_temp_mem_pool_allocation(
shared_ptr<Function> current_function) shared_ptr<Function> current_function)
{ {
...@@ -636,15 +588,15 @@ void runtime::gpu::GPU_ExternalFunction::emit_functions() ...@@ -636,15 +588,15 @@ void runtime::gpu::GPU_ExternalFunction::emit_functions()
} }
// Emit operation body // Emit operation body
string func_name; auto it = m_node_function_map.find(node.get());
func_name = m_node_function_map[node.get()]; if (it == m_node_function_map.end())
if (func_name.empty())
{ {
//throw runtime_error("No matching function found for '" + node->get_name() + "'");
handler->second(this, m_writer, node.get(), in, out); handler->second(this, m_writer, node.get(), in, out);
} }
else else
{ {
string func_name =
ngraph::pass::CommonFunctionCollection::create_function_name(*it->second);
vector<string> names; vector<string> names;
for (const GPU_TensorViewWrapper& tv : in) for (const GPU_TensorViewWrapper& tv : in)
{ {
...@@ -692,6 +644,13 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -692,6 +644,13 @@ void runtime::gpu::GPU_ExternalFunction::compile()
// For now, just make everyone row-major. // For now, just make everyone row-major.
m_pass_manager.register_pass<pass::ResultCopyElimination>(); m_pass_manager.register_pass<pass::ResultCopyElimination>();
m_pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>(); m_pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
string common_function_string;
auto femitter = bind(&ngraph::runtime::gpu::GPU_ExternalFunction::emit_op_as_function,
this,
placeholders::_1,
placeholders::_2);
m_pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, m_node_function_map, common_function_string);
m_pass_manager.register_pass<pass::Liveness>(); m_pass_manager.register_pass<pass::Liveness>();
m_pass_manager.register_pass<pass::MemoryLayout>(64); m_pass_manager.register_pass<pass::MemoryLayout>(64);
m_pass_manager.register_pass<pass::DumpSorted>(dump_filename); m_pass_manager.register_pass<pass::DumpSorted>(dump_filename);
...@@ -706,7 +665,7 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -706,7 +665,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
emit_timer_functions(); emit_timer_functions();
emit_constant_declarations(); emit_constant_declarations();
emit_function_declarations(); emit_function_declarations();
collect_unique_functions(); m_writer << common_function_string << "\n";
emit_functions(); emit_functions();
// allocate device buffers for primitive arguments and workspace // allocate device buffers for primitive arguments and workspace
......
...@@ -104,7 +104,7 @@ namespace ngraph ...@@ -104,7 +104,7 @@ namespace ngraph
std::map<std::string, size_t> m_name_index_map; std::map<std::string, size_t> m_name_index_map;
std::unordered_map<std::string, std::string> m_variable_name_map; std::unordered_map<std::string, std::string> m_variable_name_map;
std::unordered_map<const Node*, std::string> m_node_function_map; std::unordered_map<Node*, Node*> m_node_function_map;
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>> std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>
m_function_ordered_ops; m_function_ordered_ops;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment