Commit b1b0ea87 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

More hybrid fixes (#2788)

* wip

* Make a copy of the FunctionCall's internal function prior to compilation

* Make sure unit test is only included with interpreter and cpu backends enabled
parent 67248fdb
......@@ -225,7 +225,7 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
sub_function->set_placement(placement);
auto fc = make_shared<runtime::hybrid::op::FunctionCall>(function_call_outputs,
function_call_inputs,
sub_function,
*sub_function,
backend_list[placement]);
fc->set_placement_index(0);
for (size_t i = 0; i < function_call_outputs.size(); i++)
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "function_call.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/runtime/backend.hpp"
using namespace std;
......@@ -22,13 +23,13 @@ using namespace ngraph;
runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs,
const NodeVector& inputs,
shared_ptr<Function> function,
const Function& function,
shared_ptr<Backend> backend)
: Op("FunctionCall", inputs)
, m_function_outputs{outputs}
, m_function{function}
, m_function{ngraph::clone_function(function)}
, m_backend{backend}
, m_executable{backend->compile(function)}
, m_executable{backend->compile(m_function)}
{
set_output_size(outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
......@@ -40,7 +41,7 @@ runtime::hybrid::op::FunctionCall::FunctionCall(const NodeVector& outputs,
shared_ptr<Node>
runtime::hybrid::op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<FunctionCall>(m_function_outputs, new_args, m_function, m_backend);
return make_shared<FunctionCall>(m_function_outputs, new_args, *m_function, m_backend);
}
shared_ptr<runtime::Backend> runtime::hybrid::op::FunctionCall::get_backend() const
......
......@@ -38,7 +38,7 @@ class ngraph::runtime::hybrid::op::FunctionCall : public ngraph::op::Op
public:
FunctionCall(const NodeVector& outputs,
const NodeVector& inputs,
std::shared_ptr<Function> function,
const Function& function,
std::shared_ptr<Backend> backend);
std::shared_ptr<Backend> get_backend() const;
......
......@@ -85,8 +85,10 @@ if (NGRAPH_INTERPRETER_ENABLE)
list(APPEND SRC
backend_debug_api.cpp
builder.cpp
backend_api.cpp
hybrid_backend.cpp)
backend_api.cpp)
if (NGRAPH_CPU_ENABLE)
list(APPEND SRC hybrid_backend.cpp)
endif()
set(ACTIVE_BACKEND_LIST ${ACTIVE_BACKEND_LIST} INTERPRETER)
endif()
......
......@@ -25,6 +25,7 @@
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/op/function_call.hpp"
......@@ -40,11 +41,11 @@ using namespace ngraph;
static runtime::Backend* hybrid_creator(const char* config)
{
vector<string> unsupported_0 = {"Add"};
vector<string> unsupported_0 = {"Add", "Max"};
vector<string> unsupported_1 = {"Multiply"};
vector<shared_ptr<runtime::Backend>> backend_list = {
make_shared<runtime::interpreter::INTBackend>(unsupported_0),
make_shared<runtime::interpreter::INTBackend>(unsupported_1)};
make_shared<runtime::cpu::CPU_Backend>()};
return new runtime::hybrid::HybridBackend(backend_list);
}
......@@ -71,7 +72,7 @@ TEST(HYBRID, function_call)
auto C = make_shared<op::Parameter>(element::f32, shape);
NodeVector fcall_args{A, B, C};
auto H = make_shared<runtime::hybrid::op::FunctionCall>(
inner_Result, fcall_args, inner_function, backend_list[0]);
inner_Result, fcall_args, *inner_function, backend_list[0]);
auto G0 = make_shared<ngraph::op::GetOutputElement>(H, 0);
auto G1 = make_shared<ngraph::op::GetOutputElement>(H, 1);
NodeVector out{G0, G1};
......@@ -134,3 +135,25 @@ TEST(HYBRID, abc)
EXPECT_TRUE(
test::all_close_f(read_vector<float>(result2), (vector<float>{150, 576, 1176, 1536})));
}
TEST(HYBRID, simple)
{
const string backend_name = "H1";
runtime::BackendManager::register_backend(backend_name, hybrid_creator);
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i8, shape);
auto f = make_shared<Function>(make_shared<op::Max>(A, AxisSet{0, 1}), ParameterVector{A});
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
// Create some tensors for input/output
auto a = backend->create_tensor(element::i8, shape);
copy_data(a, vector<int8_t>{1, 2, 3, 4});
auto result = backend->create_tensor(element::i8, Shape{});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int8_t>{4}), read_vector<int8_t>(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment