Commit 84167659 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Function call working (#2472)

* function call working

* fix compile error

* fix compile error

* add attribute support to plot_graph

* fix build error

* fix merge error

* better colors for FunctionCall op
parent 1575e2d1
......@@ -172,6 +172,19 @@ set (SRC
cpio.cpp
)
set(SRC ${SRC}
runtime/hybrid/hybrid_backend.cpp
runtime/hybrid/hybrid_executable.cpp
runtime/hybrid/hybrid_util.cpp
runtime/hybrid/op/function_call.cpp
runtime/hybrid/pass/default_placement.cpp
runtime/hybrid/pass/dump.cpp
runtime/hybrid/pass/fix_get_output_element.cpp
runtime/hybrid/pass/liveness.cpp
runtime/hybrid/pass/memory_layout.cpp
)
if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND SRC distributed.cpp)
endif()
......
......@@ -31,6 +31,8 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/result_vector.hpp"
#include "ngraph/util.hpp"
......@@ -564,3 +566,13 @@ bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::share
return true;
}
void ngraph::plot_graph(
std::shared_ptr<Function> f,
const std::string& filename,
std::function<void(const Node& node, std::vector<std::string>& attributes)> attributes)
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::VisualizeTree>(filename, attributes);
pass_manager.run_passes(f);
}
......@@ -325,4 +325,9 @@ namespace ngraph
bool is_strided(const Strides& strides);
bool is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t> valid_ranks);
void plot_graph(
std::shared_ptr<Function> f,
const std::string& filename,
std::function<void(const Node& node, std::vector<std::string>& attributes)> = nullptr);
}
......@@ -15,7 +15,6 @@
# ******************************************************************************
add_subdirectory(interpreter)
add_subdirectory(hybrid)
# With CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS, when creating cpu_backend.dll, link reports error: library limit of 65535 objects exceeded
if (NGRAPH_CPU_ENABLE)
......
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
add_library(hybrid_base STATIC
hybrid_backend.cpp
hybrid_executable.cpp
hybrid_util.cpp
pass/default_placement.cpp
pass/dump.cpp
pass/fix_get_output_element.cpp
pass/liveness.cpp
pass/memory_layout.cpp
)
target_link_libraries(hybrid_base PUBLIC ngraph)
install(TARGETS hybrid_base
ARCHIVE DESTINATION "${NGRAPH_INSTALL_LIB}"
)
......@@ -31,17 +31,6 @@
using namespace ngraph;
using namespace std;
static void node_modifiers(const Node& node, vector<string>& attributes)
{
vector<string> colors = {"\"#A0FFA0\"", "\"#FFF790\""};
if (node.get_placement_index() < colors.size())
{
string color = colors[node.get_placement_index()];
attributes.push_back("style=filled");
attributes.push_back("fillcolor=" + color);
}
}
runtime::hybrid::HybridExecutable::HybridExecutable(
const std::vector<std::shared_ptr<runtime::Backend>>& backend_list,
const shared_ptr<Function>& func,
......@@ -51,49 +40,21 @@ runtime::hybrid::HybridExecutable::HybridExecutable(
, m_backend_list{backend_list}
, m_debug_enabled{debug_enabled}
{
// Run placement pass
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
// pass_manager.register_pass<runtime::hybrid::pass::MemoryLayout>();
if (m_debug_enabled)
{
// Run placement pass
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
// pass_manager.register_pass<runtime::hybrid::pass::MemoryLayout>();
if (m_debug_enabled)
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers);
}
pass_manager.run_passes(m_function);
// Split function to sub_functions
tie(m_sub_functions, m_map_parameter_to_result) =
runtime::hybrid::split_function_by_placement(m_function);
// Compile subfunctions in corresponding backends
size_t subfunction_number = 0;
for (shared_ptr<Function>& sub_function : m_sub_functions)
{
size_t placement = sub_function->get_placement();
if (m_debug_enabled)
{
string name = "subfunction_" + to_string(subfunction_number++);
ngraph::pass::Manager pm;
pm.register_pass<ngraph::pass::VisualizeTree>(name + ".png", node_modifiers);
pm.register_pass<runtime::hybrid::pass::Dump>(name + ".dump");
pm.run_passes(sub_function);
}
auto backend = m_backend_list[placement];
shared_ptr<Executable> exec = backend->compile(sub_function);
m_executable_map[sub_function] = exec;
// Compile will replace nodes so we need to make one more pass through all
// ops to reset placement
for (auto op : sub_function->get_ops())
{
op->set_placement_index(placement);
}
}
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers);
}
pass_manager.run_passes(m_function);
runtime::hybrid::rewrite_function(m_function, m_backend_list);
m_executable = backend_list[0]->compile(m_function);
set_parameters_and_results(*func);
}
......@@ -105,7 +66,7 @@ bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Te
using node_map_t = unordered_map<shared_ptr<Node>, shared_ptr<runtime::Tensor>>;
// Parameter and result node in sub_function maps to one Tensor
// Parameter and result node in m_function maps to one Tensor
node_map_t map_node_to_tensor;
for (size_t i = 0; i < inputs.size(); ++i)
{
......@@ -116,85 +77,80 @@ bool runtime::hybrid::HybridExecutable::call(const vector<shared_ptr<runtime::Te
map_node_to_tensor[m_function->get_results()[i]] = outputs[i];
}
// Call subfunctions
for (const shared_ptr<Function>& sub_function : m_sub_functions)
{
// Init backend
size_t placement = sub_function->get_placement();
auto backend = m_backend_list[placement];
// Init backend
size_t placement = m_function->get_placement();
auto backend = m_backend_list[placement];
// Prepare parameter Tensors
vector<shared_ptr<runtime::Tensor>> parameters;
for (const shared_ptr<op::Parameter>& parameter_node : sub_function->get_parameters())
// Prepare parameter Tensors
vector<shared_ptr<runtime::Tensor>> parameters;
for (const shared_ptr<op::Parameter>& parameter_node : m_function->get_parameters())
{
auto it = map_node_to_tensor.find(parameter_node);
if (it != map_node_to_tensor.end())
{
auto it = map_node_to_tensor.find(parameter_node);
if (it != map_node_to_tensor.end())
if (it->second->get_parent() == backend.get())
{
if (it->second->get_parent() == backend.get())
{
parameters.push_back(it->second);
}
else
{
auto parameter = backend->create_tensor(parameter_node->get_element_type(),
parameter_node->get_shape());
parameter->copy_from(*(it->second));
parameters.push_back(parameter);
}
parameters.push_back(it->second);
}
else
{
// Handle temporary tensors that go between subgraphs
auto result_node = m_map_parameter_to_result.at(parameter_node);
auto result = map_node_to_tensor.at(result_node);
auto parameter = backend->create_tensor(parameter_node->get_element_type(),
parameter_node->get_shape());
parameter->copy_from(*result);
map_node_to_tensor[parameter_node] = parameter;
parameter->copy_from(*(it->second));
parameters.push_back(parameter);
}
}
else
{
// Handle temporary tensors that go between subgraphs
auto result_node = m_map_parameter_to_result.at(parameter_node);
auto result = map_node_to_tensor.at(result_node);
auto parameter = backend->create_tensor(parameter_node->get_element_type(),
parameter_node->get_shape());
parameter->copy_from(*result);
map_node_to_tensor[parameter_node] = parameter;
parameters.push_back(parameter);
}
}
// Prepare result Tensors
vector<shared_ptr<runtime::Tensor>> results;
map<runtime::Tensor*, runtime::Tensor*> copy_back;
for (const shared_ptr<op::Result>& result_node : sub_function->get_results())
// Prepare result Tensors
vector<shared_ptr<runtime::Tensor>> results;
map<runtime::Tensor*, runtime::Tensor*> copy_back;
for (const shared_ptr<op::Result>& result_node : m_function->get_results())
{
auto it = map_node_to_tensor.find(result_node);
if (it != map_node_to_tensor.end())
{
auto it = map_node_to_tensor.find(result_node);
if (it != map_node_to_tensor.end())
if (it->second->get_parent() == backend.get())
{
if (it->second->get_parent() == backend.get())
{
results.push_back(it->second);
}
else
{
auto result = backend->create_tensor(result_node->get_element_type(),
result_node->get_shape());
results.push_back(result);
copy_back.insert({result.get(), it->second.get()});
}
results.push_back(it->second);
}
else
{
// Handle temporary tensors that go between subgraphs
auto result = backend->create_tensor(result_node->get_element_type(),
result_node->get_shape());
map_node_to_tensor[result_node] = result;
results.push_back(result);
copy_back.insert({result.get(), it->second.get()});
}
}
// Call
auto exec = m_executable_map[sub_function];
exec->call(results, parameters);
// Need to copy any results to the correct device
for (const auto& p : copy_back)
else
{
p.second->copy_from(*p.first);
// Handle temporary tensors that go between subgraphs
auto result =
backend->create_tensor(result_node->get_element_type(), result_node->get_shape());
map_node_to_tensor[result_node] = result;
results.push_back(result);
}
}
m_executable->call(results, parameters);
// Need to copy any results to the correct device
for (const auto& p : copy_back)
{
p.second->copy_from(*p.first);
}
return rc;
}
......
......@@ -47,13 +47,12 @@ public:
private:
std::shared_ptr<ngraph::Function> m_function;
std::vector<std::shared_ptr<ngraph::Function>> m_sub_functions;
std::shared_ptr<Executable> m_executable;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>, std::shared_ptr<ngraph::op::Result>>
m_map_parameter_to_result;
std::vector<std::shared_ptr<runtime::Backend>> m_backend_list;
bool m_debug_enabled = false;
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<Executable>> m_executable_map;
size_t get_placement(const runtime::Tensor* t);
};
......@@ -15,9 +15,11 @@
//*****************************************************************************
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/hybrid/op/function_call.hpp"
using namespace ngraph;
using namespace std;
......@@ -75,7 +77,7 @@ static vector<unordered_set<shared_ptr<Node>>>
previous_placement = independent_node->get_placement_index();
sorted_nodes.push_back(node_map.at(independent_node));
for (auto user : independent_node->get_users())
for (auto user : independent_node->get_users(true))
{
Node* user_node = user.get();
node_dependency_count.at(user_node) -= 1;
......@@ -97,15 +99,24 @@ static vector<unordered_set<shared_ptr<Node>>>
// Build clusters from the sorted_nodes
previous_placement = Node::placement_invalid;
vector<unordered_set<shared_ptr<Node>>> clusters;
clusters.push_back(unordered_set<shared_ptr<Node>>());
for (shared_ptr<Node> node : sorted_nodes)
{
size_t node_placement = node->get_placement_index();
if (node_placement != previous_placement)
if (node_placement == 0)
{
clusters[0].insert(node);
}
else
{
unordered_set<shared_ptr<Node>> new_cluster;
clusters.push_back(new_cluster);
if (node_placement != previous_placement)
{
unordered_set<shared_ptr<Node>> new_cluster;
clusters.push_back(new_cluster);
}
clusters.back().insert(node);
}
clusters.back().insert(node);
previous_placement = node_placement;
}
......@@ -137,131 +148,106 @@ static vector<unordered_set<shared_ptr<Node>>>
return clusters;
}
// Insert result and parameter node between src_node and dst_node by splitting the graph
//
// Before: | After:
// (Device:0) (Device:1) | (Device:0) (Device:0) (Device:1) (Device:1)
// +-----+---+ +---+-----+ | +-----+---+ +---+-----+ +-----+---+ +---+-----+
// | | | | | | | | | | | | | | | | | | |
// | | o +--[0]--> i | | | | | o +--[4]--> i | | | | o +--[8]--> i | |
// | | <--[1]--+ | | | | | <--[5]--+ | | | | <--[9]--+ | |
// | src +---+ +---+ dst | | | src +---+ +---+ res | | par +---+ +---+ dst |
// | | | | | | | | | | | | |
// | +------[2]------> | | | +------[6]------> | | +------[10]-----> |
// | <------[3]------+ | | | <------[7]------+ | | <------[11]-----+ |
// +-----+ +-----+ | +-----+ +-----+ +-----+ +-----+
static map<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
insert_result_parameter_split(const shared_ptr<Node>& src_node,
const shared_ptr<Node>& dst_node)
{
map<shared_ptr<op::Result>, shared_ptr<op::Parameter>> result_map;
for (descriptor::Input& input : dst_node->get_inputs())
{
if (input.get_output().get_node() == src_node)
{
descriptor::Input* dst_input = &input;
descriptor::Output* src_output = &input.get_output();
// Make parameter node
shared_ptr<op::Parameter> par_node =
make_shared<op::Parameter>(src_output->get_element_type(), src_output->get_shape());
par_node->set_placement_index(dst_node->get_placement_index());
// Fix input / output among src, dst and par
// Remove [0]
src_output->remove_input(dst_input);
// Remove [0] (again), add [8], remove [1], add [9]
dst_input->replace_output(par_node, 0);
// Add res node
shared_ptr<op::Result> res_node =
make_shared<op::Result>(src_node); // Add [4], [5], [6], [7]
res_node->set_placement_index(src_node->get_placement_index());
result_map.insert({res_node, par_node});
}
}
return result_map;
}
// will be removed when the backends move to the latest Hybrid backend
pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Result>>>
runtime::hybrid::split_function_by_placement(const shared_ptr<Function>& f)
void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
const vector<shared_ptr<runtime::Backend>>& backend_list)
{
// Split functions to clusters of nodes that can be computed together
vector<unordered_set<shared_ptr<Node>>> clusters = ::group_function_nodes_to_clusters(f);
// Map from (intermediate) parameter to result node, for guiding data copy among devices
unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Result>> map_parameter_to_result;
// Split neighboring nodes if they belong to different clusters
// TODO: optimization to group multiple result node from the same source,
// and to group the parameter node in the same cluster with the same result node source
// unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Result>> map_parameter_to_result;
unordered_map<shared_ptr<Node>, unordered_set<shared_ptr<Node>>*> map_node_to_cluster;
for (auto& cluster : clusters)
{
for (auto node : cluster)
{
map_node_to_cluster[node] = &cluster;
}
}
for (auto dst_node : f->get_ordered_ops())
{
for (auto src_node : dst_node->get_arguments())
if (cluster.size() > 0)
{
auto src_cluster = map_node_to_cluster.at(src_node);
auto dst_cluster = map_node_to_cluster.at(dst_node);
if (src_cluster != dst_cluster)
shared_ptr<Node> tmp_node = *cluster.begin();
auto placement = tmp_node->get_placement_index();
if (placement != 0)
{
// Split src_node and dst_node
map<shared_ptr<op::Result>, shared_ptr<op::Parameter>> res_par_pair_map =
::insert_result_parameter_split(src_node, dst_node);
for (const auto& res_par_pair : res_par_pair_map)
// This is a non-native cluster so make it a FunctionCall
vector<shared_ptr<Node>> function_call_inputs;
vector<shared_ptr<Node>> function_call_outputs;
ParameterVector cluster_inputs;
NodeVector cluster_outputs;
for (auto node : cluster)
{
shared_ptr<op::Result> res_node = res_par_pair.first;
shared_ptr<op::Parameter> par_node = res_par_pair.second;
map_parameter_to_result[par_node] = res_node;
for (auto input : node->get_arguments())
{
if (input->get_placement_index() == 0)
{
// Since this input is from outside the cluster we need to create
// a new Parameter node placed in the cluster instead of this external
// node
descriptor::Output* source_output = input->get_output_to(node);
descriptor::Input* target_input = node->get_input_from(input);
auto new_parameter = make_shared<ngraph::op::Parameter>(
source_output->get_element_type(), source_output->get_shape());
descriptor::Output& new_output = new_parameter->get_outputs()[0];
new_parameter->set_placement_index(placement);
target_input->replace_output(new_output);
cluster_inputs.push_back(new_parameter);
function_call_inputs.push_back(input);
}
}
for (auto output : node->get_users(true))
{
if (output->get_placement_index() == 0)
{
// Since this output is to outside the cluster we need to create
// a new Result node placed in the cluster instead of this external
// node
function_call_outputs.push_back(output);
cluster_outputs.push_back(node);
}
}
}
// Insert newly created nodes into clusters
src_cluster->insert(res_node);
dst_cluster->insert(par_node);
// Now make a FunctionCall out of the nodes in cluster, including the new nodes
// we just added
auto sub_function = make_shared<Function>(cluster_outputs, cluster_inputs);
sub_function->set_placement(placement);
ngraph::plot_graph(sub_function, "sub_function.png", node_modifiers);
auto fc = make_shared<runtime::hybrid::op::FunctionCall>(function_call_outputs,
function_call_inputs,
sub_function,
backend_list[placement]);
fc->set_placement_index(0);
for (size_t i = 0; i < function_call_outputs.size(); i++)
{
// // First add a GetOutputElement to the ith output of the FunctionCall
// auto goe = make_shared<GetOutpu
auto old_source = cluster_outputs[i];
auto new_source = fc;
auto target = function_call_outputs[i];
descriptor::Input* target_input = target->get_input_from(old_source);
descriptor::Output& new_output = new_source->get_outputs()[i];
target_input->replace_output(new_output);
}
}
}
}
ngraph::plot_graph(f, "f.png", node_modifiers);
}
// Create functions from clusters
vector<shared_ptr<Function>> sub_functions;
for (auto cluster : clusters)
void runtime::hybrid::node_modifiers(const Node& node, vector<string>& attributes)
{
vector<string> colors = {"\"#A0FFA0\"", "\"#FFF790\""};
auto fc = dynamic_cast<const hybrid::op::FunctionCall*>(&node);
if (fc != nullptr)
{
ParameterVector par_vector;
ResultVector res_vector;
size_t placement = -1;
for (auto node : cluster)
{
placement = node->get_placement_index();
if (auto res_node = dynamic_pointer_cast<op::Result>(node))
{
res_vector.push_back(res_node);
}
else if (auto par_node = dynamic_pointer_cast<op::Parameter>(node))
{
par_vector.push_back(par_node);
}
}
auto sub_function = make_shared<Function>(res_vector, par_vector);
sub_function->set_placement(placement);
sub_functions.push_back(sub_function);
#ifdef HYBRID_DEBUG
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::VisualizeTree>("subgraph_" + to_string(index++) +
".png");
pass_manager.run_passes(sub_function);
#endif
string fill_color = colors[fc->get_function()->get_placement()];
string outline_color = colors[node.get_placement_index()];
attributes.push_back("style=filled");
attributes.push_back("fillcolor=" + fill_color);
attributes.push_back("color=" + outline_color);
attributes.push_back("penwidth=3");
}
else if (node.get_placement_index() < colors.size())
{
string color = colors[node.get_placement_index()];
attributes.push_back("style=filled");
attributes.push_back("fillcolor=" + color);
}
return make_pair(sub_functions, map_parameter_to_result);
}
......@@ -30,11 +30,11 @@ namespace ngraph
{
namespace hybrid
{
// Split function to function(s) with unique placement
std::pair<
std::vector<std::shared_ptr<Function>>,
std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Result>>>
split_function_by_placement(const std::shared_ptr<Function>& f);
void rewrite_function(
const std::shared_ptr<Function>& f,
const std::vector<std::shared_ptr<runtime::Backend>>& backend_list);
void node_modifiers(const Node& node, std::vector<std::string>& attributes);
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "function_call.hpp"
#include "ngraph/runtime/backend.hpp"
using namespace std;
using namespace ngraph;
runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs,
const NodeVector& inputs,
shared_ptr<Function> function,
shared_ptr<Backend> backend)
: Op("FunctionCall", inputs)
, m_outputs{outputs}
, m_function{function}
, m_backend{backend}
, m_executable{backend->compile(function)}
{
set_output_size(outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
set_output_type(i, outputs[i]->get_element_type(), outputs[i]->get_output_shape(0));
}
}
shared_ptr<Node>
runtime::hybrid::op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<FunctionCall>(m_outputs, new_args, m_function, m_backend);
}
shared_ptr<runtime::Backend> runtime::hybrid::op::FunctionCall::get_backend() const
{
return m_backend;
}
shared_ptr<runtime::Executable> runtime::hybrid::op::FunctionCall::get_executable() const
{
return m_executable;
}
shared_ptr<Function> runtime::hybrid::op::FunctionCall::get_function() const
{
return m_function;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/backend.hpp"
namespace ngraph
{
namespace runtime
{
namespace hybrid
{
namespace op
{
class FunctionCall;
}
}
}
}
class ngraph::runtime::hybrid::op::FunctionCall : public ngraph::op::Op
{
public:
FunctionCall(const NodeVector& outputs,
const NodeVector& inputs,
std::shared_ptr<Function> function,
std::shared_ptr<Backend> backend);
std::shared_ptr<Backend> get_backend() const;
std::shared_ptr<Executable> get_executable() const;
std::shared_ptr<Function> get_function() const;
private:
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
const NodeVector m_outputs;
std::shared_ptr<Function> m_function;
std::shared_ptr<Backend> m_backend;
std::shared_ptr<Executable> m_executable;
};
NGRAPH_OP(FunctionCall, ngraph::runtime::hybrid::op)
......@@ -59,6 +59,7 @@
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/hybrid/op/function_call.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
......@@ -186,7 +187,6 @@ private:
const std::vector<std::shared_ptr<HostTensor>>& args)
{
const Node& node = node_wrapper.get_node();
std::string node_op = node.description();
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
......@@ -741,6 +741,29 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::FunctionCall:
{
auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
auto backend = f->get_backend();
auto executable = f->get_executable();
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<Tensor>> inputs;
for (const std::shared_ptr<HostTensor>& t : out)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
outputs.push_back(backend_tensor);
}
for (const std::shared_ptr<HostTensor>& t : args)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
inputs.push_back(backend_tensor);
}
executable->call(outputs, inputs);
break;
}
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -29,6 +29,7 @@ runtime::interpreter::NodeWrapper::NodeWrapper(const shared_ptr<const Node>& nod
#define NGRAPH_OP(a, b) {#a, runtime::interpreter::OP_TYPEID::a},
static unordered_map<string, runtime::interpreter::OP_TYPEID> typeid_map{
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
};
#undef NGRAPH_OP
......
......@@ -40,6 +40,7 @@ namespace ngraph
enum class ngraph::runtime::interpreter::OP_TYPEID
{
#include "ngraph/op/op_tbl.hpp"
#include "ngraph/runtime/hybrid/op/op_tbl.hpp"
};
#undef NGRAPH_OP
......
......@@ -241,7 +241,7 @@ endif()
if (NGRAPH_INTERPRETER_ENABLE)
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
target_link_libraries(unit-test PRIVATE interpreter_backend hybrid_base)
target_link_libraries(unit-test PRIVATE interpreter_backend)
endif()
if (NGRAPH_GPU_ENABLE)
......
......@@ -20,9 +20,14 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/op/function_call.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
......@@ -44,6 +49,53 @@ static runtime::Backend* hybrid_creator(const char* config)
return new runtime::hybrid::HybridBackend(backend_list);
}
TEST(HYBRID, function_call)
{
vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>()};
auto backend = make_shared<runtime::hybrid::HybridBackend>(backend_list);
Shape shape{};
shared_ptr<Function> inner_function;
auto inner_A = make_shared<op::Parameter>(element::f32, shape);
auto inner_B = make_shared<op::Parameter>(element::f32, shape);
auto inner_C = make_shared<op::Parameter>(element::f32, shape);
auto inner_R1 = (inner_A + inner_B) * inner_C;
auto inner_R2 = (inner_A + inner_C) * inner_C;
NodeVector inner_Result{inner_R1, inner_R2};
inner_function =
make_shared<Function>(inner_Result, ParameterVector{inner_A, inner_B, inner_C});
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
NodeVector fcall_args{A, B, C};
auto H = make_shared<runtime::hybrid::op::FunctionCall>(
inner_Result, fcall_args, inner_function, backend_list[0]);
auto G0 = make_shared<ngraph::op::GetOutputElement>(H, 0);
auto G1 = make_shared<ngraph::op::GetOutputElement>(H, 1);
NodeVector out{G0, G1};
auto J = G0 + G1;
auto f = make_shared<Function>(out, ParameterVector{A, B, C});
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> r0 = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> r1 = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{2});
copy_data(b, vector<float>{3});
copy_data(c, vector<float>{4});
auto exec = backend->compile(f);
exec->call({r0, r1}, {a, b, c});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::VisualizeTree>("test.png");
pass_manager.run_passes(f);
}
TEST(HYBRID, abc)
{
const string backend_name = "H1";
......@@ -52,11 +104,14 @@ TEST(HYBRID, abc)
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Parameter>(element::f32, shape);
auto t1 = A * B;
auto t2 = t1 * D;
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(((t2 + C) + A) * t1, ParameterVector{A, B, C, D});
auto t3 = (t2 + C);
auto t4 = (t3 + A) * t1;
NodeVector result({t3, t4});
auto f = make_shared<Function>(result, ParameterVector{A, B, C, D});
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
......@@ -66,7 +121,8 @@ TEST(HYBRID, abc)
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result1 = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result2 = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4});
copy_data(b, vector<float>{5, 6, 7, 8});
......@@ -74,6 +130,6 @@ TEST(HYBRID, abc)
copy_data(d, vector<float>{4, 3, 2, 1});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c, d});
EXPECT_EQ(read_vector<float>(result), (vector<float>{150, 576, 1176, 1536}));
handle->call_with_validate({result1, result2}, {a, b, c, d});
EXPECT_EQ(read_vector<float>(result2), (vector<float>{150, 576, 1176, 1536}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment