Commit de672abe authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Simplify external function use. (#127)

* Simplify external function use.

* By default, release original function.
Optional argument to delay release.
parent d93260a9
......@@ -33,8 +33,12 @@
using namespace std;
using namespace ngraph::runtime::eigen;
ExternalFunction::ExternalFunction()
: m_instructions(make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>())
ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function)
: m_function(function)
, m_release_function(release_function)
, m_is_compiled(false)
, m_instructions(make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>())
{
}
......@@ -83,12 +87,17 @@ std::unordered_map<std::type_index,
return op_map;
}
void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
void ExternalFunction::compile()
{
if (m_is_compiled)
{
return;
}
// This will be replaced with the pass manager
// Get the ordered list of ops in execution order
pass::TopologicalSort ts;
ts.run_on_tree(f->get_result());
ts.run_on_tree(m_function->get_result());
auto nodes = ts.get_call_graph();
// Types
for (auto node : nodes)
......@@ -104,7 +113,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
// Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs
for (auto param : f->get_parameters())
for (auto param : m_function->get_parameters())
{
for (auto output : param->get_outputs())
{
......@@ -116,7 +125,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
m_n_inputs = tensor_index.size();
// Next are the function outputs
for (auto output : f->get_result()->get_outputs())
for (auto output : m_function->get_result()->get_outputs())
{
auto tv = output.get_tensor_view();
size_t index = tensor_index.size();
......@@ -164,10 +173,19 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
handler_it->second(node, this, in, out);
}
m_instructions->push_back(make_shared<runtime::eigen::ReturnInstruction>());
m_is_compiled = true;
if (m_release_function)
{
release_function();
}
}
shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
{
if (!m_is_compiled)
{
compile();
}
std::vector<std::shared_ptr<ngraph::runtime::PrimaryTensorView>> temps;
for (auto tv : m_temp_views)
{
......
......@@ -30,9 +30,8 @@ namespace ngraph
class ExternalFunction
{
public:
ExternalFunction();
void compile(std::shared_ptr<ngraph::Function> f);
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
get_instructions()
......@@ -40,9 +39,17 @@ namespace ngraph
return m_instructions;
}
// Release original function's resources
void release_function() { m_function = nullptr; }
protected:
size_t m_n_inputs;
size_t m_n_outputs;
void compile();
std::shared_ptr<ngraph::Function> m_function;
bool m_release_function;
bool m_is_compiled;
size_t m_n_inputs;
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions;
std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views;
......
......@@ -28,8 +28,7 @@ TEST(execute, test_abc)
auto f = make_shared<Function>(make_shared<op::Multiply>(make_shared<op::Add>(A, B), C),
op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::eigen::ExternalFunction>();
external->compile(f);
auto external = make_shared<ngraph::runtime::eigen::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment