Commit 40131cdb authored by tomdol's avatar tomdol

Fix the object lifetime problem

parent 6e221a86
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
namespace py = pybind11; namespace py = pybind11;
static const char* CAPSULE_NAME = "ngraph_function";
void regclass_pyngraph_Function(py::module m) void regclass_pyngraph_Function(py::module m)
{ {
py::class_<ngraph::Function, std::shared_ptr<ngraph::Function>> function(m, "Function"); py::class_<ngraph::Function, std::shared_ptr<ngraph::Function>> function(m, "Function");
...@@ -52,21 +54,22 @@ void regclass_pyngraph_Function(py::module m) ...@@ -52,21 +54,22 @@ void regclass_pyngraph_Function(py::module m)
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>"; return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
}); });
function.def_static("from_capsule", [](py::object* capsule) { function.def_static("from_capsule", [](py::object* capsule) {
auto* pycapsule_ptr = capsule->ptr(); auto* pybind_capsule_ptr = capsule->ptr();
auto* ngraph_function = reinterpret_cast<std::shared_ptr<ngraph::Function>*>(PyCapsule_GetPointer(pycapsule_ptr, "ngraph_function")); auto* capsule_ptr = PyCapsule_GetPointer(pybind_capsule_ptr, CAPSULE_NAME);
// std::cout << "from_capsule: " << (*fun)->get_name() << " " << (*fun)->get_friendly_name() << " " << (*fun).get() << std::endl; auto* ngraph_function = static_cast<std::shared_ptr<ngraph::Function>*>(capsule_ptr);
return *ngraph_function; return *ngraph_function;
}); });
function.def_static("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) { function.def_static("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) {
// std::cout << "to_capsule_1: " << ngraph_function->get_name() << " " << ngraph_function->get_friendly_name() << " " auto* sp_copy = new std::shared_ptr<ngraph::Function>(ngraph_function);
// << ngraph_function.get() << " " << ngraph_function.use_count() << std::endl; auto pybind_capsule = py::capsule(sp_copy, CAPSULE_NAME, [](PyObject* capsule) {
auto pybind_capsule = py::capsule(&ngraph_function, "ngraph_function", nullptr); auto* capsule_ptr = PyCapsule_GetPointer(capsule, CAPSULE_NAME);
auto* function_sp = static_cast<std::shared_ptr<ngraph::Function>*>(capsule_ptr);
auto* ptr = pybind_capsule.ptr(); if (function_sp)
auto* fun = reinterpret_cast<std::shared_ptr<ngraph::Function>*>(PyCapsule_GetPointer(ptr, "ngraph_function")); {
delete function_sp;
// std::cout << "to_capsule_2: " << (*fun)->get_name() << " " << (*fun)->get_friendly_name() << " " << (*fun).get() }
// << " " << fun->use_count() << std::endl; });
return pybind_capsule; return pybind_capsule;
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment