Commit de672abe authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Simplify external function use. (#127)

* Simplify external function use.

* By default, release original function.
Optional argument to delay release.
parent d93260a9
...@@ -33,8 +33,12 @@ ...@@ -33,8 +33,12 @@
using namespace std; using namespace std;
using namespace ngraph::runtime::eigen; using namespace ngraph::runtime::eigen;
ExternalFunction::ExternalFunction() ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
: m_instructions(make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>()) bool release_function)
: m_function(function)
, m_release_function(release_function)
, m_is_compiled(false)
, m_instructions(make_shared<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>())
{ {
} }
...@@ -83,12 +87,17 @@ std::unordered_map<std::type_index, ...@@ -83,12 +87,17 @@ std::unordered_map<std::type_index,
return op_map; return op_map;
} }
void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f) void ExternalFunction::compile()
{ {
if (m_is_compiled)
{
return;
}
// This will be replaced with the pass manager // This will be replaced with the pass manager
// Get the ordered list of ops in execution order // Get the ordered list of ops in execution order
pass::TopologicalSort ts; pass::TopologicalSort ts;
ts.run_on_tree(f->get_result()); ts.run_on_tree(m_function->get_result());
auto nodes = ts.get_call_graph(); auto nodes = ts.get_call_graph();
// Types // Types
for (auto node : nodes) for (auto node : nodes)
...@@ -104,7 +113,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f) ...@@ -104,7 +113,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
// Determine tensor requirements for the call frame // Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index; unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs // First come the function inputs
for (auto param : f->get_parameters()) for (auto param : m_function->get_parameters())
{ {
for (auto output : param->get_outputs()) for (auto output : param->get_outputs())
{ {
...@@ -116,7 +125,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f) ...@@ -116,7 +125,7 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
m_n_inputs = tensor_index.size(); m_n_inputs = tensor_index.size();
// Next are the function outputs // Next are the function outputs
for (auto output : f->get_result()->get_outputs()) for (auto output : m_function->get_result()->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
...@@ -164,10 +173,19 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f) ...@@ -164,10 +173,19 @@ void ExternalFunction::compile(std::shared_ptr<ngraph::Function> f)
handler_it->second(node, this, in, out); handler_it->second(node, this, in, out);
} }
m_instructions->push_back(make_shared<runtime::eigen::ReturnInstruction>()); m_instructions->push_back(make_shared<runtime::eigen::ReturnInstruction>());
m_is_compiled = true;
if (m_release_function)
{
release_function();
}
} }
shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame() shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
{ {
if (!m_is_compiled)
{
compile();
}
std::vector<std::shared_ptr<ngraph::runtime::PrimaryTensorView>> temps; std::vector<std::shared_ptr<ngraph::runtime::PrimaryTensorView>> temps;
for (auto tv : m_temp_views) for (auto tv : m_temp_views)
{ {
......
...@@ -30,9 +30,8 @@ namespace ngraph ...@@ -30,9 +30,8 @@ namespace ngraph
class ExternalFunction class ExternalFunction
{ {
public: public:
ExternalFunction(); ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
void compile(std::shared_ptr<ngraph::Function> f);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
get_instructions() get_instructions()
...@@ -40,9 +39,17 @@ namespace ngraph ...@@ -40,9 +39,17 @@ namespace ngraph
return m_instructions; return m_instructions;
} }
// Release original function's resources
void release_function() { m_function = nullptr; }
protected: protected:
size_t m_n_inputs; void compile();
size_t m_n_outputs;
std::shared_ptr<ngraph::Function> m_function;
bool m_release_function;
bool m_is_compiled;
size_t m_n_inputs;
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions; m_instructions;
std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views; std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views;
......
...@@ -28,8 +28,7 @@ TEST(execute, test_abc) ...@@ -28,8 +28,7 @@ TEST(execute, test_abc)
auto f = make_shared<Function>(make_shared<op::Multiply>(make_shared<op::Add>(A, B), C), auto f = make_shared<Function>(make_shared<op::Multiply>(make_shared<op::Add>(A, B), C),
op::Parameters{A, B, C}); op::Parameters{A, B, C});
auto external = make_shared<ngraph::runtime::eigen::ExternalFunction>(); auto external = make_shared<ngraph::runtime::eigen::ExternalFunction>(f);
external->compile(f);
auto cf = external->make_call_frame(); auto cf = external->make_call_frame();
// Create some tensors for input/output // Create some tensors for input/output
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment