Unverified Commit 96604f12 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

back out api change (#842)

* back out api change
parent db788de8
...@@ -43,6 +43,6 @@ vector<size_t> runtime::Backend::get_subdevices(const string& type) ...@@ -43,6 +43,6 @@ vector<size_t> runtime::Backend::get_subdevices(const string& type)
return manager->get_subdevices(); return manager->get_subdevices();
} }
void runtime::Backend::remove_compiled_function(const Function& func) void runtime::Backend::remove_compiled_function(std::shared_ptr<Function> func)
{ {
} }
...@@ -93,13 +93,13 @@ namespace ngraph ...@@ -93,13 +93,13 @@ namespace ngraph
return create_tensor(element::from<T>(), shape); return create_tensor(element::from<T>(), shape);
} }
virtual bool compile(const ngraph::Function& func) = 0; virtual bool compile(std::shared_ptr<Function> func) = 0;
virtual bool call(const ngraph::Function& func, virtual bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
virtual void remove_compiled_function(const ngraph::Function& func); virtual void remove_compiled_function(std::shared_ptr<Function> func);
}; };
} }
} }
...@@ -46,30 +46,30 @@ std::shared_ptr<ngraph::runtime::TensorView> ...@@ -46,30 +46,30 @@ std::shared_ptr<ngraph::runtime::TensorView>
return make_shared<runtime::cpu::CPUTensorView>(element_type, shape); return make_shared<runtime::cpu::CPUTensorView>(element_type, shape);
} }
bool runtime::cpu::CPU_Backend::compile(const ngraph::Function& func) bool runtime::cpu::CPU_Backend::compile(std::shared_ptr<Function> func)
{ {
if (!contains_key(m_function_map, &func)) if (!contains_key(m_function_map, func))
{ {
FunctionInstance instance; FunctionInstance instance;
instance.m_function = clone_function(func); instance.m_function = func;
instance.m_external_function = make_shared<CPU_ExternalFunction>(instance.m_function); instance.m_external_function = make_shared<CPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame(); auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf); instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
m_function_map.insert({&func, instance}); m_function_map.insert({func, instance});
} }
return true; return true;
} }
bool runtime::cpu::CPU_Backend::call(const Function& func, bool runtime::cpu::CPU_Backend::call(std::shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& outputs, const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs) const vector<shared_ptr<runtime::TensorView>>& inputs)
{ {
bool rc = true; bool rc = true;
auto it = m_function_map.find(&func); auto it = m_function_map.find(func);
if (it == m_function_map.end()) if (it == m_function_map.end())
{ {
compile(func); compile(func);
it = m_function_map.find(&func); it = m_function_map.find(func);
} }
if (it == m_function_map.end()) if (it == m_function_map.end())
...@@ -96,9 +96,9 @@ bool runtime::cpu::CPU_Backend::call( ...@@ -96,9 +96,9 @@ bool runtime::cpu::CPU_Backend::call(
return true; return true;
} }
void runtime::cpu::CPU_Backend::remove_compiled_function(const Function& func) void runtime::cpu::CPU_Backend::remove_compiled_function(std::shared_ptr<Function> func)
{ {
m_function_map.erase(&func); m_function_map.erase(func);
} }
std::shared_ptr<ngraph::runtime::TensorView> runtime::cpu::CPU_Backend::make_primary_tensor_view( std::shared_ptr<ngraph::runtime::TensorView> runtime::cpu::CPU_Backend::make_primary_tensor_view(
......
...@@ -51,16 +51,16 @@ namespace ngraph ...@@ -51,16 +51,16 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type, create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override; const Shape& shape) override;
bool compile(const ngraph::Function& fun) override; bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun, bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
void remove_compiled_function(const Function& func) override; void remove_compiled_function(std::shared_ptr<Function> func) override;
private: private:
class FunctionInstance class FunctionInstance
...@@ -71,7 +71,7 @@ namespace ngraph ...@@ -71,7 +71,7 @@ namespace ngraph
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
std::map<const Function*, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
}; };
} }
} }
......
...@@ -46,31 +46,31 @@ std::shared_ptr<ngraph::runtime::TensorView> ...@@ -46,31 +46,31 @@ std::shared_ptr<ngraph::runtime::TensorView>
return dynamic_pointer_cast<runtime::TensorView>(rc); return dynamic_pointer_cast<runtime::TensorView>(rc);
} }
bool runtime::gpu::GPU_Backend::compile(const ngraph::Function& func) bool runtime::gpu::GPU_Backend::compile(std::shared_ptr<Function> func)
{ {
if (!contains_key(m_function_map, &func)) if (!contains_key(m_function_map, func))
{ {
FunctionInstance instance; FunctionInstance instance;
instance.m_function = clone_function(func); instance.m_function = func;
instance.m_external_function = make_shared<GPU_ExternalFunction>(instance.m_function); instance.m_external_function = make_shared<GPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame(); auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<GPU_CallFrame>(cf); instance.m_call_frame = dynamic_pointer_cast<GPU_CallFrame>(cf);
m_function_map.insert({&func, instance}); m_function_map.insert({func, instance});
} }
return true; return true;
} }
bool runtime::gpu::GPU_Backend::call( bool runtime::gpu::GPU_Backend::call(
const ngraph::Function& func, std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) const std::vector<std::shared_ptr<runtime::TensorView>>& inputs)
{ {
bool rc = true; bool rc = true;
auto it = m_function_map.find(&func); auto it = m_function_map.find(func);
if (it == m_function_map.end()) if (it == m_function_map.end())
{ {
compile(func); compile(func);
it = m_function_map.find(&func); it = m_function_map.find(func);
} }
if (it == m_function_map.end()) if (it == m_function_map.end())
......
...@@ -52,12 +52,12 @@ namespace ngraph ...@@ -52,12 +52,12 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type, create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override; const Shape& shape) override;
bool compile(const ngraph::Function& fun) override; bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun, bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
std::map<const Function*, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
}; };
} }
} }
......
...@@ -53,31 +53,31 @@ shared_ptr<ngraph::runtime::TensorView> ...@@ -53,31 +53,31 @@ shared_ptr<ngraph::runtime::TensorView>
return static_pointer_cast<runtime::TensorView>(rc); return static_pointer_cast<runtime::TensorView>(rc);
} }
bool runtime::interpreter::INT_Backend::compile(const ngraph::Function& func) bool runtime::interpreter::INT_Backend::compile(std::shared_ptr<Function> func)
{ {
if (!contains_key(m_function_map, &func)) if (!contains_key(m_function_map, func))
{ {
FunctionInstance instance; FunctionInstance instance;
instance.m_function = clone_function(func); instance.m_function = func;
instance.m_external_function = instance.m_external_function =
make_shared<interpreter::ExternalFunction>(instance.m_function); make_shared<interpreter::ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame(); auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<interpreter::INT_CallFrame>(cf); instance.m_call_frame = dynamic_pointer_cast<interpreter::INT_CallFrame>(cf);
m_function_map.insert({&func, instance}); m_function_map.insert({func, instance});
} }
return true; return true;
} }
bool runtime::interpreter::INT_Backend::call(const Function& fun, bool runtime::interpreter::INT_Backend::call(std::shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& outputs, const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs) const vector<shared_ptr<runtime::TensorView>>& inputs)
{ {
bool rc = true; bool rc = true;
auto it = m_function_map.find(&fun); auto it = m_function_map.find(func);
if (it == m_function_map.end()) if (it == m_function_map.end())
{ {
compile(fun); compile(func);
it = m_function_map.find(&fun); it = m_function_map.find(func);
} }
if (it == m_function_map.end()) if (it == m_function_map.end())
......
...@@ -52,12 +52,12 @@ namespace ngraph ...@@ -52,12 +52,12 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type, create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override; const Shape& shape) override;
bool compile(const ngraph::Function& fun) override; bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun, bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs, const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override; const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
}; };
std::map<const Function*, FunctionInstance> m_function_map; std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
}; };
} }
} }
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment