Commit 924ec039 authored by Jayaram Bobba's avatar Jayaram Bobba

Merge branch 'jbobba/mkldnn-outlining' of…

Merge branch 'jbobba/mkldnn-outlining' of https://github.com/NervanaSystems/private-ngraph-cpp into jbobba/mkldnn-outlining
parents d7c4dedf e4f1abfa
......@@ -79,6 +79,7 @@ set (SRC
ops/unary_elementwise.cpp
pass/dump_sorted.cpp
pass/graph_rewrite.cpp
pass/inliner.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager_state.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "inliner.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ops/function_call.hpp"
std::vector<std::shared_ptr<ngraph::op::FunctionCall>>
ngraph::pass::InlineSmallCalls::create_inlining_plan(std::shared_ptr<ngraph::Function> f,
size_t depth)
{
std::vector<std::shared_ptr<ngraph::op::FunctionCall>> callees;
for (auto n : f->get_ops())
{
auto fc = std::dynamic_pointer_cast<op::FunctionCall>(n);
if (!fc)
{
continue;
}
auto callee_function = fc->get_functions().at(0);
NGRAPH_DEBUG << "InlineSmallCalls is considering " << callee_function->get_name() << " of "
<< fc->get_name();
size_t callee_size = callee_function->get_ops().size();
NGRAPH_DEBUG << "\t" << callee_function->get_name() << " size is " << callee_size
<< " , depth = " << depth;
if (depth < m_depth && callee_size < m_call_size_limit)
{
callees.push_back(fc);
}
}
return callees;
}
bool ngraph::pass::Inliner::inline_function_call(std::shared_ptr<ngraph::Node> inlinee,
std::shared_ptr<ngraph::Function> caller)
{
auto callsite = std::dynamic_pointer_cast<ngraph::op::FunctionCall>(inlinee);
if (!callsite)
{
return false;
}
//map args to parms
auto callee = callsite->get_functions().at(0);
if (callee->get_results().size() > 1)
{
return false; //relax in the next iteration (can't just use replace_node)
}
ngraph::NodeMap nm;
for (size_t i = 0; i < callee->get_parameters().size(); i++)
{
nm.add(callee->get_parameters().at(i), callsite->get_input_op(i));
}
ngraph::clone_function(callee, nm);
auto callee_graph = nm.get(callee->get_result());
caller->replace_node(callsite, callee_graph);
NGRAPH_DEBUG << "Inlined " << callee->get_name() << " of " << callsite->get_name() << " into "
<< caller->get_name();
return true;
}
bool ngraph::pass::Inliner::run_on_function_call(std::shared_ptr<ngraph::op::FunctionCall> fc)
{
auto f = fc->get_functions().at(0);
NGRAPH_DEBUG << "Inliner::run_on_function on " << f->get_name();
auto callees = m_inlining_heuristics->create_inlining_plan(f, m_depth);
if (!callees.size())
{
return false;
}
//we could clone_function f if we need to preserve it
run_on_functions(callees, f);
return true;
}
void ngraph::pass::Inliner::run_on_functions(
std::vector<std::shared_ptr<ngraph::op::FunctionCall>> callees,
std::shared_ptr<ngraph::Function> caller)
{
for (auto callee : callees)
{
m_depth++;
//recursive inlining
run_on_function_call(callee);
m_depth--;
inline_function_call(callee, caller);
}
}
bool ngraph::pass::Inliner::run_on_module(std::vector<std::shared_ptr<ngraph::Function>>& funcs)
{
auto outermost = funcs.front();
NGRAPH_DEBUG << "Outermost function = " << outermost->get_name();
auto callees = m_inlining_heuristics->create_inlining_plan(outermost, m_depth);
if (!callees.size())
{
return false;
}
run_on_functions(callees, outermost);
return true;
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "ngraph/ops/function_call.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class Inliner;
class InliningHeuristics;
class InlineSmallCalls;
}
}
class ngraph::pass::InliningHeuristics
{
public:
virtual std::vector<std::shared_ptr<ngraph::op::FunctionCall>>
create_inlining_plan(std::shared_ptr<ngraph::Function> f, size_t depth) = 0;
virtual ~InliningHeuristics() {}
};
class ngraph::pass::InlineSmallCalls : public ngraph::pass::InliningHeuristics
{
public:
InlineSmallCalls(size_t call_size_limit, size_t depth)
: InliningHeuristics()
, m_call_size_limit(call_size_limit)
, m_depth(depth)
{
}
std::vector<std::shared_ptr<ngraph::op::FunctionCall>>
create_inlining_plan(std::shared_ptr<ngraph::Function> f, size_t depth) override;
virtual ~InlineSmallCalls() {}
private:
size_t m_call_size_limit;
size_t m_depth;
};
class ngraph::pass::Inliner : public ModulePass
{
public:
Inliner(std::shared_ptr<InliningHeuristics> ih)
: ModulePass()
, m_inlining_heuristics(ih)
, m_depth(0)
{
}
static bool inline_function_call(std::shared_ptr<ngraph::Node> inlinee,
std::shared_ptr<ngraph::Function> caller);
bool run_on_function_call(std::shared_ptr<ngraph::op::FunctionCall> fc);
void run_on_functions(std::vector<std::shared_ptr<ngraph::op::FunctionCall>>,
std::shared_ptr<ngraph::Function> caller);
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
std::shared_ptr<InliningHeuristics> m_inlining_heuristics;
size_t m_depth;
};
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <iostream>
#include <memory>
......@@ -41,18 +42,18 @@ void ngraph::pass::Manager::initialize_default_passes()
void ngraph::pass::Manager::run_passes(shared_ptr<Function> func)
{
// find all functions
set<shared_ptr<Function>> tfs;
traverse_functions(func, [&](shared_ptr<Function> f) { tfs.insert(f); });
vector<shared_ptr<Function>> fs;
traverse_functions(func, [&](shared_ptr<Function> f) { fs.push_back(f); });
set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs);
vector<shared_ptr<Function>> fs;
for (shared_ptr<Function> f : get_state().get_functions())
{
for (size_t i = 0; i < f->get_output_size(); ++i)
{
f->get_output_op(i)->set_is_output();
}
fs.push_back(f);
}
for (shared_ptr<PassBase> pass : m_pass_list)
......
......@@ -31,6 +31,7 @@ set (SRC
eigen.cpp
element_type.cpp
file_util.cpp
inliner.cpp
input_output_assign.cpp
main.cpp
op.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/inliner.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(inline, basic)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, op::Parameters{A, B, C});
auto X = make_shared<op::Parameter>(element::f32, shape);
auto Y = make_shared<op::Parameter>(element::f32, shape);
auto Z = make_shared<op::Parameter>(element::f32, shape);
auto fc1 = make_shared<op::FunctionCall>(f, Nodes{X, Y, Z});
auto fc2 = make_shared<op::FunctionCall>(f, Nodes{X, Y, Z});
auto g = make_shared<Function>(fc1 + fc2, op::Parameters{X, Y, Z});
auto ih = std::make_shared<ngraph::pass::InlineSmallCalls>(10, 1);
pass::Manager pass_manager;
pass_manager.register_pass<pass::Inliner>(ih);
auto bc = g->get_ops().size();
pass_manager.run_passes(g);
auto ac = g->get_ops().size();
ASSERT_EQ(count_ops_of_type<op::FunctionCall>(g), 0); //check that FunctionCalls disappear
ASSERT_LT(bc, ac); //we should get more ops after inlining
}
TEST(inline, recursive)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B), op::Parameters{A, B});
auto X = make_shared<op::Parameter>(element::f32, shape);
auto Y = make_shared<op::Parameter>(element::f32, shape);
auto fc1 = make_shared<op::FunctionCall>(f, Nodes{X, Y});
auto g = make_shared<Function>(make_shared<op::Negative>(fc1), op::Parameters{X, Y});
auto P1 = make_shared<op::Parameter>(element::f32, shape);
auto P2 = make_shared<op::Parameter>(element::f32, shape);
auto P3 = make_shared<op::Parameter>(element::f32, shape);
auto fc2 = make_shared<op::FunctionCall>(g, Nodes{P1, P2});
auto e = make_shared<Function>(fc2 * P3, op::Parameters{P1, P2, P3});
auto ih = std::make_shared<ngraph::pass::InlineSmallCalls>(15, 2);
pass::Manager pass_manager;
pass_manager.register_pass<pass::Inliner>(ih);
auto bce = e->get_ops().size();
pass_manager.run_passes(e);
auto ace = e->get_ops().size();
ASSERT_EQ(count_ops_of_type<op::FunctionCall>(g), 0); //check that FunctionCalls disappear
ASSERT_EQ(count_ops_of_type<op::Add>(g), 1); //FunctionCall is replaced w/ Add
ASSERT_EQ(count_ops_of_type<op::FunctionCall>(e), 0);
ASSERT_LT(bce, ace); //we should get more ops after inlining
}
......@@ -49,18 +49,3 @@ public:
return is_match;
}
};
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
size_t count = 0;
for (auto op : f->get_ops())
{
if (std::dynamic_pointer_cast<T>(op))
{
count++;
}
}
return count;
}
......@@ -56,3 +56,18 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve
{
tv->write(values.data(), 0, values.size() * sizeof(T));
}
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
size_t count = 0;
for (auto op : f->get_ops())
{
if (std::dynamic_pointer_cast<T>(op))
{
count++;
}
}
return count;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment