Unverified Commit c5b082c6 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] return single nGraph function with multiple outputs (#2017)

* onnx: return signle nGraph function with multiple outputs
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: remove excessive code
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* Update ngraph python unit tests
parent aaf25652
......@@ -32,7 +32,5 @@ else:
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
sys.setdlopenflags(flags)
from _pyngraph_onnx_import import load_onnx_model
from _pyngraph_onnx_import import load_onnx_model_file
from _pyngraph_onnx_import import import_onnx_function
from _pyngraph_onnx_import import import_onnx_function_file
from _pyngraph_onnx_import import import_onnx_model
from _pyngraph_onnx_import import import_onnx_model_file
......@@ -67,6 +67,7 @@ class Computation(object):
self.runtime = runtime
self.function = ng_function
self.parameters = ng_function.get_parameters()
self.results = ng_function.get_results()
self.tensor_views = [] # type: List[Tensor]
for parameter in self.parameters:
......@@ -74,6 +75,12 @@ class Computation(object):
element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.create_tensor(element_type, shape))
self.result_views = [] # type: List[Tensor]
for result in self.results:
shape = result.get_shape()
element_type = result.get_element_type()
self.result_views.append(runtime.backend.create_tensor(element_type, shape))
def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters])
return '<Computation: {}({})>'.format(self.function.get_name(), params_string)
......@@ -85,18 +92,15 @@ class Computation(object):
value = np.array(value)
Computation._write_ndarray_to_tensor_view(value, tensor_view)
result_element_type = self.function.get_output_element_type(0)
result_shape = self.function.get_output_shape(0)
result_dtype = get_dtype(result_element_type)
result_view = self.runtime.backend.create_tensor(result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype)
self.runtime.backend.call(self.function, self.result_views, self.tensor_views)
self.runtime.backend.call(self.function, [result_view], self.tensor_views)
results = []
for result_view in self.result_views:
result = np.ndarray(result_view.shape, dtype=get_dtype(result_view.element_type))
Computation._read_tensor_view_to_ndarray(result_view, result)
results.append(result)
Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape)
return result_arr
return results
def serialize(self, indent=0): # type: (int) -> str
"""Serialize function (compute graph) to a JSON string.
......
......@@ -28,34 +28,19 @@
namespace py = pybind11;
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model(const std::string& model_proto)
static std::shared_ptr<ngraph::Function> import_onnx_model(const std::string& model_proto)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::load_onnx_model(iss);
return ngraph::onnx_import::import_onnx_model(iss);
}
static std::shared_ptr<ngraph::Function> import_onnx_function(const std::string& model_proto)
static std::shared_ptr<ngraph::Function> import_onnx_model_file(const std::string& filename)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::import_onnx_function(iss);
}
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model_file(const std::string& filename)
{
return ngraph::onnx_import::load_onnx_model(filename);
}
static std::shared_ptr<ngraph::Function> import_onnx_function_file(const std::string& filename)
{
return ngraph::onnx_import::import_onnx_function(filename);
return ngraph::onnx_import::import_onnx_model(filename);
}
void regmodule_pyngraph_onnx_import(py::module mod)
{
mod.def("load_onnx_model", &load_onnx_model);
mod.def("import_onnx_function", &import_onnx_function);
mod.def("load_onnx_model_file", &load_onnx_model_file);
mod.def("import_onnx_function_file", &import_onnx_function_file);
mod.def("import_onnx_model", &import_onnx_model);
mod.def("import_onnx_model_file", &import_onnx_model_file);
}
......@@ -93,4 +93,5 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tan(m_op);
regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Result(m_op);
}
......@@ -68,6 +68,7 @@
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/sign.hpp"
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "ngraph/node.hpp"
#include "ngraph/op/result.hpp"
#include "pyngraph/ops/result.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Result(py::module m)
{
py::class_<ngraph::op::Result, std::shared_ptr<ngraph::op::Result>, ngraph::Node> result(
m, "Result");
result.doc() = "ngraph.impl.op.Result wraps ngraph::op::Result";
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Result(py::module m);
......@@ -27,6 +27,7 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/parameter_vector.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/result_vector.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/serializer.hpp"
#include "pyngraph/shape.hpp"
......@@ -58,4 +59,5 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_runtime(m);
regmodule_pyngraph_passes(m);
regmodule_pyngraph_util(m);
regclass_pyngraph_ResultVector(m);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/result.hpp" // ngraph::op::Result
#include "ngraph/result_vector.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/result_vector.hpp"
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m)
{
py::class_<ngraph::ResultVector, std::shared_ptr<ngraph::ResultVector>> result_vector(
m, "ResultVector");
result_vector.doc() = "ngraph.impl.ResultVector wraps ngraph::ResultVector";
result_vector.def(
py::init<const std::initializer_list<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const std::vector<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const ngraph::ResultVector&>());
result_vector.def("__len__", [](const ngraph::ResultVector& v) { return v.size(); });
result_vector.def("__getitem__", [](const ngraph::ResultVector& v, int key) { return v[key]; });
result_vector.def("__iter__",
[](ngraph::ResultVector& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m);
......@@ -149,6 +149,7 @@ sources = [
'pyngraph/parameter_vector.cpp',
'pyngraph/pyngraph.cpp',
'pyngraph/util.cpp',
'pyngraph/result_vector.cpp',
'pyngraph/ops/util/arithmetic_reduction.cpp',
'pyngraph/ops/util/binary_elementwise_comparison.cpp',
'pyngraph/ops/util/op_annotations.cpp',
......@@ -223,6 +224,7 @@ sources = [
'pyngraph/ops/min.cpp',
'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp',
'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
'pyngraph/runtime/tensor.cpp',
......
......@@ -17,13 +17,13 @@
import os
import numpy as np
from ngraph.impl.onnx_import import load_onnx_model_file
from ngraph.impl.onnx_import import import_onnx_model_file
from test.ngraph.util import get_runtime
def test_import_onnx_function():
model_path = os.path.join(os.path.dirname(__file__), 'models/add_abc.onnx')
ng_function = load_onnx_model_file(model_path)[0]
ng_function = import_onnx_model_file(model_path)
dtype = np.float32
value_a = np.array([1.0], dtype=dtype)
......
......@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_node([input_data], ng_api_fn)
result = run_op_node([input_data], ng_api_fn)[0]
np.testing.assert_allclose(result, expected, rtol=0.001)
result = run_op_numeric_data(input_data, ng_api_fn)
result = run_op_numeric_data(input_data, ng_api_fn)[0]
np.testing.assert_allclose(result, expected, rtol=0.001)
......
......@@ -95,6 +95,16 @@ namespace ngraph
}
}
NodeVector Graph::get_ng_outputs() const
{
NodeVector results;
for (const auto& output : m_graph_proto->output())
{
results.emplace_back(get_ng_node_from_cache(output.name()));
}
return results;
}
} // namespace onnx_import
} // namespace ngraph
......@@ -38,6 +38,7 @@ namespace ngraph
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
NodeVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; }
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const
{
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include <fstream>
#include <memory>
#include "core/graph.hpp"
#include "core/model.hpp"
......@@ -50,45 +51,32 @@ namespace ngraph
} // namespace error
} // namespace detail
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights)
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights)
{
onnx::ModelProto model_proto;
if (!model_proto.ParseFromIstream(&sin))
{
throw detail::error::stream_parse{sin};
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph(), model, weights};
for (const auto& output : graph.get_outputs())
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
output_functions.emplace_back(std::make_shared<Function>(
graph.get_ng_node_from_cache(output.get_name()), graph.get_ng_parameters()));
function->get_output_op(i)->set_name(graph.get_outputs().at(i).get_name());
}
return output_functions;
return function;
}
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& path,
const Weights& weights)
std::shared_ptr<Function> import_onnx_model(const std::string& path, const Weights& weights)
{
std::ifstream ifs{path, std::ios::in | std::ios::binary};
if (!ifs.is_open())
{
throw detail::error::file_open{path};
}
return load_onnx_model(ifs, weights);
}
std::shared_ptr<Function> import_onnx_function(std::istream& sin, const Weights& weights)
{
return load_onnx_model(sin, weights).front();
}
std::shared_ptr<Function> import_onnx_function(const std::string& path,
const Weights& weights)
{
return load_onnx_model(path, weights).front();
return import_onnx_model(ifs, weights);
}
void register_operator(const std::string& name,
......
......@@ -40,31 +40,6 @@ namespace ngraph
const std::string& domain,
Operator fn);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& filename,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
......@@ -74,8 +49,7 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(std::istream& sin,
const Weights& weights = {});
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
......@@ -86,8 +60,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(const std::string& filename,
const Weights& weights = {});
std::shared_ptr<Function> import_onnx_model(const std::string& filename,
const Weights& weights = {});
} // namespace onnx_import
......
......@@ -66,7 +66,7 @@ int main(int argc, char** argv)
ifstream f(input);
if (f)
{
shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_function(input);
std::shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_model(input);
ngraph::stopwatch timer;
timer.start();
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment