Unverified Commit 4d272f1f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Hybrid backend update (#2306)

* update hybrid unit test

* remove unused files

* add graphviz to test

* add ability to add attributes to graphviz nodes

* tweak colors

* more interesting graph

* update test model

* add memory management passes

* add Dump

* wip

* remove in-place code from memory layout
parent e01c47f1
......@@ -65,8 +65,9 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& fu
return false;
}
pass::VisualizeTree::VisualizeTree(const string& file_name)
pass::VisualizeTree::VisualizeTree(const string& file_name, node_modifiers_t nm)
: m_name{file_name}
, m_node_modifiers{nm}
{
}
......@@ -83,32 +84,38 @@ std::string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{
stringstream ss;
vector<string> attributes;
if (node->is_parameter() || node->is_output())
{
ss << " " << node->get_name() << " [shape=box ";
attributes.push_back("shape=box");
if (node->is_parameter())
{
ss << "color=blue ";
attributes.push_back("color=blue");
attributes.push_back("penwidth=1.5");
}
if (node->is_output())
{
ss << "style=filled fillcolor=pink ";
attributes.push_back("color=crimson");
attributes.push_back("penwidth=1.5");
}
}
else
{
ss << " " << node->get_name() << " [shape=ellipse color=black";
attributes.push_back("shape=ellipse");
attributes.push_back("color=black");
}
ss << " label=\"" << node->get_name();
// Construct the label attribute
{
stringstream label;
label << "label=\"" << node->get_name();
static const char* nvtos = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr)
{
// The shapes of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
ss << " " << (node->get_outputs().size() != 1 ? std::string("[skipped]")
label << " " << (node->get_outputs().size() != 1 ? std::string("[skipped]")
: vector_to_string(node->get_shape()));
}
......@@ -117,7 +124,8 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
{
// The types of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
ss << " " << ((node->get_outputs().size() != 1) ? std::string("[skipped]")
label << " "
<< ((node->get_outputs().size() != 1) ? std::string("[skipped]")
: node->get_element_type().c_type_string());
}
......@@ -125,10 +133,19 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
auto eh = m_ops_to_details.find(TI(n));
if (eh != m_ops_to_details.end())
{
eh->second(n, ss);
eh->second(n, label);
}
label << "\"";
attributes.push_back(label.str());
}
ss << " \"]\n";
if (m_node_modifiers)
{
m_node_modifiers(*node, attributes);
}
stringstream ss;
ss << " " << node->get_name() << " [" << join(attributes, " ") << "]\n";
return ss.str();
}
......
......@@ -39,7 +39,9 @@ namespace ngraph
class ngraph::pass::VisualizeTree : public ModulePass
{
public:
VisualizeTree(const std::string& file_name);
using node_modifiers_t =
std::function<void(const Node& node, std::vector<std::string>& attributes)>;
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr);
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
static std::string get_file_ext();
......@@ -54,4 +56,5 @@ private:
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
std::unordered_map<std::type_index, std::function<void(const Node&, std::ostream& ss)>>
m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
};
......@@ -18,7 +18,11 @@ add_library(hybrid_base STATIC
hybrid_backend.cpp
hybrid_util.cpp
pass/assign_placement.cpp
pass/fix_get_output_element.cpp)
pass/dump.cpp
pass/fix_get_output_element.cpp
pass/liveness.cpp
pass/memory_layout.cpp
)
target_link_libraries(hybrid_base PUBLIC ngraph)
set_target_properties(hybrid_base PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${NGRAPH_BUILD_DIR})
......
......@@ -21,7 +21,10 @@
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/pass/assign_placement.hpp"
#include "ngraph/runtime/hybrid/pass/dump.hpp"
#include "ngraph/runtime/hybrid/pass/fix_get_output_element.hpp"
#include "ngraph/runtime/hybrid/pass/liveness.hpp"
#include "ngraph/runtime/hybrid/pass/memory_layout.hpp"
#include "ngraph/runtime/tensor.hpp"
using namespace ngraph;
......@@ -48,6 +51,17 @@ shared_ptr<runtime::Tensor> runtime::hybrid::HybridBackend::create_tensor(
return (*it)->create_tensor(element_type, shape, memory_pointer);
}
static void node_modifiers(const Node& node, vector<string>& attributes)
{
vector<string> colors = {"\"#A0FFA0\"", "\"#FFF790\""};
if (node.get_placement_index() < colors.size())
{
string color = colors[node.get_placement_index()];
attributes.push_back("style=filled");
attributes.push_back("fillcolor=" + color);
}
}
runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> func)
{
if (m_function_map.find(func) == m_function_map.end())
......@@ -60,9 +74,13 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::AssignPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
#ifdef GPUH_DEBUG
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png");
#endif
pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
// pass_manager.register_pass<runtime::hybrid::pass::MemoryLayout>();
if (m_debug_enabled)
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>("graph.png", node_modifiers);
}
pass_manager.run_passes(instance.m_function);
// Split function to sub_functions
......@@ -71,9 +89,18 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
m_function_map.insert({func, instance});
// Compile subfunctions in corresponding backends
size_t subfunction_number = 0;
for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
{
size_t placement = runtime::hybrid::get_colocated_function_placement(sub_function);
if (m_debug_enabled)
{
string name = "subfunction_" + to_string(subfunction_number++);
ngraph::pass::Manager pm;
pm.register_pass<ngraph::pass::VisualizeTree>(name + ".png", node_modifiers);
pm.register_pass<runtime::hybrid::pass::Dump>(name + ".dump");
pm.run_passes(sub_function);
}
auto backend = m_backend_list[placement];
backend->compile(sub_function);
......
......@@ -56,6 +56,7 @@ public:
bool is_supported(const ngraph::Node& node) const override;
void set_debug_enabled(bool flag) { m_debug_enabled = flag; }
private:
class FunctionInstance
{
......@@ -69,6 +70,7 @@ private:
std::map<std::shared_ptr<ngraph::Function>, FunctionInstance> m_function_map;
std::vector<std::shared_ptr<runtime::Backend>> m_backend_list;
bool m_debug_enabled = false;
size_t get_placement(const runtime::Tensor* t);
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <fstream>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/runtime/hybrid/pass/dump.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
runtime::hybrid::pass::Dump::Dump(const string& output_file)
: m_output_file{output_file}
{
}
bool runtime::hybrid::pass::Dump::run_on_module(vector<shared_ptr<Function>>& functions)
{
ofstream out{m_output_file};
if (out)
{
for (shared_ptr<Function> f : functions)
{
out << "=====================================================================\n";
out << f->get_name() << " start\n";
out << "=====================================================================\n";
for (const shared_ptr<Node>& node : f->get_ordered_ops())
{
out << node->get_name() << "(";
vector<string> inputs;
for (const descriptor::Input& input : node->get_inputs())
{
inputs.push_back(input.get_tensor().get_name());
}
out << join(inputs);
out << ") -> ";
vector<string> outputs;
for (size_t i = 0; i < node->get_output_size(); ++i)
{
outputs.push_back(node->get_output_tensor(i).get_name());
}
out << join(outputs);
out << "\n";
out << " " << node->get_placement_index() << " Placement\n";
for (const descriptor::Tensor* tensor : node->liveness_new_list)
{
out << " N " << tensor->get_name() << "\n";
}
for (const descriptor::Tensor* tensor : node->liveness_free_list)
{
out << " F " << tensor->get_name() << "\n";
}
}
out << "=====================================================================\n";
out << f->get_name() << " end\n";
out << "=====================================================================\n";
}
}
return false;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <string>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace hybrid
{
namespace pass
{
class Dump;
}
}
}
}
class ngraph::runtime::hybrid::pass::Dump : public ngraph::pass::ModulePass
{
public:
Dump(const std::string& output_file);
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
const std::string m_output_file;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <exception>
#include <sstream>
#include <unordered_set>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/runtime/hybrid/pass/liveness.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
bool runtime::hybrid::pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
{
list<shared_ptr<Node>> ops = function->get_ordered_ops();
unordered_set<descriptor::Tensor*> persistent_tensors;
unordered_set<descriptor::Tensor*> output_tensors;
for (const shared_ptr<op::Parameter>& node : function->get_parameters())
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
}
}
for (const shared_ptr<op::Result>& node : function->get_results())
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
output_tensors.insert(&tensor);
}
}
for (const shared_ptr<Node>& node : ops)
{
if (auto constant_node = dynamic_pointer_cast<op::Constant>(node))
{
for (size_t i = 0; i < constant_node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = constant_node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
}
}
}
unordered_set<descriptor::Tensor*> currently_live;
for (auto it = ops.rbegin(); it != ops.rend(); it++)
{
const shared_ptr<Node>& node = *it;
node->liveness_new_list.clear();
node->liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls;
for (descriptor::Input& input_decl : node->get_inputs())
{
descriptor::Tensor& tensor = input_decl.get_tensor();
if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{
input_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> output_tensor_decls;
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{
output_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> free_tensor_decls;
unordered_set<descriptor::Tensor*> new_tensor_decls;
unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
all_tensor_decls.insert(output_tensor_decls.begin(), output_tensor_decls.end());
for (descriptor::Tensor* tensor_decl : all_tensor_decls)
{
if (currently_live.find(tensor_decl) == currently_live.end())
{
// this is the last node that value is seen in
// delete it at the end of the op
currently_live.insert(tensor_decl);
if (output_tensors.find(tensor_decl) == output_tensors.end())
{
// Don't free output tensors
free_tensor_decls.insert(tensor_decl);
}
}
}
for (descriptor::Tensor* output_decl : output_tensor_decls)
{
auto currently_live_it = currently_live.find(output_decl);
if (currently_live_it != currently_live.end())
{
new_tensor_decls.insert(output_decl);
currently_live.erase(currently_live_it);
}
}
node->liveness_free_list = free_tensor_decls;
node->liveness_new_list = new_tensor_decls;
}
return false;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace hybrid
{
namespace pass
{
class Liveness;
}
}
}
}
class ngraph::runtime::hybrid::pass::Liveness : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <exception>
#include <sstream>
#include "ngraph/op/concat.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/hybrid/pass/memory_layout.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
runtime::hybrid::pass::MemoryLayout::MemoryLayout(size_t alignment)
: m_alignment(alignment)
{
if (m_alignment == 0)
{
throw invalid_argument("Memory alignment must be > 0");
}
}
bool runtime::hybrid::pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
{
ngraph::pass::MemoryManager mm(m_alignment, false);
for (shared_ptr<Node> node : function->get_ordered_ops())
{
for (descriptor::Tensor* tensor : node->liveness_new_list)
{
size_t offset = mm.allocate(tensor->size());
tensor->set_pool_offset(offset);
}
for (const descriptor::Tensor* tensor : node->liveness_free_list)
{
mm.free(tensor->get_pool_offset());
}
}
function->set_temporary_pool_size(mm.max_allocated());
return false;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace hybrid
{
namespace pass
{
class MemoryLayout;
}
}
}
}
class ngraph::runtime::hybrid::pass::MemoryLayout : public ngraph::pass::FunctionPass
{
public:
MemoryLayout(size_t alignment = 64);
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
size_t m_alignment;
};
......@@ -66,7 +66,8 @@ if (NGRAPH_INTERPRETER_ENABLE)
list(APPEND SRC
backend_debug_api.cpp
builder.cpp
backend_api.cpp)
backend_api.cpp
hybrid_backend.cpp)
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER)
endif()
......@@ -224,7 +225,7 @@ endif()
if (NGRAPH_INTERPRETER_ENABLE)
target_compile_definitions(unit-test PRIVATE NGRAPH_INTERPRETER_ENABLE)
target_link_libraries(unit-test PRIVATE interpreter_backend)
target_link_libraries(unit-test PRIVATE interpreter_backend hybrid_base)
endif()
if (NGRAPH_GPU_ENABLE)
......
......@@ -18,12 +18,12 @@
#include "gtest/gtest.h"
#include "hybrid_utils.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
......@@ -33,58 +33,47 @@
using namespace std;
using namespace ngraph;
static runtime::Backend* hybrid1_creator(const char* config)
static runtime::Backend* hybrid_creator(const char* config)
{
vector<shared_ptr<runtime::Backend>> backend_list;
set<string> s0 = {"Add"};
auto b0 = make_shared<BackendWrapper>("INTERPRETER", s0, "AddOnly");
backend_list.push_back(b0);
vector<string> unsupported_0 = {"Add"};
vector<string> unsupported_1 = {"Multiply"};
vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>(unsupported_0),
make_shared<runtime::interpreter::INTBackend>(unsupported_1)};
#define NGRAPH_OP(a, b) #a,
set<string> s1 = {
#include "ngraph/op/op_tbl.hpp"
};
auto b1 = make_shared<BackendWrapper>("INTERPRETER", s1, "AllOps");
backend_list.push_back(b1);
return new TestBackend(backend_list);
return new runtime::hybrid::HybridBackend(backend_list);
}
TEST(HYBRID, abc)
{
const string backend_name = "HYBRID1";
runtime::BackendManager::register_backend(backend_name, hybrid1_creator);
const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator);
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Parameter>(element::f32, shape);
auto t1 = A * B;
auto t2 = t1 * D;
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
auto f = make_shared<Function>((t2 + C) * t1, ParameterVector{A, B, C, D});
auto backend = runtime::Backend::create(backend_name);
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {a, b, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {b, a, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
copy_data(a, vector<float>{1, 2, 3, 4});
copy_data(b, vector<float>{5, 6, 7, 8});
copy_data(c, vector<float>{9, 10, 11, 12});
copy_data(d, vector<float>{4, 3, 2, 1});
auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {a, c, b});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{50, 72}, {98, 128}})).get_vector());
backend->call_with_validate(handle, {result}, {a, b, c, d});
EXPECT_EQ(read_vector<float>(result), (vector<float>{145, 552, 1113, 1408}));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "hybrid_utils.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_placement.hpp"
#include "ngraph/pass/manager.hpp"
using namespace std;
using namespace ngraph;
TestBackend::TestBackend(const vector<shared_ptr<runtime::Backend>>& backend_list)
: m_backend_list{backend_list}
{
if (m_backend_list.size() == 0)
{
throw runtime_error("TestBackend backend list empty");
}
}
shared_ptr<runtime::Tensor> TestBackend::create_tensor(const element::Type& element_type,
const Shape& shape)
{
return m_backend_list[0]->create_tensor(element_type, shape);
}
shared_ptr<runtime::Tensor> TestBackend::create_tensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer)
{
return m_backend_list[0]->create_tensor(element_type, shape, memory_pointer);
}
bool TestBackend::compile(shared_ptr<Function> func)
{
if (m_function_map.find(func) == m_function_map.end())
{
// Clone function
FunctionInstance instance;
instance.m_function = clone_function(*func);
// Run placement pass
pass::Manager pass_manager;
pass_manager.register_pass<pass::AssignPlacement>(m_backend_list);
pass_manager.run_passes(instance.m_function);
// Split function to sub_functions
tie(instance.m_sub_functions, instance.m_map_parameter_to_result) =
split_function_by_placement_size(instance.m_function);
m_function_map.insert({func, instance});
// Compile subfunctions in corresponding backends
for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
{
size_t placement = get_colocated_function_placement_size(sub_function);
auto backend =
m_backend_list[(placement - 1)]; // (placement-1) as 0 is default placement
backend->compile(sub_function);
}
}
return true;
}
bool TestBackend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
// Get FunctionInstance
bool rc = true;
auto it = m_function_map.find(func);
if (it == m_function_map.end())
{
compile(func);
it = m_function_map.find(func);
}
if (it == m_function_map.end())
{
throw runtime_error("Error constructing backend.");
}
FunctionInstance& instance = it->second;
// Parameter and result node in sub_function maps to one Tensor
unordered_map<shared_ptr<Node>, shared_ptr<runtime::Tensor>> map_node_to_tensor_view;
for (size_t i = 0; i < inputs.size(); ++i)
{
map_node_to_tensor_view[instance.m_function->get_parameters()[i]] = inputs[i];
}
for (size_t i = 0; i < outputs.size(); ++i)
{
map_node_to_tensor_view[instance.m_function->get_results()[i]] = outputs[i];
}
// Call subfunctions
for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
{
// Init backend
size_t placement = get_colocated_function_placement_size(sub_function);
auto backend = m_backend_list[(placement - 1)]; // (placement-1) as 0 is default placement
// Prepare parameter TensorViews
vector<shared_ptr<runtime::Tensor>> parameter_tvs;
for (auto parameter_node : sub_function->get_parameters())
{
if (map_node_to_tensor_view.find(parameter_node) != map_node_to_tensor_view.end())
{
parameter_tvs.push_back(map_node_to_tensor_view.at(parameter_node));
}
else
{
auto result_node = instance.m_map_parameter_to_result.at(parameter_node);
auto result_tv = map_node_to_tensor_view.at(result_node);
auto parameter_tv = backend->create_tensor(parameter_node->get_element_type(),
parameter_node->get_shape());
copy_data(parameter_tv, read_vector<float>(result_tv));
map_node_to_tensor_view[parameter_node] = parameter_tv;
parameter_tvs.push_back(parameter_tv);
}
}
// Prepare result TensorViews
vector<shared_ptr<runtime::Tensor>> result_tvs;
for (auto result_node : sub_function->get_results())
{
if (map_node_to_tensor_view.find(result_node) != map_node_to_tensor_view.end())
{
result_tvs.push_back(map_node_to_tensor_view.at(result_node));
}
else
{
auto result_tv = backend->create_tensor(result_node->get_element_type(),
result_node->get_shape());
map_node_to_tensor_view[result_node] = result_tv;
result_tvs.push_back(result_tv);
}
}
// Call
backend->call_with_validate(sub_function, result_tvs, parameter_tvs);
}
return rc;
}
BackendWrapper::BackendWrapper(const string& backend_name,
const set<string>& supported_ops,
const string& name)
: m_backend{runtime::Backend::create(backend_name)}
, m_supported_ops{supported_ops}
, m_name{name}
{
}
shared_ptr<runtime::Tensor> BackendWrapper::create_tensor(const element::Type& element_type,
const Shape& shape)
{
return m_backend->create_tensor(element_type, shape);
}
shared_ptr<runtime::Tensor> BackendWrapper::create_tensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer)
{
return m_backend->create_tensor(element_type, shape, memory_pointer);
}
bool BackendWrapper::compile(shared_ptr<Function> func)
{
return m_backend->compile(func);
}
bool BackendWrapper::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
return m_backend->call(func, outputs, inputs);
}
bool BackendWrapper::is_supported(const Node& node) const
{
return m_supported_ops.find(node.description()) != m_supported_ops.end();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/ngraph.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
class TestBackend : public ngraph::runtime::Backend
{
public:
TestBackend(const std::vector<std::shared_ptr<ngraph::runtime::Backend>>& backend_list);
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
void* memory_pointer) override;
bool compile(std::shared_ptr<ngraph::Function> func) override;
bool call(std::shared_ptr<ngraph::Function> func,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
private:
// This list of backends is in order of priority with the first backend higher priority
// than the second.
std::vector<std::shared_ptr<ngraph::runtime::Backend>> m_backend_list;
protected:
class FunctionInstance
{
public:
std::shared_ptr<ngraph::Function> m_function;
std::vector<std::shared_ptr<ngraph::Function>> m_sub_functions;
std::unordered_map<std::shared_ptr<ngraph::op::Parameter>,
std::shared_ptr<ngraph::op::Result>>
m_map_parameter_to_result;
};
std::map<std::shared_ptr<ngraph::Function>, FunctionInstance> m_function_map;
};
class BackendWrapper : public ngraph::runtime::Backend
{
public:
BackendWrapper(const std::string& backend_name,
const std::set<std::string>& supported_ops,
const std::string& name);
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
void* memory_pointer) override;
bool compile(std::shared_ptr<ngraph::Function> func) override;
bool call(std::shared_ptr<ngraph::Function> func,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& inputs) override;
bool is_supported(const ngraph::Node& node) const override;
private:
std::shared_ptr<ngraph::runtime::Backend> m_backend;
const std::set<std::string> m_supported_ops;
const std::string m_name;
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment