Unverified Commit a2593991 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3276 from NervanaSystems/tomdol/pycapsule

Python capsules support for ngraph::Function
parents f6f3a032 2a3c7935
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
namespace py = pybind11; namespace py = pybind11;
static const char* CAPSULE_NAME = "ngraph_function";
void regclass_pyngraph_Function(py::module m) void regclass_pyngraph_Function(py::module m)
{ {
py::class_<ngraph::Function, std::shared_ptr<ngraph::Function>> function(m, "Function"); py::class_<ngraph::Function, std::shared_ptr<ngraph::Function>> function(m, "Function");
...@@ -49,4 +51,41 @@ void regclass_pyngraph_Function(py::module m) ...@@ -49,4 +51,41 @@ void regclass_pyngraph_Function(py::module m)
py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>(); py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>();
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>"; return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
}); });
function.def_static("from_capsule", [](py::object* capsule) {
// get the underlying PyObject* which is a PyCapsule pointer
auto* pybind_capsule_ptr = capsule->ptr();
// extract the pointer stored in the PyCapsule under the name CAPSULE_NAME
auto* capsule_ptr = PyCapsule_GetPointer(pybind_capsule_ptr, CAPSULE_NAME);
auto* ngraph_function = static_cast<std::shared_ptr<ngraph::Function>*>(capsule_ptr);
if (ngraph_function)
{
return *ngraph_function;
}
else
{
throw std::runtime_error("The provided capsule does not contain an ngraph::Function");
}
});
function.def_static("to_capsule", [](std::shared_ptr<ngraph::Function>& ngraph_function) {
// create a shared pointer on the heap before putting it in the capsule
// this secures the lifetime of the object transferred by the capsule
auto* sp_copy = new std::shared_ptr<ngraph::Function>(ngraph_function);
// a destructor callback that will delete the heap allocated shared_ptr
// when the capsule is destructed
auto sp_deleter = [](PyObject* capsule) {
auto* capsule_ptr = PyCapsule_GetPointer(capsule, CAPSULE_NAME);
auto* function_sp = static_cast<std::shared_ptr<ngraph::Function>*>(capsule_ptr);
if (function_sp)
{
delete function_sp;
}
};
// put the shared_ptr in a new capsule under the same name as in "from_capsule"
auto pybind_capsule = py::capsule(sp_copy, CAPSULE_NAME, sp_deleter);
return pybind_capsule;
});
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment