Unverified Commit b5a0d734 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Update python wrapper to new Backend API (#863)

* remove obsolete

* change to use new Backend API

* rename parameter
parent ec501913
......@@ -106,3 +106,5 @@ python/share/*
\#*
\.#*
python/pybind11/
......@@ -17,21 +17,21 @@
import onnx
onnx_protobuf = onnx.load('/path/to/model/cntk_ResNet20_CIFAR10/model.onnx')
# Convert a serialized ONNX model to an ngraph model
from ngraph_onnx.onnx_importer.importer import import_onnx_model
ng_model = import_onnx_model(onnx_protobuf)[0]
# Using an ngraph runtime (CPU backend), create a callable computation
import ngraph as ng
runtime = ng.runtime(manager_name='CPU')
runtime = ng.runtime(backend_name='CPU')
resnet = runtime.computation(ng_model['output'], *ng_model['inputs'])
# Load or create an image
import numpy as np
picture = np.ones([1, 3, 32, 32])
# Run ResNet inference on picture
resnet(picture)
......@@ -29,7 +29,7 @@ model = (A + B) * C
# >>> print(model)
# <Node: 'Multiply_6'>
runtime = ng.runtime(manager_name='INTERPRETER')
runtime = ng.runtime(backend_name='INTERPRETER')
# >>> print(runtime)
# <Runtime: Manager='INTERPRETER'>
......
......@@ -28,7 +28,4 @@ else:
sys.setdlopenflags(flags)
from _pyngraph.runtime import Backend
from _pyngraph.runtime import CallFrame
from _pyngraph.runtime import ExternalFunction
from _pyngraph.runtime import Manager
from _pyngraph.runtime import TensorView
......@@ -20,7 +20,7 @@ from typing import List
import numpy as np
from ngraph.impl import Function, Node, serialize, TensorViewType, util
from ngraph.impl.runtime import Manager
from ngraph.impl.runtime import Backend
from ngraph.impl.op import Parameter
from ngraph.utils.types import get_dtype, NumericData
......@@ -28,24 +28,23 @@ from ngraph.utils.types import get_dtype, NumericData
log = logging.getLogger(__file__)
def runtime(manager_name='CPU'): # type: (str) -> 'Runtime'
def runtime(backend_name='CPU'): # type: (str) -> 'Runtime'
"""Create a Runtime object (helper factory).
Use signature to parametrize runtime as needed.
Use signature to parameterize runtime as needed.
"""
return Runtime(manager_name)
return Runtime(backend_name)
class Runtime:
"""Represents the ngraph++ runtime environment."""
def __init__(self, manager_name): # type: (str) -> None
self.manager_name = manager_name
self.manager = Manager.get(manager_name)
self.backend = self.manager.allocate_backend()
def __init__(self, backend_name): # type: (str) -> None
self.backend_name = backend_name
self.backend = Backend.create(backend_name)
def __repr__(self): # type: () -> str
return '<Runtime: Manager=\'{}\'>'.format(self.manager_name)
return '<Runtime: Backend=\'{}\'>'.format(self.backend_name)
def computation(self, node, *inputs): # type: (Node, *Node) -> 'Computation'
"""Return a callable Computation object."""
......@@ -63,10 +62,9 @@ class Computation:
for parameter in parameters:
shape = parameter.get_shape()
element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.make_primary_tensor_view(element_type, shape))
self.tensor_views.append(runtime.backend.create_tensor(element_type, shape))
self.function = Function(self.node, self.parameters, 'ngraph_computation')
external = self.runtime.manager.compile(self.function)
self.call_frame = self.runtime.backend.make_call_frame(external)
self.backend = runtime.backend
def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters])
......@@ -83,11 +81,11 @@ class Computation:
result_shape = self.node.get_shape()
result_dtype = get_dtype(result_element_type)
result_view = self.runtime.backend.make_primary_tensor_view(
result_view = self.runtime.backend.create_tensor(
result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype)
self.call_frame.call([result_view], self.tensor_views)
self.backend.call(self.function, [result_view], self.tensor_views)
Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape)
......
......@@ -18,8 +18,6 @@
#include <pybind11/stl.h>
//#include <string>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "pyngraph/runtime/backend.hpp"
......@@ -30,9 +28,30 @@ void regclass_pyngraph_runtime_Backend(py::module m)
py::class_<ngraph::runtime::Backend, std::shared_ptr<ngraph::runtime::Backend>> backend(
m, "Backend");
backend.doc() = "ngraph.impl.runtime.Backend wraps ngraph::runtime::Backend";
backend.def("make_call_frame", &ngraph::runtime::Backend::make_call_frame);
backend.def("make_primary_tensor_view",
backend.def_static("create", &ngraph::runtime::Backend::create);
backend.def_static("get_registered_devices", &ngraph::runtime::Backend::get_registered_devices);
backend.def_static("get_subdevices", &ngraph::runtime::Backend::get_subdevices);
backend.def("create_tensor",
(std::shared_ptr<ngraph::runtime::TensorView>(ngraph::runtime::Backend::*)(
const ngraph::element::Type&, const ngraph::Shape&)) &
ngraph::runtime::Backend::create_tensor);
backend.def("compile",
(void (ngraph::runtime::Backend::*)(std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::compile);
backend.def("call",
(void (ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>&,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>&)) &
ngraph::runtime::Backend::call);
backend.def("remove_compiled_function",
(void (ngraph::runtime::Backend::*)(std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::remove_compiled_function);
backend.def("enable_performance_data",
(void (ngraph::runtime::Backend::*)(std::shared_ptr<ngraph::Function>, bool)) &
ngraph::runtime::Backend::enable_performance_data);
backend.def("get_performance_data",
(std::vector<ngraph::runtime::PerformanceCounter>(ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::get_performance_data);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
//#include <string>
#include "pyngraph/runtime/call_frame.hpp"
namespace py = pybind11;
void regclass_pyngraph_runtime_CallFrame(py::module m)
{
py::class_<ngraph::runtime::CallFrame, std::shared_ptr<ngraph::runtime::CallFrame>> callFrame(
m, "CallFrame");
callFrame.doc() = "ngraph.impl.runtime.CallFrame wraps ngraph::runtime::CallFrame";
callFrame.def("call", &ngraph::runtime::CallFrame::call);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_runtime_CallFrame(py::module m);
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
//#include <string>
#include "ngraph/runtime/external_function.hpp"
#include "pyngraph/runtime/external_function.hpp"
namespace py = pybind11;
void regclass_pyngraph_runtime_ExternalFunction(py::module m)
{
py::class_<ngraph::runtime::ExternalFunction,
std::shared_ptr<ngraph::runtime::ExternalFunction>>
externalFunction(m, "ExternalFunction");
externalFunction.doc() =
"ngraph.impl.runtime.ExternalFunction wraps ngraph::runtime::ExternalFunction";
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_runtime_ExternalFunction(py::module m);
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
//#include <string>
#include "ngraph/function.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/manager.hpp"
#include "pyngraph/runtime/manager.hpp"
namespace py = pybind11;
void regclass_pyngraph_runtime_Manager(py::module m)
{
py::class_<ngraph::runtime::Manager, std::shared_ptr<ngraph::runtime::Manager>> manager(
m, "Manager");
manager.doc() = "ngraph.impl.runtime.Manager wraps ngraph::runtime::Manager";
manager.def_static("get", &ngraph::runtime::Manager::get);
manager.def("compile", &ngraph::runtime::Manager::compile);
manager.def("allocate_backend", &ngraph::runtime::Manager::allocate_backend);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_runtime_Manager(py::module m);
......@@ -25,7 +25,4 @@ void regmodule_pyngraph_runtime(py::module m)
m.def_submodule("runtime", "Package ngraph.impl.runtime wraps ngraph::runtime");
regclass_pyngraph_runtime_TensorView(m_runtime);
regclass_pyngraph_runtime_Backend(m_runtime);
regclass_pyngraph_runtime_CallFrame(m_runtime);
regclass_pyngraph_runtime_ExternalFunction(m_runtime);
regclass_pyngraph_runtime_Manager(m_runtime);
}
......@@ -18,9 +18,6 @@
#include <pybind11/pybind11.h>
#include "pyngraph/runtime/backend.hpp"
#include "pyngraph/runtime/call_frame.hpp"
#include "pyngraph/runtime/external_function.hpp"
#include "pyngraph/runtime/manager.hpp"
#include "pyngraph/runtime/tensor_view.hpp"
namespace py = pybind11;
......
......@@ -204,9 +204,6 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp',
'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/call_frame.cpp',
'pyngraph/runtime/external_function.cpp',
'pyngraph/runtime/manager.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
'pyngraph/runtime/tensor_view.cpp',
'pyngraph/passes/manager.cpp',
......
......@@ -71,14 +71,14 @@ def test_function_call():
def test_serialization():
dtype = np.float32
manager_name = pytest.config.getoption('backend', default='CPU')
backend_name = pytest.config.getoption('backend', default='CPU')
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=dtype, name='A')
parameter_b = ng.parameter(shape, dtype=dtype, name='B')
parameter_c = ng.parameter(shape, dtype=dtype, name='C')
model = (parameter_a + parameter_b) * parameter_c
runtime = ng.runtime(manager_name=manager_name)
runtime = ng.runtime(backend_name=backend_name)
computation = runtime.computation(model, parameter_a, parameter_b, parameter_c)
serialized = computation.serialize(2)
serial_json = json.loads(serialized)
......
......@@ -27,8 +27,8 @@ def _get_numpy_dtype(scalar):
def get_runtime():
"""Return runtime object."""
manager_name = pytest.config.getoption('backend', default='CPU')
return ng.runtime(manager_name=manager_name)
backend_name = pytest.config.getoption('backend', default='CPU')
return ng.runtime(backend_name=backend_name)
def run_op_node(input_data, op_fun, *args):
......
This diff is collapsed.
......@@ -40,7 +40,7 @@ namespace ngraph
///
/// | Type | Description |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------- |
/// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function or in the initial `ngraph::runtime::CallFrame`. |
/// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function. |
class Parameter : public op::Op
{
protected:
......
......@@ -25,7 +25,6 @@ namespace ngraph
{
namespace runtime
{
class CallFrame;
namespace cpu
{
class CPU_ExternalFunction;
......
......@@ -27,8 +27,6 @@ namespace ngraph
{
namespace cpu
{
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment