Commit b1b0ea87 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

More hybrid fixes (#2788)

* wip

* Make a copy of the FunctionCall's internal function prior to compilation

* Make sure unit test is only included with interpreter and cpu backends enabled
parent 67248fdb
...@@ -225,7 +225,7 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f, ...@@ -225,7 +225,7 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
sub_function->set_placement(placement); sub_function->set_placement(placement);
auto fc = make_shared<runtime::hybrid::op::FunctionCall>(function_call_outputs, auto fc = make_shared<runtime::hybrid::op::FunctionCall>(function_call_outputs,
function_call_inputs, function_call_inputs,
sub_function, *sub_function,
backend_list[placement]); backend_list[placement]);
fc->set_placement_index(0); fc->set_placement_index(0);
for (size_t i = 0; i < function_call_outputs.size(); i++) for (size_t i = 0; i < function_call_outputs.size(); i++)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "function_call.hpp" #include "function_call.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
using namespace std; using namespace std;
...@@ -22,13 +23,13 @@ using namespace ngraph; ...@@ -22,13 +23,13 @@ using namespace ngraph;
runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs, runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs,
const NodeVector& inputs, const NodeVector& inputs,
shared_ptr<Function> function, const Function& function,
shared_ptr<Backend> backend) shared_ptr<Backend> backend)
: Op("FunctionCall", inputs) : Op("FunctionCall", inputs)
, m_function_outputs{outputs} , m_function_outputs{outputs}
, m_function{function} , m_function{ngraph::clone_function(function)}
, m_backend{backend} , m_backend{backend}
, m_executable{backend->compile(function)} , m_executable{backend->compile(m_function)}
{ {
set_output_size(outputs.size()); set_output_size(outputs.size());
for (size_t i = 0; i < outputs.size(); i++) for (size_t i = 0; i < outputs.size(); i++)
...@@ -40,7 +41,7 @@ runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs, ...@@ -40,7 +41,7 @@ runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs,
shared_ptr<Node> shared_ptr<Node>
runtime::hybrid::op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const runtime::hybrid::op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{ {
return make_shared<FunctionCall>(m_function_outputs, new_args, m_function, m_backend); return make_shared<FunctionCall>(m_function_outputs, new_args, *m_function, m_backend);
} }
shared_ptr<runtime::Backend> runtime::hybrid::op::FunctionCall::get_backend() const shared_ptr<runtime::Backend> runtime::hybrid::op::FunctionCall::get_backend() const
......
...@@ -38,7 +38,7 @@ class ngraph::runtime::hybrid::op::FunctionCall : public ngraph::op::Op ...@@ -38,7 +38,7 @@ class ngraph::runtime::hybrid::op::FunctionCall : public ngraph::op::Op
public: public:
FunctionCall(const NodeVector& outputs, FunctionCall(const NodeVector& outputs,
const NodeVector& inputs, const NodeVector& inputs,
std::shared_ptr<Function> function, const Function& function,
std::shared_ptr<Backend> backend); std::shared_ptr<Backend> backend);
std::shared_ptr<Backend> get_backend() const; std::shared_ptr<Backend> get_backend() const;
......
...@@ -85,8 +85,10 @@ if (NGRAPH_INTERPRETER_ENABLE) ...@@ -85,8 +85,10 @@ if (NGRAPH_INTERPRETER_ENABLE)
list(APPEND SRC list(APPEND SRC
backend_debug_api.cpp backend_debug_api.cpp
builder.cpp builder.cpp
backend_api.cpp backend_api.cpp)
hybrid_backend.cpp) if (NGRAPH_CPU_ENABLE)
list(APPEND SRC hybrid_backend.cpp)
endif()
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER) set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER)
endif() endif()
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/backend.hpp" #include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp" #include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp" #include "ngraph/runtime/hybrid/hybrid_backend.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp" #include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/op/function_call.hpp" #include "ngraph/runtime/hybrid/op/function_call.hpp"
...@@ -40,11 +41,11 @@ using namespace ngraph; ...@@ -40,11 +41,11 @@ using namespace ngraph;
static runtime::Backend* hybrid_creator(const char* config) static runtime::Backend* hybrid_creator(const char* config)
{ {
vector<string> unsupported_0 = {"Add"}; vector<string> unsupported_0 = {"Add", "Max"};
vector<string> unsupported_1 = {"Multiply"}; vector<string> unsupported_1 = {"Multiply"};
vector<shared_ptr<runtime::Backend>> backend_list = { vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>(unsupported_0), make_shared<runtime::interpreter::INTBackend>(unsupported_0),
make_shared<runtime::interpreter::INTBackend>(unsupported_1)}; make_shared<runtime::cpu::CPU_Backend>()};
return new runtime::hybrid::HybridBackend(backend_list); return new runtime::hybrid::HybridBackend(backend_list);
} }
...@@ -71,7 +72,7 @@ TEST(HYBRID, function_call) ...@@ -71,7 +72,7 @@ TEST(HYBRID, function_call)
auto C = make_shared<op::Parameter>(element::f32, shape); auto C = make_shared<op::Parameter>(element::f32, shape);
NodeVector fcall_args{A, B, C}; NodeVector fcall_args{A, B, C};
auto H = make_shared<runtime::hybrid::op::FunctionCall>( auto H = make_shared<runtime::hybrid::op::FunctionCall>(
inner_Result, fcall_args, inner_function, backend_list[0]); inner_Result, fcall_args, *inner_function, backend_list[0]);
auto G0 = make_shared<ngraph::op::GetOutputElement>(H, 0); auto G0 = make_shared<ngraph::op::GetOutputElement>(H, 0);
auto G1 = make_shared<ngraph::op::GetOutputElement>(H, 1); auto G1 = make_shared<ngraph::op::GetOutputElement>(H, 1);
NodeVector out{G0, G1}; NodeVector out{G0, G1};
...@@ -134,3 +135,25 @@ TEST(HYBRID, abc) ...@@ -134,3 +135,25 @@ TEST(HYBRID, abc)
EXPECT_TRUE( EXPECT_TRUE(
test::all_close_f(read_vector<float>(result2), (vector<float>{150, 576, 1176, 1536}))); test::all_close_f(read_vector<float>(result2), (vector<float>{150, 576, 1176, 1536})));
} }
TEST(HYBRID, simple)
{
const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator);
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i8, shape);
auto f = make_shared<Function>(make_shared<op::Max>(A, AxisSet{0, 1}), ParameterVector{A});
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
// Create some tensors for input/output
auto a = backend->create_tensor(element::i8, shape);
copy_data(a, vector<int8_t>{1, 2, 3, 4});
auto result = backend->create_tensor(element::i8, Shape{});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int8_t>{4}), read_vector<int8_t>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment