Commit 7e001000 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement function calls

parent 31eb5c46
...@@ -22,12 +22,14 @@ using namespace ngraph::runtime::cpu; ...@@ -22,12 +22,14 @@ using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function, CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs, size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const TensorViewPtrs& temps) const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees)
: m_n_inputs(n_inputs) : m_n_inputs(n_inputs)
, m_n_outputs(n_outputs) , m_n_outputs(n_outputs)
, m_tensor_views(n_inputs + n_outputs + temps.size()) , m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
, m_callees(callees)
{ {
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs); copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
} }
...@@ -40,7 +42,7 @@ void CallFrame::tensor_call( ...@@ -40,7 +42,7 @@ void CallFrame::tensor_call(
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs); copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
// Invoke compiled computation // Invoke compiled computation
m_compiled_function(this, m_tensor_views); m_compiled_function(this, m_tensor_views, m_callees);
// Don't hold onto inputs/outputs // Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr); fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
namespace cpu namespace cpu
{ {
class CallFrame; class CallFrame;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*, using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>; ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>;
// Compile and execute graphs // Compile and execute graphs
class CallFrame : public ngraph::runtime::CallFrame class CallFrame : public ngraph::runtime::CallFrame
...@@ -41,7 +43,8 @@ namespace ngraph ...@@ -41,7 +43,8 @@ namespace ngraph
CallFrame(EntryPoint compiled_function, CallFrame(EntryPoint compiled_function,
size_t n_inputs, size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const TensorViewPtrs& temps); const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
...@@ -73,6 +76,7 @@ namespace ngraph ...@@ -73,6 +76,7 @@ namespace ngraph
TensorViewPtrs m_tensor_views; TensorViewPtrs m_tensor_views;
bool m_return; bool m_return;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
std::vector<std::shared_ptr<CallFrame>> m_callees;
}; };
} }
} }
......
...@@ -24,11 +24,13 @@ ...@@ -24,11 +24,13 @@
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp" #include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp" #include "ngraph/runtime/cpu/emitter.hpp"
#include "ngraph/runtime/cpu/external_function.hpp" #include "ngraph/runtime/cpu/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
using namespace std; using namespace std;
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
...@@ -1016,3 +1018,47 @@ void Emitter::EMITTER_DECL(EmitReshape) ...@@ -1016,3 +1018,47 @@ void Emitter::EMITTER_DECL(EmitReshape)
"Axis permutation in reshape is not implemented yet for tensors with rank>2"); "Axis permutation in reshape is not implemented yet for tensors with rank>2");
} }
} }
void Emitter::EMITTER_DECL(EmitFunctionCall)
{
auto function_call = static_cast<const op::FunctionCall*>(n);
auto function = function_call->get_function();
std::shared_ptr<ExternalFunction> external;
try
{
external = function_map.at(function);
}
catch (const std::out_of_range)
{
external = make_shared<ExternalFunction>(function);
function_map.insert({function, external});
}
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
external->make_call_frame());
ef->get_callees().emplace_back(cf);
TU +=
" {\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;\n"
" std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;\n";
for (const auto &in : inputs)
{
TU +=
" inputs.emplace_back(call_frame->get_tensor_view(" + to_string(in.get_index()) + "));\n";
}
for (const auto &out : outputs)
{
TU +=
" outputs.emplace_back(call_frame->get_tensor_view(" + to_string(out.get_index()) + "));\n";
}
TU +=
" (*cf)(inputs, outputs);\n"
" }\n";
}
...@@ -77,6 +77,7 @@ namespace ngraph ...@@ -77,6 +77,7 @@ namespace ngraph
void EMITTER_DECL(EmitConvert); void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant); void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape); void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall);
}; };
} }
} }
......
...@@ -108,6 +108,7 @@ static const OpMap dispatcher{ ...@@ -108,6 +108,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::Convert), &Emitter::EmitConvert}, {TI(ngraph::op::Convert), &Emitter::EmitConvert},
{TI(ngraph::op::Constant), &Emitter::EmitConstant}, {TI(ngraph::op::Constant), &Emitter::EmitConstant},
{TI(ngraph::op::Reshape), &Emitter::EmitReshape}, {TI(ngraph::op::Reshape), &Emitter::EmitReshape},
{TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall},
}; };
#undef TI #undef TI
...@@ -203,14 +204,15 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -203,14 +204,15 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp" #include "ngraph/runtime/cpu/eigen_utils.hpp"
void *__dso_handle = 0;
using namespace ngraph::element; using namespace ngraph::element;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::cpu::eigen; using namespace ngraph::runtime::cpu::eigen;
void *__dso_handle = 0;
extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame, extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
ngraph::runtime::TensorViewPtrs& tensor_views) ngraph::runtime::TensorViewPtrs& tensor_views,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>& callees)
{ {
)"; )";
...@@ -262,7 +264,9 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame, ...@@ -262,7 +264,9 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
estate.add_module(llvm_module); estate.add_module(llvm_module);
estate.finalize(); estate.finalize();
compiled_function = estate.find_function<void( compiled_function = estate.find_function<void(
ngraph::runtime::cpu::CallFrame*, ngraph::runtime::TensorViewPtrs&)>("__entrypoint"); ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>("__entrypoint");
assert(compiled_function); assert(compiled_function);
m_is_compiled = true; m_is_compiled = true;
...@@ -340,5 +344,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame() ...@@ -340,5 +344,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M #undef M
} }
return make_shared<ngraph::runtime::cpu::CallFrame>( return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_inputs, m_n_outputs, temps); compiled_function, m_n_inputs, m_n_outputs, temps, callees);
} }
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
...@@ -48,7 +48,8 @@ namespace ngraph ...@@ -48,7 +48,8 @@ namespace ngraph
using OpMap = std::unordered_map<std::type_index, OpFunction>; using OpMap = std::unordered_map<std::type_index, OpFunction>;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*, using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>; ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>&)>;
class ExternalFunction : public ngraph::runtime::ExternalFunction class ExternalFunction : public ngraph::runtime::ExternalFunction
{ {
...@@ -56,6 +57,7 @@ namespace ngraph ...@@ -56,6 +57,7 @@ namespace ngraph
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::vector<std::shared_ptr<CallFrame>> &get_callees() { return callees; }
protected: protected:
void compile(FunctionMap& function_map); void compile(FunctionMap& function_map);
...@@ -64,6 +66,7 @@ namespace ngraph ...@@ -64,6 +66,7 @@ namespace ngraph
size_t m_n_outputs; size_t m_n_outputs;
ngraph::descriptor::TensorViewPtrs m_temp_views; ngraph::descriptor::TensorViewPtrs m_temp_views;
EntryPoint compiled_function; EntryPoint compiled_function;
std::vector<std::shared_ptr<CallFrame>> callees;
}; };
} }
} }
......
...@@ -1084,8 +1084,7 @@ TEST(cpu, tensor_constant_with_op) ...@@ -1084,8 +1084,7 @@ TEST(cpu, tensor_constant_with_op)
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), result->get_vector()); ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), result->get_vector());
} }
/* TEST(cpu, function_call)
TEST(execute, function_call)
{ {
// First create "f(A,B,C) = (A+B)*C". // First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
...@@ -1106,7 +1105,7 @@ TEST(execute, function_call) ...@@ -1106,7 +1105,7 @@ TEST(execute, function_call)
op::Parameters{X, Y, Z}); op::Parameters{X, Y, Z});
// Now call g on some test vectors. // Now call g on some test vectors.
auto manager = runtime::Manager::get("NGVM"); auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(g); auto external = manager->compile(g);
auto backend = manager->allocate_backend(); auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external); auto cf = backend->make_call_frame(external);
...@@ -1128,7 +1127,6 @@ TEST(execute, function_call) ...@@ -1128,7 +1127,6 @@ TEST(execute, function_call)
(*cf)({x, z, y}, {result}); (*cf)({x, z, y}, {result});
ASSERT_EQ((vector<float>{100, 144, 196, 256}), result->get_vector()); ASSERT_EQ((vector<float>{100, 144, 196, 256}), result->get_vector());
} }
*/
TEST(cpu, broadcast_scalar_vector) TEST(cpu, broadcast_scalar_vector)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment