Commit 7e001000 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement function calls

parent 31eb5c46
......@@ -22,12 +22,14 @@ using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps)
const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function)
, m_callees(callees)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
}
......@@ -40,7 +42,7 @@ void CallFrame::tensor_call(
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
// Invoke compiled computation
m_compiled_function(this, m_tensor_views);
m_compiled_function(this, m_tensor_views, m_callees);
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
......
......@@ -31,8 +31,10 @@ namespace ngraph
namespace cpu
{
class CallFrame;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>;
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>;
// Compile and execute graphs
class CallFrame : public ngraph::runtime::CallFrame
......@@ -41,7 +43,8 @@ namespace ngraph
CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps);
const TensorViewPtrs& temps,
const std::vector<std::shared_ptr<CallFrame>>& callees);
/// @brief Invoke the function with values matching the signature of the function.
///
......@@ -73,6 +76,7 @@ namespace ngraph
TensorViewPtrs m_tensor_views;
bool m_return;
EntryPoint m_compiled_function;
std::vector<std::shared_ptr<CallFrame>> m_callees;
};
}
}
......
......@@ -24,11 +24,13 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
using namespace std;
using namespace ngraph::runtime::cpu;
......@@ -1016,3 +1018,47 @@ void Emitter::EMITTER_DECL(EmitReshape)
"Axis permutation in reshape is not implemented yet for tensors with rank>2");
}
}
void Emitter::EMITTER_DECL(EmitFunctionCall)
{
auto function_call = static_cast<const op::FunctionCall*>(n);
auto function = function_call->get_function();
std::shared_ptr<ExternalFunction> external;
try
{
external = function_map.at(function);
}
catch (const std::out_of_range)
{
external = make_shared<ExternalFunction>(function);
function_map.insert({function, external});
}
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
external->make_call_frame());
ef->get_callees().emplace_back(cf);
TU +=
" {\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;\n"
" std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;\n";
for (const auto &in : inputs)
{
TU +=
" inputs.emplace_back(call_frame->get_tensor_view(" + to_string(in.get_index()) + "));\n";
}
for (const auto &out : outputs)
{
TU +=
" outputs.emplace_back(call_frame->get_tensor_view(" + to_string(out.get_index()) + "));\n";
}
TU +=
" (*cf)(inputs, outputs);\n"
" }\n";
}
......@@ -77,6 +77,7 @@ namespace ngraph
void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall);
};
}
}
......
......@@ -108,6 +108,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::Convert), &Emitter::EmitConvert},
{TI(ngraph::op::Constant), &Emitter::EmitConstant},
{TI(ngraph::op::Reshape), &Emitter::EmitReshape},
{TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall},
};
#undef TI
......@@ -203,14 +204,15 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp"
void *__dso_handle = 0;
using namespace ngraph::element;
using namespace ngraph::runtime;
using namespace ngraph::runtime::cpu::eigen;
void *__dso_handle = 0;
extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
ngraph::runtime::TensorViewPtrs& tensor_views)
ngraph::runtime::TensorViewPtrs& tensor_views,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>& callees)
{
)";
......@@ -262,7 +264,9 @@ extern "C" void __entrypoint(ngraph::runtime::cpu::CallFrame* call_frame,
estate.add_module(llvm_module);
estate.finalize();
compiled_function = estate.find_function<void(
ngraph::runtime::cpu::CallFrame*, ngraph::runtime::TensorViewPtrs&)>("__entrypoint");
ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<CallFrame>>&)>("__entrypoint");
assert(compiled_function);
m_is_compiled = true;
......@@ -340,5 +344,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M
}
return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_inputs, m_n_outputs, temps);
compiled_function, m_n_inputs, m_n_outputs, temps, callees);
}
......@@ -20,8 +20,8 @@
#include <typeinfo>
#include <unordered_map>
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/function.hpp"
#include "ngraph/codegen/compiler.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -48,7 +48,8 @@ namespace ngraph
using OpMap = std::unordered_map<std::type_index, OpFunction>;
using EntryPoint = std::function<void(ngraph::runtime::cpu::CallFrame*,
ngraph::runtime::TensorViewPtrs&)>;
ngraph::runtime::TensorViewPtrs&,
const std::vector<std::shared_ptr<ngraph::runtime::cpu::CallFrame>>&)>;
class ExternalFunction : public ngraph::runtime::ExternalFunction
{
......@@ -56,6 +57,7 @@ namespace ngraph
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::vector<std::shared_ptr<CallFrame>> &get_callees() { return callees; }
protected:
void compile(FunctionMap& function_map);
......@@ -64,6 +66,7 @@ namespace ngraph
size_t m_n_outputs;
ngraph::descriptor::TensorViewPtrs m_temp_views;
EntryPoint compiled_function;
std::vector<std::shared_ptr<CallFrame>> callees;
};
}
}
......
......@@ -1084,8 +1084,7 @@ TEST(cpu, tensor_constant_with_op)
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), result->get_vector());
}
/*
TEST(execute, function_call)
TEST(cpu, function_call)
{
// First create "f(A,B,C) = (A+B)*C".
auto shape = Shape{2, 2};
......@@ -1106,7 +1105,7 @@ TEST(execute, function_call)
op::Parameters{X, Y, Z});
// Now call g on some test vectors.
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(g);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -1128,7 +1127,6 @@ TEST(execute, function_call)
(*cf)({x, z, y}, {result});
ASSERT_EQ((vector<float>{100, 144, 196, 256}), result->get_vector());
}
*/
TEST(cpu, broadcast_scalar_vector)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment