Commit 12def435 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Expose onnx_import C++ interface in Python API. (#1499)

* Expose onnx_import C++ interface in Python API.

* Pretty printing.

* Add Computation interface using Ngraph Function objects plus UT.

* Apply code format.

* Remove unnecessary stream open mode.

- Code formatting.

* Fix onnx_import submodule visibility.

- Folder restructurization.

* Fix some small errors.

- Wrong function type annotations.
- Class doc.
- Code formatting.
- Class inheritance from object.

* Use modified Runtime class interface.

* Add model for test_onnx_import.

* Revert back to old API.

- Use of Function object in Computation class.

* Use of previous verions API.

* Small refactoring

* Code cleanup
parent 94c5acda
# ******************************************************************************
# Copyright 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Package: ngraph
Low level wrappers for the nGraph c++ api in ngraph::onnx_import.
"""
# flake8: noqa
import sys
import six
# workaround to load the libngraph.so with RTLD_GLOBAL
if six.PY3:
import os
flags = os.RTLD_NOW | os.RTLD_GLOBAL
else:
import ctypes
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
sys.setdlopenflags(flags)
from _pyngraph.onnx_import import load_onnx_model
from _pyngraph.onnx_import import load_onnx_model_file
from _pyngraph.onnx_import import import_onnx_function
from _pyngraph.onnx_import import import_onnx_function_file
...@@ -15,13 +15,12 @@ ...@@ -15,13 +15,12 @@
# ****************************************************************************** # ******************************************************************************
"""Provide a layer of abstraction for the ngraph++ runtime environment.""" """Provide a layer of abstraction for the ngraph++ runtime environment."""
import logging import logging
from typing import List from typing import List, Union
import numpy as np import numpy as np
from ngraph.impl import Function, Node, serialize, TensorViewType, util from ngraph.impl import Function, Node, Shape, serialize, TensorViewType, util
from ngraph.impl.runtime import Backend from ngraph.impl.runtime import Backend
from ngraph.impl.op import Parameter
from ngraph.utils.types import get_dtype, NumericData from ngraph.utils.types import get_dtype, NumericData
from ngraph.exceptions import UserInputError from ngraph.exceptions import UserInputError
...@@ -47,29 +46,38 @@ class Runtime: ...@@ -47,29 +46,38 @@ class Runtime:
def __repr__(self): # type: () -> str def __repr__(self): # type: () -> str
return '<Runtime: Backend=\'{}\'>'.format(self.backend_name) return '<Runtime: Backend=\'{}\'>'.format(self.backend_name)
def computation(self, node, *inputs): # type: (Node, *Node) -> 'Computation' def computation(self, node_or_function, *inputs):
# type: (Union[Node, Function], *Node) -> 'Computation'
"""Return a callable Computation object.""" """Return a callable Computation object."""
return Computation(self, node, *inputs) if isinstance(node_or_function, Node):
ng_function = Function(node_or_function, inputs, node_or_function.name)
return Computation(self, ng_function)
class Computation: elif isinstance(node_or_function, Function):
return Computation(self, node_or_function)
else:
raise TypeError('Runtime.computation must be called with an nGraph Function object '
'or an nGraph node object an optionally Parameter node objects. '
'Called with: %s', node_or_function)
class Computation(object):
"""ngraph callable computation object.""" """ngraph callable computation object."""
def __init__(self, runtime, node, *parameters): # type: (Runtime, Node, *Parameter) -> None def __init__(self, runtime, ng_function):
# type: (Runtime, Function) -> None
self.runtime = runtime self.runtime = runtime
self.node = node self.function = ng_function
self.parameters = parameters self.parameters = ng_function.get_parameters()
self.tensor_views = [] # type: List[TensorViewType] self.tensor_views = [] # type: List[TensorViewType]
for parameter in parameters: for parameter in self.parameters:
shape = parameter.get_shape() shape = parameter.get_shape()
element_type = parameter.get_element_type() element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.create_tensor(element_type, shape)) self.tensor_views.append(runtime.backend.create_tensor(element_type, shape))
self.function = Function(self.node, self.parameters, 'ngraph_computation')
self.backend = runtime.backend
def __repr__(self): # type: () -> str def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters]) params_string = ', '.join([param.name for param in self.parameters])
return '<Computation: {}({})>'.format(self.node.name, params_string) return '<Computation: {}({})>'.format(self.function.get_name(), params_string)
def __call__(self, *input_values): # type: (*NumericData) -> NumericData def __call__(self, *input_values): # type: (*NumericData) -> NumericData
"""Run computation on input values and return result.""" """Run computation on input values and return result."""
...@@ -78,15 +86,14 @@ class Computation: ...@@ -78,15 +86,14 @@ class Computation:
value = np.array(value) value = np.array(value)
Computation._write_ndarray_to_tensor_view(value, tensor_view) Computation._write_ndarray_to_tensor_view(value, tensor_view)
result_element_type = self.node.get_element_type() result_element_type = self.function.get_output_element_type(0)
result_shape = self.node.get_shape() result_shape = self.function.get_output_shape(0)
result_dtype = get_dtype(result_element_type) result_dtype = get_dtype(result_element_type)
result_view = self.runtime.backend.create_tensor( result_view = self.runtime.backend.create_tensor(result_element_type, result_shape)
result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype) result_arr = np.empty(result_shape, dtype=result_dtype)
self.backend.call(self.function, [result_view], self.tensor_views) self.runtime.backend.call(self.function, [result_view], self.tensor_views)
Computation._read_tensor_view_to_ndarray(result_view, result_arr) Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape) result_arr = result_arr.reshape(result_shape)
......
...@@ -41,6 +41,13 @@ void regclass_pyngraph_Function(py::module m) ...@@ -41,6 +41,13 @@ void regclass_pyngraph_Function(py::module m)
function.def("get_parameters", &ngraph::Function::get_parameters); function.def("get_parameters", &ngraph::Function::get_parameters);
function.def("get_results", &ngraph::Function::get_results); function.def("get_results", &ngraph::Function::get_results);
function.def("get_result", &ngraph::Function::get_result); function.def("get_result", &ngraph::Function::get_result);
function.def("get_name", &ngraph::Function::get_name); function.def("get_unique_name", &ngraph::Function::get_name);
function.def("get_name", &ngraph::Function::get_friendly_name);
function.def("set_name", &ngraph::Function::set_name); function.def("set_name", &ngraph::Function::set_name);
function.def("__repr__", [](const ngraph::Function& self) {
std::string class_name = py::cast(self).get_type().attr("__name__").cast<std::string>();
std::string shape =
py::cast(self.get_output_shape(0)).attr("__str__")().cast<std::string>();
return "<" + class_name + ": '" + self.get_friendly_name() + "' (" + shape + ")>";
});
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <istream>
#include <memory>
#include <string>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/function.hpp"
#include "pyngraph/onnx_import/onnx_import.hpp"
namespace py = pybind11;
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model(const std::string& model_proto)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::load_onnx_model(iss);
}
static std::shared_ptr<ngraph::Function> import_onnx_function(const std::string& model_proto)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::import_onnx_function(iss);
}
void regmodule_pyngraph_onnx_import(py::module mod)
{
mod.def("load_onnx_model", &load_onnx_model);
mod.def("import_onnx_function", &import_onnx_function);
mod.def("load_onnx_model_file",
static_cast<std::vector<std::shared_ptr<ngraph::Function>> (*)(const std::string&)>(
&ngraph::onnx_import::load_onnx_model),
py::arg());
mod.def("import_onnx_function_file",
static_cast<std::shared_ptr<ngraph::Function> (*)(const std::string&)>(
&ngraph::onnx_import::import_onnx_function),
py::arg());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regmodule_pyngraph_onnx_import(py::module m);
...@@ -32,4 +32,11 @@ void regclass_pyngraph_op_ParameterVector(py::module m) ...@@ -32,4 +32,11 @@ void regclass_pyngraph_op_ParameterVector(py::module m)
py::init<const std::initializer_list<std::shared_ptr<ngraph::op::Parameter>>&>()); py::init<const std::initializer_list<std::shared_ptr<ngraph::op::Parameter>>&>());
parameter_vector.def(py::init<const std::vector<std::shared_ptr<ngraph::op::Parameter>>&>()); parameter_vector.def(py::init<const std::vector<std::shared_ptr<ngraph::op::Parameter>>&>());
parameter_vector.def(py::init<const ngraph::op::ParameterVector&>()); parameter_vector.def(py::init<const ngraph::op::ParameterVector&>());
parameter_vector.def("__len__", [](const ngraph::op::ParameterVector& v) { return v.size(); });
parameter_vector.def("__getitem__",
[](const ngraph::op::ParameterVector& v, int key) { return v[key]; });
parameter_vector.def(
"__iter__",
[](ngraph::op::ParameterVector& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */
} }
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "pyngraph/function.hpp" #include "pyngraph/function.hpp"
#include "pyngraph/node.hpp" #include "pyngraph/node.hpp"
#include "pyngraph/node_vector.hpp" #include "pyngraph/node_vector.hpp"
#include "pyngraph/onnx_import/onnx_import.hpp"
#include "pyngraph/ops/op.hpp" #include "pyngraph/ops/op.hpp"
#include "pyngraph/ops/regmodule_pyngraph_op.hpp" #include "pyngraph/ops/regmodule_pyngraph_op.hpp"
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp" #include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
...@@ -48,6 +49,10 @@ PYBIND11_MODULE(_pyngraph, m) ...@@ -48,6 +49,10 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_types(m); regmodule_pyngraph_types(m);
regclass_pyngraph_Function(m); regclass_pyngraph_Function(m);
regclass_pyngraph_Serializer(m); regclass_pyngraph_Serializer(m);
py::module m_onnx_import = m.def_submodule("onnx_import",
"Package ngraph.impl.onnx_import "
"that wraps ngraph::onnx_import");
regmodule_pyngraph_onnx_import(m_onnx_import);
py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ngraph::op"); py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ngraph::op");
regclass_pyngraph_op_Op(m_op); regclass_pyngraph_op_Op(m_op);
regmodule_pyngraph_op_util(m_op); regmodule_pyngraph_op_util(m_op);
......
...@@ -126,6 +126,7 @@ sources = ['pyngraph/function.cpp', ...@@ -126,6 +126,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/serializer.cpp', 'pyngraph/serializer.cpp',
'pyngraph/node.cpp', 'pyngraph/node.cpp',
'pyngraph/node_vector.cpp', 'pyngraph/node_vector.cpp',
'pyngraph/onnx_import/onnx_import.cpp',
'pyngraph/shape.cpp', 'pyngraph/shape.cpp',
'pyngraph/strides.cpp', 'pyngraph/strides.cpp',
'pyngraph/coordinate_diff.cpp', 'pyngraph/coordinate_diff.cpp',
...@@ -298,11 +299,13 @@ setup( ...@@ -298,11 +299,13 @@ setup(
'ngraph.utils': PYNGRAPH_SOURCE_DIR + "/ngraph/utils", 'ngraph.utils': PYNGRAPH_SOURCE_DIR + "/ngraph/utils",
'ngraph.impl': PYNGRAPH_SOURCE_DIR + "/ngraph/impl", 'ngraph.impl': PYNGRAPH_SOURCE_DIR + "/ngraph/impl",
'ngraph.impl.op': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op", 'ngraph.impl.op': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op",
'ngraph.impl.onnx_import': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/onnx_import",
'ngraph.impl.op.util': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op/util", 'ngraph.impl.op.util': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/op/util",
'ngraph.impl.passes': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/passes", 'ngraph.impl.passes': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/passes",
'ngraph.impl.runtime': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/runtime"}, 'ngraph.impl.runtime': PYNGRAPH_SOURCE_DIR + "/ngraph/impl/runtime"},
packages = ['ngraph', 'ngraph.utils', 'ngraph.impl', 'ngraph.impl.op', packages = ['ngraph', 'ngraph.utils', 'ngraph.impl', 'ngraph.impl.onnx_import',
'ngraph.impl.op.util', 'ngraph.impl.passes', 'ngraph.impl.runtime'], 'ngraph.impl.op', 'ngraph.impl.op.util', 'ngraph.impl.passes',
'ngraph.impl.runtime'],
cmdclass={'build_ext': BuildExt}, cmdclass={'build_ext': BuildExt},
data_files = data_files, data_files = data_files,
install_requires = requirements, install_requires = requirements,
......
ngraph ONNXImporter:

A
BX add_node1"Add

X
CY add_node2"Add
test_graphZ
A

Z
B

Z
C

b
Y

B
\ No newline at end of file
...@@ -43,8 +43,8 @@ def test_convolution_2d(): ...@@ -43,8 +43,8 @@ def test_convolution_2d():
[1., 0., -1.]], dtype=np.float32).reshape(1, 1, 3, 3)) [1., 0., -1.]], dtype=np.float32).reshape(1, 1, 3, 3))
# convolution with padding=1 should produce 9 x 9 output: # convolution with padding=1 should produce 9 x 9 output:
model = runtime.computation(ng.convolution(input_x, input_filter, model = runtime.computation(ng.convolution(input_x, input_filter, padding_above=[1, 1],
padding_above=[1, 1], padding_below=[1, 1])) padding_below=[1, 1]))
result = model() result = model()
assert np.allclose(result, assert np.allclose(result,
......
# ******************************************************************************
# Copyright 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import os
import numpy as np
from ngraph.impl.onnx_import import load_onnx_model_file
from test.ngraph.util import get_runtime
def test_import_onnx_function():
dtype = np.float32
cur_dir = os.path.dirname(__file__)
model_path = os.path.join(cur_dir, 'models/add_abc.onnx')
ng_function = load_onnx_model_file(model_path)[0]
value_a = np.array([1.0], dtype=dtype)
value_b = np.array([2.0], dtype=dtype)
value_c = np.array([3.0], dtype=dtype)
result = ng_function(value_a, value_b, value_c, runtime=get_runtime())
assert np.allclose(result, np.array([6], dtype=dtype))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment