Commit 6e221a86 authored by tomdol's avatar tomdol

Convert the ngraph::Function from and to python capsule

parent c5b976c8
......@@ -14,6 +14,8 @@
// limitations under the License.
//*****************************************************************************
// #include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
......@@ -49,4 +51,23 @@ void regclass_pyngraph_Function(py::module m)
py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>();
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
});
function.def_static("from_capsule", [](py::object* capsule) {
auto* pycapsule_ptr = capsule->ptr();
auto* ngraph_function = reinterpret_cast<std::shared_ptr<ngraph::Function>*>(PyCapsule_GetPointer(pycapsule_ptr, "ngraph_function"));
// std::cout << "from_capsule: " << (*fun)->get_name() << " " << (*fun)->get_friendly_name() << " " << (*fun).get() << std::endl;
return *ngraph_function;
});
function.def_static("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) {
// std::cout << "to_capsule_1: " << ngraph_function->get_name() << " " << ngraph_function->get_friendly_name() << " "
// << ngraph_function.get() << " " << ngraph_function.use_count() << std::endl;
auto pybind_capsule = py::capsule(&ngraph_function, "ngraph_function", nullptr);
auto* ptr = pybind_capsule.ptr();
auto* fun = reinterpret_cast<std::shared_ptr<ngraph::Function>*>(PyCapsule_GetPointer(ptr, "ngraph_function"));
// std::cout << "to_capsule_2: " << (*fun)->get_name() << " " << (*fun)->get_friendly_name() << " " << (*fun).get()
// << " " << fun->use_count() << std::endl;
return pybind_capsule;
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment