Unverified Commit 3dc2a915 authored by Chris Sullivan's avatar Chris Sullivan Committed by GitHub

Merge branch 'master' into master

parents a0446a2f 1b71fdca
...@@ -32,5 +32,7 @@ else: ...@@ -32,5 +32,7 @@ else:
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
sys.setdlopenflags(flags) sys.setdlopenflags(flags)
from _pyngraph_onnx_import import import_onnx_model from _pyngraph_onnx_import import load_onnx_model
from _pyngraph_onnx_import import import_onnx_model_file from _pyngraph_onnx_import import load_onnx_model_file
from _pyngraph_onnx_import import import_onnx_function
from _pyngraph_onnx_import import import_onnx_function_file
...@@ -67,7 +67,6 @@ class Computation(object): ...@@ -67,7 +67,6 @@ class Computation(object):
self.runtime = runtime self.runtime = runtime
self.function = ng_function self.function = ng_function
self.parameters = ng_function.get_parameters() self.parameters = ng_function.get_parameters()
self.results = ng_function.get_results()
self.tensor_views = [] # type: List[Tensor] self.tensor_views = [] # type: List[Tensor]
for parameter in self.parameters: for parameter in self.parameters:
...@@ -75,12 +74,6 @@ class Computation(object): ...@@ -75,12 +74,6 @@ class Computation(object):
element_type = parameter.get_element_type() element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.create_tensor(element_type, shape)) self.tensor_views.append(runtime.backend.create_tensor(element_type, shape))
self.result_views = [] # type: List[Tensor]
for result in self.results:
shape = result.get_shape()
element_type = result.get_element_type()
self.result_views.append(runtime.backend.create_tensor(element_type, shape))
def __repr__(self): # type: () -> str def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters]) params_string = ', '.join([param.name for param in self.parameters])
return '<Computation: {}({})>'.format(self.function.get_name(), params_string) return '<Computation: {}({})>'.format(self.function.get_name(), params_string)
...@@ -92,15 +85,18 @@ class Computation(object): ...@@ -92,15 +85,18 @@ class Computation(object):
value = np.array(value) value = np.array(value)
Computation._write_ndarray_to_tensor_view(value, tensor_view) Computation._write_ndarray_to_tensor_view(value, tensor_view)
self.runtime.backend.call(self.function, self.result_views, self.tensor_views) result_element_type = self.function.get_output_element_type(0)
result_shape = self.function.get_output_shape(0)
result_dtype = get_dtype(result_element_type)
result_view = self.runtime.backend.create_tensor(result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype)
results = [] self.runtime.backend.call(self.function, [result_view], self.tensor_views)
for result_view in self.result_views:
result = np.ndarray(result_view.shape, dtype=get_dtype(result_view.element_type))
Computation._read_tensor_view_to_ndarray(result_view, result)
results.append(result)
return results Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape)
return result_arr
def serialize(self, indent=0): # type: (int) -> str def serialize(self, indent=0): # type: (int) -> str
"""Serialize function (compute graph) to a JSON string. """Serialize function (compute graph) to a JSON string.
......
...@@ -28,19 +28,34 @@ ...@@ -28,19 +28,34 @@
namespace py = pybind11; namespace py = pybind11;
static std::shared_ptr<ngraph::Function> import_onnx_model(const std::string& model_proto) static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model(const std::string& model_proto)
{ {
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in); std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::import_onnx_model(iss); return ngraph::onnx_import::load_onnx_model(iss);
} }
static std::shared_ptr<ngraph::Function> import_onnx_model_file(const std::string& filename) static std::shared_ptr<ngraph::Function> import_onnx_function(const std::string& model_proto)
{ {
return ngraph::onnx_import::import_onnx_model(filename); std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::import_onnx_function(iss);
}
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model_file(const std::string& filename)
{
return ngraph::onnx_import::load_onnx_model(filename);
}
static std::shared_ptr<ngraph::Function> import_onnx_function_file(const std::string& filename)
{
return ngraph::onnx_import::import_onnx_function(filename);
} }
void regmodule_pyngraph_onnx_import(py::module mod) void regmodule_pyngraph_onnx_import(py::module mod)
{ {
mod.def("import_onnx_model", &import_onnx_model); mod.def("load_onnx_model", &load_onnx_model);
mod.def("import_onnx_model_file", &import_onnx_model_file); mod.def("import_onnx_function", &import_onnx_function);
mod.def("load_onnx_model_file", &load_onnx_model_file);
mod.def("import_onnx_function_file", &import_onnx_function_file);
} }
...@@ -93,5 +93,4 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -93,5 +93,4 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tan(m_op); regclass_pyngraph_op_Tan(m_op);
regclass_pyngraph_op_Tanh(m_op); regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op); regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Result(m_op);
} }
...@@ -68,7 +68,6 @@ ...@@ -68,7 +68,6 @@
#include "pyngraph/ops/relu.hpp" #include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp" #include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp" #include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/ops/reverse.hpp" #include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/select.hpp" #include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/sign.hpp" #include "pyngraph/ops/sign.hpp"
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp" #include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/parameter_vector.hpp" #include "pyngraph/parameter_vector.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp" #include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/result_vector.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp" #include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/serializer.hpp" #include "pyngraph/serializer.hpp"
#include "pyngraph/shape.hpp" #include "pyngraph/shape.hpp"
...@@ -59,5 +58,4 @@ PYBIND11_MODULE(_pyngraph, m) ...@@ -59,5 +58,4 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_runtime(m); regmodule_pyngraph_runtime(m);
regmodule_pyngraph_passes(m); regmodule_pyngraph_passes(m);
regmodule_pyngraph_util(m); regmodule_pyngraph_util(m);
regclass_pyngraph_ResultVector(m);
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/result.hpp" // ngraph::op::Result
#include "ngraph/result_vector.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/result_vector.hpp"
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m)
{
py::class_<ngraph::ResultVector, std::shared_ptr<ngraph::ResultVector>> result_vector(
m, "ResultVector");
result_vector.doc() = "ngraph.impl.ResultVector wraps ngraph::ResultVector";
result_vector.def(
py::init<const std::initializer_list<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const std::vector<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const ngraph::ResultVector&>());
result_vector.def("__len__", [](const ngraph::ResultVector& v) { return v.size(); });
result_vector.def("__getitem__", [](const ngraph::ResultVector& v, int key) { return v[key]; });
result_vector.def("__iter__",
[](ngraph::ResultVector& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m);
...@@ -149,7 +149,6 @@ sources = [ ...@@ -149,7 +149,6 @@ sources = [
'pyngraph/parameter_vector.cpp', 'pyngraph/parameter_vector.cpp',
'pyngraph/pyngraph.cpp', 'pyngraph/pyngraph.cpp',
'pyngraph/util.cpp', 'pyngraph/util.cpp',
'pyngraph/result_vector.cpp',
'pyngraph/ops/util/arithmetic_reduction.cpp', 'pyngraph/ops/util/arithmetic_reduction.cpp',
'pyngraph/ops/util/binary_elementwise_comparison.cpp', 'pyngraph/ops/util/binary_elementwise_comparison.cpp',
'pyngraph/ops/util/op_annotations.cpp', 'pyngraph/ops/util/op_annotations.cpp',
...@@ -224,7 +223,6 @@ sources = [ ...@@ -224,7 +223,6 @@ sources = [
'pyngraph/ops/min.cpp', 'pyngraph/ops/min.cpp',
'pyngraph/ops/batch_norm.cpp', 'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp', 'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp',
'pyngraph/runtime/backend.cpp', 'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp', 'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
'pyngraph/runtime/tensor.cpp', 'pyngraph/runtime/tensor.cpp',
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
import os import os
import numpy as np import numpy as np
from ngraph.impl.onnx_import import import_onnx_model_file from ngraph.impl.onnx_import import load_onnx_model_file
from test.ngraph.util import get_runtime from test.ngraph.util import get_runtime
def test_import_onnx_function(): def test_import_onnx_function():
model_path = os.path.join(os.path.dirname(__file__), 'models/add_abc.onnx') model_path = os.path.join(os.path.dirname(__file__), 'models/add_abc.onnx')
ng_function = import_onnx_model_file(model_path) ng_function = load_onnx_model_file(model_path)[0]
dtype = np.float32 dtype = np.float32
value_a = np.array([1.0], dtype=dtype) value_a = np.array([1.0], dtype=dtype)
......
...@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end): ...@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start) input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data) expected = numpy_fn(input_data)
result = run_op_node([input_data], ng_api_fn)[0] result = run_op_node([input_data], ng_api_fn)
np.testing.assert_allclose(result, expected, rtol=0.001) np.testing.assert_allclose(result, expected, rtol=0.001)
result = run_op_numeric_data(input_data, ng_api_fn)[0] result = run_op_numeric_data(input_data, ng_api_fn)
np.testing.assert_allclose(result, expected, rtol=0.001) np.testing.assert_allclose(result, expected, rtol=0.001)
......
...@@ -102,6 +102,8 @@ add_library(onnx_import STATIC ...@@ -102,6 +102,8 @@ add_library(onnx_import STATIC
op/neg.hpp op/neg.hpp
op/not.hpp op/not.hpp
op/or.hpp op/or.hpp
op/pad.cpp
op/pad.hpp
op/pow.hpp op/pow.hpp
op/prelu.cpp op/prelu.cpp
op/prelu.hpp op/prelu.hpp
......
...@@ -95,16 +95,6 @@ namespace ngraph ...@@ -95,16 +95,6 @@ namespace ngraph
} }
} }
NodeVector Graph::get_ng_outputs() const
{
NodeVector results;
for (const auto& output : m_graph_proto->output())
{
results.emplace_back(get_ng_node_from_cache(output.name()));
}
return results;
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -38,7 +38,6 @@ namespace ngraph ...@@ -38,7 +38,6 @@ namespace ngraph
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; } const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
NodeVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; } const ParameterVector& get_ng_parameters() const { return m_parameters; }
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const
{ {
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
//***************************************************************************** //*****************************************************************************
#include <fstream> #include <fstream>
#include <memory>
#include "core/graph.hpp" #include "core/graph.hpp"
#include "core/model.hpp" #include "core/model.hpp"
...@@ -51,32 +50,45 @@ namespace ngraph ...@@ -51,32 +50,45 @@ namespace ngraph
} // namespace error } // namespace error
} // namespace detail } // namespace detail
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights) std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights)
{ {
onnx::ModelProto model_proto; onnx::ModelProto model_proto;
if (!model_proto.ParseFromIstream(&sin)) if (!model_proto.ParseFromIstream(&sin))
{ {
throw detail::error::stream_parse{sin}; throw detail::error::stream_parse{sin};
} }
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto}; Model model{model_proto};
Graph graph{model_proto.graph(), model, weights}; Graph graph{model_proto.graph(), model, weights};
auto function = std::make_shared<Function>( for (const auto& output : graph.get_outputs())
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{ {
function->get_output_op(i)->set_name(graph.get_outputs().at(i).get_name()); output_functions.emplace_back(std::make_shared<Function>(
graph.get_ng_node_from_cache(output.get_name()), graph.get_ng_parameters()));
} }
return function; return output_functions;
} }
std::shared_ptr<Function> import_onnx_model(const std::string& path, const Weights& weights) std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& path,
const Weights& weights)
{ {
std::ifstream ifs{path, std::ios::in | std::ios::binary}; std::ifstream ifs{path, std::ios::in | std::ios::binary};
if (!ifs.is_open()) if (!ifs.is_open())
{ {
throw detail::error::file_open{path}; throw detail::error::file_open{path};
} }
return import_onnx_model(ifs, weights); return load_onnx_model(ifs, weights);
}
std::shared_ptr<Function> import_onnx_function(std::istream& sin, const Weights& weights)
{
return load_onnx_model(sin, weights).front();
}
std::shared_ptr<Function> import_onnx_function(const std::string& path,
const Weights& weights)
{
return load_onnx_model(path, weights).front();
} }
void register_operator(const std::string& name, void register_operator(const std::string& name,
......
...@@ -40,6 +40,31 @@ namespace ngraph ...@@ -40,6 +40,31 @@ namespace ngraph
const std::string& domain, const std::string& domain,
Operator fn); Operator fn);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& filename,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph function /// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized /// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream. /// ONNX model is read from input stream.
...@@ -49,7 +74,8 @@ namespace ngraph ...@@ -49,7 +74,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from /// and providing through this parameters is invalid (the weights from
/// the model will take precedence). /// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph. /// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights = {}); std::shared_ptr<Function> import_onnx_function(std::istream& sin,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions /// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model /// The function translated serialized ONNX model to nGraph functions. The ONNX model
...@@ -60,7 +86,7 @@ namespace ngraph ...@@ -60,7 +86,7 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from /// and providing through this parameters is invalid (the weights from
/// the model will take precedence). /// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph. /// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_model(const std::string& filename, std::shared_ptr<Function> import_onnx_function(const std::string& filename,
const Weights& weights = {}); const Weights& weights = {});
} // namespace onnx_import } // namespace onnx_import
......
...@@ -40,13 +40,11 @@ namespace ngraph ...@@ -40,13 +40,11 @@ namespace ngraph
std::shared_ptr<ngraph::Node> var{nullptr}; std::shared_ptr<ngraph::Node> var{nullptr};
std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)}; std::int64_t is_test{node.get_attribute_value<std::int64_t>("is_test", 1)};
std::int64_t spatial{node.get_attribute_value<std::int64_t>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support // TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
......
...@@ -14,19 +14,47 @@ ...@@ -14,19 +14,47 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <pybind11/pybind11.h> #include <memory>
#include <pybind11/stl.h>
#include <string>
#include "ngraph/node.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/frontend/onnx_import/op/pad.hpp"
#include "pyngraph/ops/result.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/shape.hpp"
namespace py = pybind11; namespace ngraph
void regclass_pyngraph_op_Result(py::module m)
{ {
py::class_<ngraph::op::Result, std::shared_ptr<ngraph::op::Result>, ngraph::Node> result( namespace onnx_import
m, "Result"); {
result.doc() = "ngraph.impl.op.Result wraps ngraph::op::Result"; namespace op
} {
namespace set_1
{
NodeVector pad(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
const Shape& data_shape = data->get_shape();
double value = node.get_attribute_value<double>("value", 0);
auto paddings = convpool::get_pads(node, data_shape);
ngraph::CoordinateDiff padding_below = paddings.first;
ngraph::CoordinateDiff padding_above = paddings.second;
return {std::make_shared<ngraph::op::Pad>(
data,
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{value}),
Shape(padding_below.begin(), padding_below.end()),
Shape(padding_above.begin(), padding_above.end()),
Shape(data_shape.size(), 0))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -16,8 +16,23 @@ ...@@ -16,8 +16,23 @@
#pragma once #pragma once
#include <pybind11/pybind11.h> #include "ngraph/frontend/onnx_import/core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace py = pybind11; namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector pad(const Node& node);
void regclass_pyngraph_op_Result(py::module m); } // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -66,6 +66,8 @@ ...@@ -66,6 +66,8 @@
#include "op/neg.hpp" #include "op/neg.hpp"
#include "op/not.hpp" #include "op/not.hpp"
#include "op/or.hpp" #include "op/or.hpp"
#include "op/pad.cpp"
#include "op/pad.hpp"
#include "op/pow.hpp" #include "op/pow.hpp"
#include "op/prelu.hpp" #include "op/prelu.hpp"
#include "op/reciprocal.hpp" #include "op/reciprocal.hpp"
...@@ -195,6 +197,7 @@ namespace ngraph ...@@ -195,6 +197,7 @@ namespace ngraph
REGISTER_OPERATOR("Neg", 1, neg); REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("Not", 1, logical_not); REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or); REGISTER_OPERATOR("Or", 1, logical_or);
REGISTER_OPERATOR("Pad", 1, pad);
REGISTER_OPERATOR("Pow", 1, pow); REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu); REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal); REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
......
...@@ -66,7 +66,7 @@ int main(int argc, char** argv) ...@@ -66,7 +66,7 @@ int main(int argc, char** argv)
ifstream f(input); ifstream f(input);
if (f) if (f)
{ {
std::shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_model(input); shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_function(input);
ngraph::stopwatch timer; ngraph::stopwatch timer;
timer.start(); timer.start();
......
...@@ -31,24 +31,12 @@ using namespace ngraph; ...@@ -31,24 +31,12 @@ using namespace ngraph;
using Inputs = std::vector<std::vector<float>>; using Inputs = std::vector<std::vector<float>>;
using Outputs = std::vector<std::vector<float>>; using Outputs = std::vector<std::vector<float>>;
using Model = std::vector<std::shared_ptr<Function>>;
TEST(onnx, model_output_names_check)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx"));
std::size_t size = function->get_output_size();
for (std::size_t i{0}; i < size; ++i)
{
std::shared_ptr<Node> node = function->get_output_op(i);
EXPECT_EQ(node->get_friendly_name(), "output_" + std::to_string(i + 1));
}
}
TEST(onnx, model_add_abc) TEST(onnx, model_add_abc)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.onnx"));
Inputs inputs{{1}, {2}, {3}}; Inputs inputs{{1}, {2}, {3}};
Outputs expected_outputs{{6}}; Outputs expected_outputs{{6}};
...@@ -59,7 +47,7 @@ TEST(onnx, model_add_abc) ...@@ -59,7 +47,7 @@ TEST(onnx, model_add_abc)
TEST(onnx, model_add_abc_initializers) TEST(onnx, model_add_abc_initializers)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc_initializers.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc_initializers.onnx"));
Inputs inputs{{1, 2, 3, 4}}; Inputs inputs{{1, 2, 3, 4}};
...@@ -71,7 +59,7 @@ TEST(onnx, model_add_abc_initializers) ...@@ -71,7 +59,7 @@ TEST(onnx, model_add_abc_initializers)
TEST(onnx, model_addmul_abc) TEST(onnx, model_addmul_abc)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));
std::vector<std::vector<float>> inputs; std::vector<std::vector<float>> inputs;
...@@ -89,7 +77,7 @@ TEST(onnx, model_addmul_abc) ...@@ -89,7 +77,7 @@ TEST(onnx, model_addmul_abc)
TEST(onnx, model_argmin_no_keepdims) TEST(onnx, model_argmin_no_keepdims)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_no_keepdims.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_no_keepdims.onnx"));
Inputs inputs{test::NDArray<float, 2>{{2, 1}, {3, 10}}.get_vector()}; Inputs inputs{test::NDArray<float, 2>{{2, 1}, {3, 10}}.get_vector()};
...@@ -101,57 +89,51 @@ TEST(onnx, model_argmin_no_keepdims) ...@@ -101,57 +89,51 @@ TEST(onnx, model_argmin_no_keepdims)
TEST(onnx, model_split_equal_parts_default) TEST(onnx, model_split_equal_parts_default)
{ {
auto function = onnx_import::import_onnx_model( Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx"))};
Inputs inputs{{1, 2, 3, 4, 5, 6}}; Inputs inputs{{1, 2, 3, 4, 5, 6}};
Outputs expected_outputs{{1, 2}, {3, 4}, {5, 6}}; Outputs expected_outputs{{1, 2}, {3, 4}, {5, 6}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i) for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{ {
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size()); Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i])); EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
} }
} }
TEST(onnx, model_split_equal_parts_2d) TEST(onnx, model_split_equal_parts_2d)
{ {
// Split into 2 equal parts along axis=1 // Split into 2 equal parts along axis=1
auto function = onnx_import::import_onnx_model( Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_2d.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_2d.onnx"))};
Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}; Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
Outputs expected_outputs{{0, 1, 2, 6, 7, 8}, {3, 4, 5, 9, 10, 11}}; Outputs expected_outputs{{0, 1, 2, 6, 7, 8}, {3, 4, 5, 9, 10, 11}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i) for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{ {
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size()); Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i])); EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
} }
} }
TEST(onnx, model_split_variable_parts_2d) TEST(onnx, model_split_variable_parts_2d)
{ {
// Split into variable parts {2, 4} along axis=1 // Split into variable parts {2, 4} along axis=1
auto function = onnx_import::import_onnx_model( Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_variable_parts_2d.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/split_variable_parts_2d.onnx"))};
Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}; Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
Outputs expected_outputs{{0, 1, 6, 7}, {2, 3, 4, 5, 8, 9, 10, 11}}; Outputs expected_outputs{{0, 1, 6, 7}, {2, 3, 4, 5, 8, 9, 10, 11}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i) for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{ {
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size()); Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i])); EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
} }
} }
...@@ -178,13 +160,12 @@ namespace ...@@ -178,13 +160,12 @@ namespace
return execute(function, args, "INTERPRETER"); return execute(function, args, "INTERPRETER");
} }
} // namespace } // namespace
TEST(onnx, model_conv2d_strides_padding) TEST(onnx, model_conv2d_strides_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx"));
// (1, 1, 4, 3) // (1, 1, 4, 3)
...@@ -201,7 +182,7 @@ TEST(onnx, model_conv2d_strides_padding) ...@@ -201,7 +182,7 @@ TEST(onnx, model_conv2d_strides_padding)
TEST(onnx, model_conv2d_strides_no_padding) TEST(onnx, model_conv2d_strides_no_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx"));
// (1, 1, 3, 2) // (1, 1, 3, 2)
...@@ -215,7 +196,7 @@ TEST(onnx, model_conv2d_strides_no_padding) ...@@ -215,7 +196,7 @@ TEST(onnx, model_conv2d_strides_no_padding)
TEST(onnx, model_conv2d_strides_assymetric_padding) TEST(onnx, model_conv2d_strides_assymetric_padding)
{ {
// Convolution with strides=2 and padding=1 // Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx"));
// (1, 1, 4, 2) // (1, 1, 4, 2)
...@@ -230,7 +211,7 @@ TEST(onnx, model_conv2d_strides_assymetric_padding) ...@@ -230,7 +211,7 @@ TEST(onnx, model_conv2d_strides_assymetric_padding)
TEST(onnx, model_average_pool_2d) TEST(onnx, model_average_pool_2d)
{ {
// Pooling with strides=2 and no padding // Pooling with strides=2 and no padding
auto function = onnx_import::import_onnx_model( auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -244,7 +225,7 @@ TEST(onnx, model_average_pool_2d) ...@@ -244,7 +225,7 @@ TEST(onnx, model_average_pool_2d)
// (1, 1, 2, 2) // (1, 1, 2, 2)
auto expected_output = test::NDArray<float, 4>({{{{2.5f, 4.5f}, {10.5f, 12.5f}}}}).get_vector(); auto expected_output = test::NDArray<float, 4>({{{{2.5f, 4.5f}, {10.5f, 12.5f}}}}).get_vector();
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(model, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front()); EXPECT_EQ(expected_output, outputs.front());
} }
...@@ -252,7 +233,7 @@ TEST(onnx, model_average_pool_2d) ...@@ -252,7 +233,7 @@ TEST(onnx, model_average_pool_2d)
TEST(onnx, model_average_pool_2d_pads) TEST(onnx, model_average_pool_2d_pads)
{ {
// Pooling with strides=2 and padding=1 // Pooling with strides=2 and padding=1
auto function = onnx_import::import_onnx_model( auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -268,7 +249,7 @@ TEST(onnx, model_average_pool_2d_pads) ...@@ -268,7 +249,7 @@ TEST(onnx, model_average_pool_2d_pads)
test::NDArray<float, 4>({{{{0.f, 1.5f, 3.f}, {6.f, 7.5f, 9.f}, {12.f, 13.5f, 15.f}}}}) test::NDArray<float, 4>({{{{0.f, 1.5f, 3.f}, {6.f, 7.5f, 9.f}, {12.f, 13.5f, 15.f}}}})
.get_vector(); .get_vector();
Outputs outputs = execute(function, inputs, "INTERPRETER"); Outputs outputs = execute(model, inputs, "INTERPRETER");
EXPECT_EQ(expected_output, outputs.front()); EXPECT_EQ(expected_output, outputs.front());
} }
...@@ -276,7 +257,7 @@ TEST(onnx, model_average_pool_2d_pads) ...@@ -276,7 +257,7 @@ TEST(onnx, model_average_pool_2d_pads)
TEST(onnx, model_max_pool_2d_pads) TEST(onnx, model_max_pool_2d_pads)
{ {
// Pooling with strides=2 and padding=1 // Pooling with strides=2 and padding=1
auto function = onnx_import::import_onnx_model( auto model = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -292,7 +273,7 @@ TEST(onnx, model_max_pool_2d_pads) ...@@ -292,7 +273,7 @@ TEST(onnx, model_max_pool_2d_pads)
test::NDArray<float, 4>({{{{0.f, 2.f, 3.f}, {8.f, 10.f, 11.f}, {12.f, 14.f, 15.f}}}}) test::NDArray<float, 4>({{{{0.f, 2.f, 3.f}, {8.f, 10.f, 11.f}, {12.f, 14.f, 15.f}}}})
.get_vector(); .get_vector();
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(model, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front()); EXPECT_EQ(expected_output, outputs.front());
} }
...@@ -300,8 +281,8 @@ TEST(onnx, model_max_pool_2d_pads) ...@@ -300,8 +281,8 @@ TEST(onnx, model_max_pool_2d_pads)
TEST(onnx, model_batchnorm_default) TEST(onnx, model_batchnorm_default)
{ {
// Batch Normalization with default parameters // Batch Normalization with default parameters
auto function = onnx_import::import_onnx_model( Model model{onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/batchnorm_default.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/batchnorm_default.onnx"))};
Inputs inputs; Inputs inputs;
...@@ -323,7 +304,7 @@ TEST(onnx, model_batchnorm_default) ...@@ -323,7 +304,7 @@ TEST(onnx, model_batchnorm_default)
{{{{-0.999995f, 0.f, 0.999995f}}, {{-0.22474074f, 1.f, 2.2247407f}}}}} {{{{-0.999995f, 0.f, 0.999995f}}, {{-0.22474074f, 1.f, 2.2247407f}}}}}
.get_vector()}; .get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")}; Outputs outputs{execute(model.front(), inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
...@@ -331,7 +312,7 @@ TEST(onnx, model_relu) ...@@ -331,7 +312,7 @@ TEST(onnx, model_relu)
{ {
// Simple ReLU test // Simple ReLU test
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx"));
Inputs inputs{{-1, -2, 0, 1, 2, 3}}; Inputs inputs{{-1, -2, 0, 1, 2, 3}};
Outputs expected_outputs{{0, 0, 0, 1, 2, 3}}; Outputs expected_outputs{{0, 0, 0, 1, 2, 3}};
...@@ -344,7 +325,7 @@ TEST(onnx, model_sum) ...@@ -344,7 +325,7 @@ TEST(onnx, model_sum)
{ {
// Simple Sum test // Simple Sum test
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
// input data shape (3, ) // input data shape (3, )
Inputs inputs; Inputs inputs;
...@@ -359,7 +340,7 @@ TEST(onnx, model_sum) ...@@ -359,7 +340,7 @@ TEST(onnx, model_sum)
TEST(onnx, model_sum_one_input) TEST(onnx, model_sum_one_input)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx"));
// input data shape (3, ) // input data shape (3, )
...@@ -371,7 +352,7 @@ TEST(onnx, model_sum_one_input) ...@@ -371,7 +352,7 @@ TEST(onnx, model_sum_one_input)
TEST(onnx, model_min_two_inputs) TEST(onnx, model_min_two_inputs)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx"));
// input data shape (3, ) // input data shape (3, )
...@@ -387,7 +368,7 @@ TEST(onnx, model_min_two_inputs) ...@@ -387,7 +368,7 @@ TEST(onnx, model_min_two_inputs)
TEST(onnx, model_max) TEST(onnx, model_max)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
// input data shape (3, ) // input data shape (3, )
Inputs inputs; Inputs inputs;
...@@ -403,7 +384,7 @@ TEST(onnx, model_max) ...@@ -403,7 +384,7 @@ TEST(onnx, model_max)
TEST(onnx, model_mean) TEST(onnx, model_mean)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
// input data shape (3, ) // input data shape (3, )
Inputs inputs; Inputs inputs;
...@@ -418,8 +399,8 @@ TEST(onnx, model_mean) ...@@ -418,8 +399,8 @@ TEST(onnx, model_mean)
TEST(onnx, model_gemm_abc) TEST(onnx, model_gemm_abc)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 2>( inputs.emplace_back(test::NDArray<float, 2>(
...@@ -449,7 +430,7 @@ TEST(onnx, model_gemm_abc) ...@@ -449,7 +430,7 @@ TEST(onnx, model_gemm_abc)
TEST(onnx, model_matmul) TEST(onnx, model_matmul)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul.onnx"));
std::vector<std::vector<float>> inputs; std::vector<std::vector<float>> inputs;
...@@ -469,8 +450,8 @@ TEST(onnx, model_matmul) ...@@ -469,8 +450,8 @@ TEST(onnx, model_matmul)
TEST(onnx, model_softmax) TEST(onnx, model_softmax)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -525,7 +506,7 @@ TEST(onnx, model_softmax) ...@@ -525,7 +506,7 @@ TEST(onnx, model_softmax)
TEST(onnx, model_concat) TEST(onnx, model_concat)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/concat.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/concat.onnx"));
Inputs inputs; Inputs inputs;
...@@ -540,8 +521,8 @@ TEST(onnx, model_concat) ...@@ -540,8 +521,8 @@ TEST(onnx, model_concat)
TEST(onnx, model_flatten) TEST(onnx, model_flatten)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/flatten.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/flatten.onnx"));
Inputs inputs; Inputs inputs;
...@@ -557,7 +538,7 @@ TEST(onnx, model_flatten) ...@@ -557,7 +538,7 @@ TEST(onnx, model_flatten)
TEST(onnx, model_sub) TEST(onnx, model_sub)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
...@@ -572,8 +553,8 @@ TEST(onnx, model_sub) ...@@ -572,8 +553,8 @@ TEST(onnx, model_sub)
TEST(onnx, model_unsqueeze) TEST(onnx, model_unsqueeze)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>( inputs.emplace_back(test::NDArray<float, 3>(
...@@ -595,7 +576,7 @@ TEST(onnx, model_unsqueeze) ...@@ -595,7 +576,7 @@ TEST(onnx, model_unsqueeze)
TEST(onnx, model_squeeze) TEST(onnx, model_squeeze)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.onnx"));
// {1, 4, 1, 1, 2} // {1, 4, 1, 1, 2}
...@@ -615,7 +596,7 @@ TEST(onnx, model_squeeze) ...@@ -615,7 +596,7 @@ TEST(onnx, model_squeeze)
TEST(onnx, model_div) TEST(onnx, model_div)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector()); inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
...@@ -630,8 +611,8 @@ TEST(onnx, model_div) ...@@ -630,8 +611,8 @@ TEST(onnx, model_div)
TEST(onnx, model_add_bcast) TEST(onnx, model_add_bcast)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/add_bcast.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/add_bcast.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>( inputs.emplace_back(test::NDArray<float, 3>(
...@@ -655,7 +636,7 @@ TEST(onnx, model_add_bcast) ...@@ -655,7 +636,7 @@ TEST(onnx, model_add_bcast)
TEST(onnx, model_reshape_reduced_dims) TEST(onnx, model_reshape_reduced_dims)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reduced_dims.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reduced_dims.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -675,7 +656,7 @@ TEST(onnx, model_reshape_reduced_dims) ...@@ -675,7 +656,7 @@ TEST(onnx, model_reshape_reduced_dims)
TEST(onnx, model_reshape_reordered_dims) TEST(onnx, model_reshape_reordered_dims)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reordered_dims.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reordered_dims.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -696,7 +677,7 @@ TEST(onnx, model_reshape_reordered_dims) ...@@ -696,7 +677,7 @@ TEST(onnx, model_reshape_reordered_dims)
TEST(onnx, model_reshape_extended_dims) TEST(onnx, model_reshape_extended_dims)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_extended_dims.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_extended_dims.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -716,7 +697,7 @@ TEST(onnx, model_reshape_extended_dims) ...@@ -716,7 +697,7 @@ TEST(onnx, model_reshape_extended_dims)
TEST(onnx, model_reshape_single_dim) TEST(onnx, model_reshape_single_dim)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_single_dim.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_single_dim.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -736,7 +717,7 @@ TEST(onnx, model_reshape_single_dim) ...@@ -736,7 +717,7 @@ TEST(onnx, model_reshape_single_dim)
TEST(onnx, model_reshape_negative_dim) TEST(onnx, model_reshape_negative_dim)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -759,7 +740,7 @@ TEST(onnx, model_reshape_negative_dim) ...@@ -759,7 +740,7 @@ TEST(onnx, model_reshape_negative_dim)
TEST(onnx, model_reshape_negative_with_zero_dim) TEST(onnx, model_reshape_negative_with_zero_dim)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_with_zero_dims.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_with_zero_dims.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -779,7 +760,7 @@ TEST(onnx, model_reshape_negative_with_zero_dim) ...@@ -779,7 +760,7 @@ TEST(onnx, model_reshape_negative_with_zero_dim)
TEST(onnx, model_reshape_output_shape_as_input) TEST(onnx, model_reshape_output_shape_as_input)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_output_shape_as_input.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_output_shape_as_input.onnx"));
// input data shape (2, 3, 4) // input data shape (2, 3, 4)
...@@ -799,7 +780,7 @@ TEST(onnx, model_reshape_output_shape_as_input) ...@@ -799,7 +780,7 @@ TEST(onnx, model_reshape_output_shape_as_input)
TEST(onnx, model_reduce_log_sum) TEST(onnx, model_reduce_log_sum)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -816,7 +797,7 @@ TEST(onnx, model_reduce_log_sum) ...@@ -816,7 +797,7 @@ TEST(onnx, model_reduce_log_sum)
TEST(onnx, model_reduce_log_sum_exp) TEST(onnx, model_reduce_log_sum_exp)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum_exp.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum_exp.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -833,8 +814,8 @@ TEST(onnx, model_reduce_log_sum_exp) ...@@ -833,8 +814,8 @@ TEST(onnx, model_reduce_log_sum_exp)
TEST(onnx, model_reduce_l1) TEST(onnx, model_reduce_l1)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l1.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l1.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
Inputs inputs{ Inputs inputs{
...@@ -850,8 +831,8 @@ TEST(onnx, model_reduce_l1) ...@@ -850,8 +831,8 @@ TEST(onnx, model_reduce_l1)
TEST(onnx, model_reduce_l2) TEST(onnx, model_reduce_l2)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l2.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l2.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
Inputs inputs{ Inputs inputs{
...@@ -867,7 +848,7 @@ TEST(onnx, model_reduce_l2) ...@@ -867,7 +848,7 @@ TEST(onnx, model_reduce_l2)
TEST(onnx, model_reduce_max) TEST(onnx, model_reduce_max)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_max.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_max.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -884,7 +865,7 @@ TEST(onnx, model_reduce_max) ...@@ -884,7 +865,7 @@ TEST(onnx, model_reduce_max)
TEST(onnx, model_reduce_mean) TEST(onnx, model_reduce_mean)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_mean.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_mean.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -901,7 +882,7 @@ TEST(onnx, model_reduce_mean) ...@@ -901,7 +882,7 @@ TEST(onnx, model_reduce_mean)
TEST(onnx, model_reduce_min) TEST(onnx, model_reduce_min)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_min.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_min.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -918,7 +899,7 @@ TEST(onnx, model_reduce_min) ...@@ -918,7 +899,7 @@ TEST(onnx, model_reduce_min)
TEST(onnx, model_reduce_prod) TEST(onnx, model_reduce_prod)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_prod.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_prod.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -935,7 +916,7 @@ TEST(onnx, model_reduce_prod) ...@@ -935,7 +916,7 @@ TEST(onnx, model_reduce_prod)
TEST(onnx, model_reduce_sum) TEST(onnx, model_reduce_sum)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -952,7 +933,7 @@ TEST(onnx, model_reduce_sum) ...@@ -952,7 +933,7 @@ TEST(onnx, model_reduce_sum)
TEST(onnx, model_reduce_sum_square) TEST(onnx, model_reduce_sum_square)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_square.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_square.onnx"));
// input data shape (1, 1, 4, 4) // input data shape (1, 1, 4, 4)
...@@ -970,7 +951,7 @@ TEST(onnx, model_reduce_sum_square) ...@@ -970,7 +951,7 @@ TEST(onnx, model_reduce_sum_square)
TEST(onnx, model_shape) TEST(onnx, model_shape)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/shape.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/shape.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>( inputs.emplace_back(test::NDArray<float, 3>(
...@@ -989,7 +970,7 @@ TEST(onnx, model_shape) ...@@ -989,7 +970,7 @@ TEST(onnx, model_shape)
TEST(onnx, model_elu) TEST(onnx, model_elu)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/elu.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/elu.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -1035,7 +1016,7 @@ TEST(onnx, model_elu) ...@@ -1035,7 +1016,7 @@ TEST(onnx, model_elu)
TEST(onnx, model_leaky_relu) TEST(onnx, model_leaky_relu)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/leaky_relu.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/leaky_relu.onnx"));
Inputs inputs; Inputs inputs;
...@@ -1067,7 +1048,7 @@ TEST(onnx, model_leaky_relu) ...@@ -1067,7 +1048,7 @@ TEST(onnx, model_leaky_relu)
TEST(onnx, prelu) TEST(onnx, prelu)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/prelu.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/prelu.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -1097,7 +1078,7 @@ TEST(onnx, prelu) ...@@ -1097,7 +1078,7 @@ TEST(onnx, prelu)
TEST(onnx, model_selu) TEST(onnx, model_selu)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/selu.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/selu.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -1137,8 +1118,8 @@ TEST(onnx, model_selu) ...@@ -1137,8 +1118,8 @@ TEST(onnx, model_selu)
TEST(onnx, model_sigmoid) TEST(onnx, model_sigmoid)
{ {
auto function = auto function = onnx_import::import_onnx_function(
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sigmoid.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/sigmoid.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -1213,7 +1194,7 @@ TEST(onnx, model_sigmoid) ...@@ -1213,7 +1194,7 @@ TEST(onnx, model_sigmoid)
TEST(onnx, model_tanh) TEST(onnx, model_tanh)
{ {
auto function = auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tanh.onnx")); onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/tanh.onnx"));
Inputs inputs; Inputs inputs;
inputs.emplace_back( inputs.emplace_back(
...@@ -1287,7 +1268,7 @@ TEST(onnx, model_tanh) ...@@ -1287,7 +1268,7 @@ TEST(onnx, model_tanh)
TEST(onnx, model_thresholded_relu) TEST(onnx, model_thresholded_relu)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/thresholded_relu.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/thresholded_relu.onnx"));
Inputs inputs; Inputs inputs;
...@@ -1313,7 +1294,7 @@ TEST(onnx, model_unsupported_op) ...@@ -1313,7 +1294,7 @@ TEST(onnx, model_unsupported_op)
{ {
try try
{ {
onnx_import::import_onnx_model( onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsupported_op.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/unsupported_op.onnx"));
FAIL() << "Expected ngraph::ngraph_error"; FAIL() << "Expected ngraph::ngraph_error";
} }
...@@ -1338,7 +1319,7 @@ TEST(onnx, model_custom_op) ...@@ -1338,7 +1319,7 @@ TEST(onnx, model_custom_op)
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}); });
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.onnx"));
Inputs inputs{{1, 2, 3, 4}}; Inputs inputs{{1, 2, 3, 4}};
...@@ -1356,7 +1337,7 @@ TEST(onnx, model_custom_op_default_domain) ...@@ -1356,7 +1337,7 @@ TEST(onnx, model_custom_op_default_domain)
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}); });
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator_default_domain.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator_default_domain.onnx"));
Inputs inputs{{1, 2, 3, 4}}; Inputs inputs{{1, 2, 3, 4}};
...@@ -1368,7 +1349,7 @@ TEST(onnx, model_custom_op_default_domain) ...@@ -1368,7 +1349,7 @@ TEST(onnx, model_custom_op_default_domain)
TEST(onnx, model_conv2d_dilation_assymetric_pads_strides) TEST(onnx, model_conv2d_dilation_assymetric_pads_strides)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv2d_dilation_assym_pads_strides.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv2d_dilation_assym_pads_strides.onnx"));
// "", // auto_pad // "", // auto_pad
...@@ -1408,7 +1389,7 @@ TEST(onnx, model_conv2d_dilation_assymetric_pads_strides) ...@@ -1408,7 +1389,7 @@ TEST(onnx, model_conv2d_dilation_assymetric_pads_strides)
TEST(onnx, model_conv3d_bias) TEST(onnx, model_conv3d_bias)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv3d_bias.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/conv3d_bias.onnx"));
// "", // auto_pad // "", // auto_pad
...@@ -1522,7 +1503,7 @@ TEST(onnx, model_conv3d_bias) ...@@ -1522,7 +1503,7 @@ TEST(onnx, model_conv3d_bias)
TEST(onnx, model_matmul_vec_ten3d) TEST(onnx, model_matmul_vec_ten3d)
{ {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_vec_ten3d.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_vec_ten3d.onnx"));
Inputs inputs; Inputs inputs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment