Unverified Commit 96604f12 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

back out api change (#842)

* back out api change
parent db788de8
......@@ -43,6 +43,6 @@ vector<size_t> runtime::Backend::get_subdevices(const string& type)
return manager->get_subdevices();
}
void runtime::Backend::remove_compiled_function(const Function& func)
void runtime::Backend::remove_compiled_function(std::shared_ptr<Function> func)
{
}
......@@ -93,13 +93,13 @@ namespace ngraph
return create_tensor(element::from<T>(), shape);
}
virtual bool compile(const ngraph::Function& func) = 0;
virtual bool compile(std::shared_ptr<Function> func) = 0;
virtual bool call(const ngraph::Function& func,
virtual bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) = 0;
virtual void remove_compiled_function(const ngraph::Function& func);
virtual void remove_compiled_function(std::shared_ptr<Function> func);
};
}
}
......@@ -46,30 +46,30 @@ std::shared_ptr<ngraph::runtime::TensorView>
return make_shared<runtime::cpu::CPUTensorView>(element_type, shape);
}
bool runtime::cpu::CPU_Backend::compile(const ngraph::Function& func)
bool runtime::cpu::CPU_Backend::compile(std::shared_ptr<Function> func)
{
if (!contains_key(m_function_map, &func))
if (!contains_key(m_function_map, func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_function = func;
instance.m_external_function = make_shared<CPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<CPU_CallFrame>(cf);
m_function_map.insert({&func, instance});
m_function_map.insert({func, instance});
}
return true;
}
bool runtime::cpu::CPU_Backend::call(const Function& func,
bool runtime::cpu::CPU_Backend::call(std::shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&func);
auto it = m_function_map.find(func);
if (it == m_function_map.end())
{
compile(func);
it = m_function_map.find(&func);
it = m_function_map.find(func);
}
if (it == m_function_map.end())
......@@ -96,9 +96,9 @@ bool runtime::cpu::CPU_Backend::call(
return true;
}
void runtime::cpu::CPU_Backend::remove_compiled_function(const Function& func)
void runtime::cpu::CPU_Backend::remove_compiled_function(std::shared_ptr<Function> func)
{
m_function_map.erase(&func);
m_function_map.erase(func);
}
std::shared_ptr<ngraph::runtime::TensorView> runtime::cpu::CPU_Backend::make_primary_tensor_view(
......
......@@ -51,16 +51,16 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
void remove_compiled_function(const Function& func) override;
void remove_compiled_function(std::shared_ptr<Function> func) override;
private:
class FunctionInstance
......@@ -71,7 +71,7 @@ namespace ngraph
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
};
}
}
......
......@@ -46,31 +46,31 @@ std::shared_ptr<ngraph::runtime::TensorView>
return dynamic_pointer_cast<runtime::TensorView>(rc);
}
bool runtime::gpu::GPU_Backend::compile(const ngraph::Function& func)
bool runtime::gpu::GPU_Backend::compile(std::shared_ptr<Function> func)
{
if (!contains_key(m_function_map, &func))
if (!contains_key(m_function_map, func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_function = func;
instance.m_external_function = make_shared<GPU_ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<GPU_CallFrame>(cf);
m_function_map.insert({&func, instance});
m_function_map.insert({func, instance});
}
return true;
}
bool runtime::gpu::GPU_Backend::call(
const ngraph::Function& func,
std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&func);
auto it = m_function_map.find(func);
if (it == m_function_map.end())
{
compile(func);
it = m_function_map.find(&func);
it = m_function_map.find(func);
}
if (it == m_function_map.end())
......
......@@ -52,12 +52,12 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
......@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
};
}
}
......
......@@ -53,31 +53,31 @@ shared_ptr<ngraph::runtime::TensorView>
return static_pointer_cast<runtime::TensorView>(rc);
}
bool runtime::interpreter::INT_Backend::compile(const ngraph::Function& func)
bool runtime::interpreter::INT_Backend::compile(std::shared_ptr<Function> func)
{
if (!contains_key(m_function_map, &func))
if (!contains_key(m_function_map, func))
{
FunctionInstance instance;
instance.m_function = clone_function(func);
instance.m_function = func;
instance.m_external_function =
make_shared<interpreter::ExternalFunction>(instance.m_function);
auto cf = instance.m_external_function->make_call_frame();
instance.m_call_frame = dynamic_pointer_cast<interpreter::INT_CallFrame>(cf);
m_function_map.insert({&func, instance});
m_function_map.insert({func, instance});
}
return true;
}
bool runtime::interpreter::INT_Backend::call(const Function& fun,
bool runtime::interpreter::INT_Backend::call(std::shared_ptr<Function> func,
const vector<shared_ptr<runtime::TensorView>>& outputs,
const vector<shared_ptr<runtime::TensorView>>& inputs)
{
bool rc = true;
auto it = m_function_map.find(&fun);
auto it = m_function_map.find(func);
if (it == m_function_map.end())
{
compile(fun);
it = m_function_map.find(&fun);
compile(func);
it = m_function_map.find(func);
}
if (it == m_function_map.end())
......
......@@ -52,12 +52,12 @@ namespace ngraph
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) override;
bool compile(const ngraph::Function& fun) override;
bool compile(std::shared_ptr<Function> func) override;
bool call(const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
bool call(const ngraph::Function& fun,
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::TensorView>>& outputs,
const std::vector<std::shared_ptr<runtime::TensorView>>& inputs) override;
......@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<Function> m_function;
};
std::map<const Function*, FunctionInstance> m_function_map;
std::map<std::shared_ptr<Function>, FunctionInstance> m_function_map;
};
}
}
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment