Unverified Commit a0446a2f authored by Chris Sullivan's avatar Chris Sullivan Committed by GitHub

Merge branch 'master' into master

parents a1392e91 c5b082c6
......@@ -32,7 +32,5 @@ else:
flags = sys.getdlopenflags() | ctypes.RTLD_GLOBAL
sys.setdlopenflags(flags)
from _pyngraph_onnx_import import load_onnx_model
from _pyngraph_onnx_import import load_onnx_model_file
from _pyngraph_onnx_import import import_onnx_function
from _pyngraph_onnx_import import import_onnx_function_file
from _pyngraph_onnx_import import import_onnx_model
from _pyngraph_onnx_import import import_onnx_model_file
......@@ -67,6 +67,7 @@ class Computation(object):
self.runtime = runtime
self.function = ng_function
self.parameters = ng_function.get_parameters()
self.results = ng_function.get_results()
self.tensor_views = [] # type: List[Tensor]
for parameter in self.parameters:
......@@ -74,6 +75,12 @@ class Computation(object):
element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.create_tensor(element_type, shape))
self.result_views = [] # type: List[Tensor]
for result in self.results:
shape = result.get_shape()
element_type = result.get_element_type()
self.result_views.append(runtime.backend.create_tensor(element_type, shape))
def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters])
return '<Computation: {}({})>'.format(self.function.get_name(), params_string)
......@@ -85,18 +92,15 @@ class Computation(object):
value = np.array(value)
Computation._write_ndarray_to_tensor_view(value, tensor_view)
result_element_type = self.function.get_output_element_type(0)
result_shape = self.function.get_output_shape(0)
result_dtype = get_dtype(result_element_type)
result_view = self.runtime.backend.create_tensor(result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype)
self.runtime.backend.call(self.function, self.result_views, self.tensor_views)
self.runtime.backend.call(self.function, [result_view], self.tensor_views)
results = []
for result_view in self.result_views:
result = np.ndarray(result_view.shape, dtype=get_dtype(result_view.element_type))
Computation._read_tensor_view_to_ndarray(result_view, result)
results.append(result)
Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape)
return result_arr
return results
def serialize(self, indent=0): # type: (int) -> str
"""Serialize function (compute graph) to a JSON string.
......
......@@ -28,34 +28,19 @@
namespace py = pybind11;
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model(const std::string& model_proto)
static std::shared_ptr<ngraph::Function> import_onnx_model(const std::string& model_proto)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::load_onnx_model(iss);
return ngraph::onnx_import::import_onnx_model(iss);
}
static std::shared_ptr<ngraph::Function> import_onnx_function(const std::string& model_proto)
static std::shared_ptr<ngraph::Function> import_onnx_model_file(const std::string& filename)
{
std::istringstream iss(model_proto, std::ios_base::binary | std::ios_base::in);
return ngraph::onnx_import::import_onnx_function(iss);
}
static std::vector<std::shared_ptr<ngraph::Function>>
load_onnx_model_file(const std::string& filename)
{
return ngraph::onnx_import::load_onnx_model(filename);
}
static std::shared_ptr<ngraph::Function> import_onnx_function_file(const std::string& filename)
{
return ngraph::onnx_import::import_onnx_function(filename);
return ngraph::onnx_import::import_onnx_model(filename);
}
void regmodule_pyngraph_onnx_import(py::module mod)
{
mod.def("load_onnx_model", &load_onnx_model);
mod.def("import_onnx_function", &import_onnx_function);
mod.def("load_onnx_model_file", &load_onnx_model_file);
mod.def("import_onnx_function_file", &import_onnx_function_file);
mod.def("import_onnx_model", &import_onnx_model);
mod.def("import_onnx_model_file", &import_onnx_model_file);
}
......@@ -93,4 +93,5 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tan(m_op);
regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Result(m_op);
}
......@@ -68,6 +68,7 @@
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/sign.hpp"
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "ngraph/node.hpp"
#include "ngraph/op/result.hpp"
#include "pyngraph/ops/result.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Result(py::module m)
{
py::class_<ngraph::op::Result, std::shared_ptr<ngraph::op::Result>, ngraph::Node> result(
m, "Result");
result.doc() = "ngraph.impl.op.Result wraps ngraph::op::Result";
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Result(py::module m);
......@@ -27,6 +27,7 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/parameter_vector.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/result_vector.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/serializer.hpp"
#include "pyngraph/shape.hpp"
......@@ -58,4 +59,5 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_runtime(m);
regmodule_pyngraph_passes(m);
regmodule_pyngraph_util(m);
regclass_pyngraph_ResultVector(m);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/result.hpp" // ngraph::op::Result
#include "ngraph/result_vector.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/result_vector.hpp"
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m)
{
py::class_<ngraph::ResultVector, std::shared_ptr<ngraph::ResultVector>> result_vector(
m, "ResultVector");
result_vector.doc() = "ngraph.impl.ResultVector wraps ngraph::ResultVector";
result_vector.def(
py::init<const std::initializer_list<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const std::vector<std::shared_ptr<ngraph::op::Result>>&>());
result_vector.def(py::init<const ngraph::ResultVector&>());
result_vector.def("__len__", [](const ngraph::ResultVector& v) { return v.size(); });
result_vector.def("__getitem__", [](const ngraph::ResultVector& v, int key) { return v[key]; });
result_vector.def("__iter__",
[](ngraph::ResultVector& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_ResultVector(py::module m);
......@@ -149,6 +149,7 @@ sources = [
'pyngraph/parameter_vector.cpp',
'pyngraph/pyngraph.cpp',
'pyngraph/util.cpp',
'pyngraph/result_vector.cpp',
'pyngraph/ops/util/arithmetic_reduction.cpp',
'pyngraph/ops/util/binary_elementwise_comparison.cpp',
'pyngraph/ops/util/op_annotations.cpp',
......@@ -223,6 +224,7 @@ sources = [
'pyngraph/ops/min.cpp',
'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp',
'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
'pyngraph/runtime/tensor.cpp',
......
......@@ -17,13 +17,13 @@
import os
import numpy as np
from ngraph.impl.onnx_import import load_onnx_model_file
from ngraph.impl.onnx_import import import_onnx_model_file
from test.ngraph.util import get_runtime
def test_import_onnx_function():
model_path = os.path.join(os.path.dirname(__file__), 'models/add_abc.onnx')
ng_function = load_onnx_model_file(model_path)[0]
ng_function = import_onnx_model_file(model_path)
dtype = np.float32
value_a = np.array([1.0], dtype=dtype)
......
......@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_node([input_data], ng_api_fn)
result = run_op_node([input_data], ng_api_fn)[0]
np.testing.assert_allclose(result, expected, rtol=0.001)
result = run_op_numeric_data(input_data, ng_api_fn)
result = run_op_numeric_data(input_data, ng_api_fn)[0]
np.testing.assert_allclose(result, expected, rtol=0.001)
......
......@@ -95,6 +95,16 @@ namespace ngraph
}
}
NodeVector Graph::get_ng_outputs() const
{
NodeVector results;
for (const auto& output : m_graph_proto->output())
{
results.emplace_back(get_ng_node_from_cache(output.name()));
}
return results;
}
} // namespace onnx_import
} // namespace ngraph
......@@ -38,6 +38,7 @@ namespace ngraph
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
NodeVector get_ng_outputs() const;
const ParameterVector& get_ng_parameters() const { return m_parameters; }
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const
{
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include <fstream>
#include <memory>
#include "core/graph.hpp"
#include "core/model.hpp"
......@@ -50,45 +51,32 @@ namespace ngraph
} // namespace error
} // namespace detail
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights)
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights)
{
onnx::ModelProto model_proto;
if (!model_proto.ParseFromIstream(&sin))
{
throw detail::error::stream_parse{sin};
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph(), model, weights};
for (const auto& output : graph.get_outputs())
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
output_functions.emplace_back(std::make_shared<Function>(
graph.get_ng_node_from_cache(output.get_name()), graph.get_ng_parameters()));
function->get_output_op(i)->set_name(graph.get_outputs().at(i).get_name());
}
return output_functions;
return function;
}
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& path,
const Weights& weights)
std::shared_ptr<Function> import_onnx_model(const std::string& path, const Weights& weights)
{
std::ifstream ifs{path, std::ios::in | std::ios::binary};
if (!ifs.is_open())
{
throw detail::error::file_open{path};
}
return load_onnx_model(ifs, weights);
}
std::shared_ptr<Function> import_onnx_function(std::istream& sin, const Weights& weights)
{
return load_onnx_model(sin, weights).front();
}
std::shared_ptr<Function> import_onnx_function(const std::string& path,
const Weights& weights)
{
return load_onnx_model(path, weights).front();
return import_onnx_model(ifs, weights);
}
void register_operator(const std::string& name,
......
......@@ -40,31 +40,6 @@ namespace ngraph
const std::string& domain,
Operator fn);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream& sin,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std::vector<std::shared_ptr<Function>> load_onnx_model(const std::string& filename,
const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
......@@ -74,8 +49,7 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(std::istream& sin,
const Weights& weights = {});
std::shared_ptr<Function> import_onnx_model(std::istream& sin, const Weights& weights = {});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
......@@ -86,8 +60,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std::shared_ptr<Function> import_onnx_function(const std::string& filename,
const Weights& weights = {});
std::shared_ptr<Function> import_onnx_model(const std::string& filename,
const Weights& weights = {});
} // namespace onnx_import
......
......@@ -66,7 +66,7 @@ int main(int argc, char** argv)
ifstream f(input);
if (f)
{
shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_function(input);
std::shared_ptr<ngraph::Function> function = ngraph::onnx_import::import_onnx_model(input);
ngraph::stopwatch timer;
timer.start();
......
......@@ -31,12 +31,24 @@ using namespace ngraph;
using Inputs = std::vector<std::vector<float>>;
using Outputs = std::vector<std::vector<float>>;
using Model = std::vector<std::shared_ptr<Function>>;
TEST(onnx, model_output_names_check)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx"));
std::size_t size = function->get_output_size();
for (std::size_t i{0}; i < size; ++i)
{
std::shared_ptr<Node> node = function->get_output_op(i);
EXPECT_EQ(node->get_friendly_name(), "output_" + std::to_string(i + 1));
}
}
TEST(onnx, model_add_abc)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.onnx"));
Inputs inputs{{1}, {2}, {3}};
Outputs expected_outputs{{6}};
......@@ -47,7 +59,7 @@ TEST(onnx, model_add_abc)
TEST(onnx, model_add_abc_initializers)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc_initializers.onnx"));
Inputs inputs{{1, 2, 3, 4}};
......@@ -59,7 +71,7 @@ TEST(onnx, model_add_abc_initializers)
TEST(onnx, model_addmul_abc)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/addmul_abc.onnx"));
std::vector<std::vector<float>> inputs;
......@@ -77,7 +89,7 @@ TEST(onnx, model_addmul_abc)
TEST(onnx, model_argmin_no_keepdims)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_no_keepdims.onnx"));
Inputs inputs{test::NDArray<float, 2>{{2, 1}, {3, 10}}.get_vector()};
......@@ -89,51 +101,57 @@ TEST(onnx, model_argmin_no_keepdims)
TEST(onnx, model_split_equal_parts_default)
{
Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx"))};
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_default.onnx"));
Inputs inputs{{1, 2, 3, 4, 5, 6}};
Outputs expected_outputs{{1, 2}, {3, 4}, {5, 6}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{
Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size());
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i]));
}
}
TEST(onnx, model_split_equal_parts_2d)
{
// Split into 2 equal parts along axis=1
Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_2d.onnx"))};
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_equal_parts_2d.onnx"));
Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
Outputs expected_outputs{{0, 1, 2, 6, 7, 8}, {3, 4, 5, 9, 10, 11}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{
Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size());
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i]));
}
}
TEST(onnx, model_split_variable_parts_2d)
{
// Split into variable parts {2, 4} along axis=1
Model model{onnx_import::load_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_variable_parts_2d.onnx"))};
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/split_variable_parts_2d.onnx"));
Inputs inputs{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}};
Outputs expected_outputs{{0, 1, 6, 7}, {2, 3, 4, 5, 8, 9, 10, 11}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (std::size_t i = 0; i < expected_outputs.size(); ++i)
{
Outputs outputs{execute(model[i], inputs, "INTERPRETER")};
EXPECT_EQ(outputs.size(), 1);
EXPECT_TRUE(test::all_close_f(expected_outputs[i], outputs.front()));
EXPECT_EQ(outputs[i].size(), expected_outputs[i].size());
EXPECT_TRUE(test::all_close_f(outputs[i], expected_outputs[i]));
}
}
......@@ -160,12 +178,13 @@ namespace
return execute(function, args, "INTERPRETER");
}
} // namespace
TEST(onnx, model_conv2d_strides_padding)
{
// Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_padding.onnx"));
// (1, 1, 4, 3)
......@@ -182,7 +201,7 @@ TEST(onnx, model_conv2d_strides_padding)
TEST(onnx, model_conv2d_strides_no_padding)
{
// Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_no_padding.onnx"));
// (1, 1, 3, 2)
......@@ -196,7 +215,7 @@ TEST(onnx, model_conv2d_strides_no_padding)
TEST(onnx, model_conv2d_strides_assymetric_padding)
{
// Convolution with strides=2 and padding=1
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv_with_strides_and_asymmetric_padding.onnx"));
// (1, 1, 4, 2)
......@@ -211,7 +230,7 @@ TEST(onnx, model_conv2d_strides_assymetric_padding)
TEST(onnx, model_average_pool_2d)
{
// Pooling with strides=2 and no padding
auto model = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -225,7 +244,7 @@ TEST(onnx, model_average_pool_2d)
// (1, 1, 2, 2)
auto expected_output = test::NDArray<float, 4>({{{{2.5f, 4.5f}, {10.5f, 12.5f}}}}).get_vector();
Outputs outputs{execute(model, inputs, "INTERPRETER")};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front());
}
......@@ -233,7 +252,7 @@ TEST(onnx, model_average_pool_2d)
TEST(onnx, model_average_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/average_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -249,7 +268,7 @@ TEST(onnx, model_average_pool_2d_pads)
test::NDArray<float, 4>({{{{0.f, 1.5f, 3.f}, {6.f, 7.5f, 9.f}, {12.f, 13.5f, 15.f}}}})
.get_vector();
Outputs outputs = execute(model, inputs, "INTERPRETER");
Outputs outputs = execute(function, inputs, "INTERPRETER");
EXPECT_EQ(expected_output, outputs.front());
}
......@@ -257,7 +276,7 @@ TEST(onnx, model_average_pool_2d_pads)
TEST(onnx, model_max_pool_2d_pads)
{
// Pooling with strides=2 and padding=1
auto model = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_2d_pads.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -273,7 +292,7 @@ TEST(onnx, model_max_pool_2d_pads)
test::NDArray<float, 4>({{{{0.f, 2.f, 3.f}, {8.f, 10.f, 11.f}, {12.f, 14.f, 15.f}}}})
.get_vector();
Outputs outputs{execute(model, inputs, "INTERPRETER")};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_EQ(expected_output, outputs.front());
}
......@@ -281,8 +300,8 @@ TEST(onnx, model_max_pool_2d_pads)
TEST(onnx, model_batchnorm_default)
{
// Batch Normalization with default parameters
Model model{onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/batchnorm_default.onnx"))};
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/batchnorm_default.onnx"));
Inputs inputs;
......@@ -304,7 +323,7 @@ TEST(onnx, model_batchnorm_default)
{{{{-0.999995f, 0.f, 0.999995f}}, {{-0.22474074f, 1.f, 2.2247407f}}}}}
.get_vector()};
Outputs outputs{execute(model.front(), inputs, "INTERPRETER")};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
......@@ -312,7 +331,7 @@ TEST(onnx, model_relu)
{
// Simple ReLU test
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/relu.onnx"));
Inputs inputs{{-1, -2, 0, 1, 2, 3}};
Outputs expected_outputs{{0, 0, 0, 1, 2, 3}};
......@@ -325,7 +344,7 @@ TEST(onnx, model_sum)
{
// Simple Sum test
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sum.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -340,7 +359,7 @@ TEST(onnx, model_sum)
TEST(onnx, model_sum_one_input)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/sum_one_input.onnx"));
// input data shape (3, )
......@@ -352,7 +371,7 @@ TEST(onnx, model_sum_one_input)
TEST(onnx, model_min_two_inputs)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/min_two_inputs.onnx"));
// input data shape (3, )
......@@ -368,7 +387,7 @@ TEST(onnx, model_min_two_inputs)
TEST(onnx, model_max)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -384,7 +403,7 @@ TEST(onnx, model_max)
TEST(onnx, model_mean)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/mean.onnx"));
// input data shape (3, )
Inputs inputs;
......@@ -399,8 +418,8 @@ TEST(onnx, model_mean)
TEST(onnx, model_gemm_abc)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/gemm_abc.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 2>(
......@@ -430,7 +449,7 @@ TEST(onnx, model_gemm_abc)
TEST(onnx, model_matmul)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul.onnx"));
std::vector<std::vector<float>> inputs;
......@@ -450,8 +469,8 @@ TEST(onnx, model_matmul)
TEST(onnx, model_softmax)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -506,7 +525,7 @@ TEST(onnx, model_softmax)
TEST(onnx, model_concat)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/concat.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/concat.onnx"));
Inputs inputs;
......@@ -521,8 +540,8 @@ TEST(onnx, model_concat)
TEST(onnx, model_flatten)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/flatten.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/flatten.onnx"));
Inputs inputs;
......@@ -538,7 +557,7 @@ TEST(onnx, model_flatten)
TEST(onnx, model_sub)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sub.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
......@@ -553,8 +572,8 @@ TEST(onnx, model_sub)
TEST(onnx, model_unsqueeze)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/unsqueeze.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
......@@ -576,7 +595,7 @@ TEST(onnx, model_unsqueeze)
TEST(onnx, model_squeeze)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.onnx"));
// {1, 4, 1, 1, 2}
......@@ -596,7 +615,7 @@ TEST(onnx, model_squeeze)
TEST(onnx, model_div)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/div.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>({{{1, 2, 3}}}).get_vector());
......@@ -611,8 +630,8 @@ TEST(onnx, model_div)
TEST(onnx, model_add_bcast)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_bcast.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/add_bcast.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
......@@ -636,7 +655,7 @@ TEST(onnx, model_add_bcast)
TEST(onnx, model_reshape_reduced_dims)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reduced_dims.onnx"));
// input data shape (2, 3, 4)
......@@ -656,7 +675,7 @@ TEST(onnx, model_reshape_reduced_dims)
TEST(onnx, model_reshape_reordered_dims)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_reordered_dims.onnx"));
// input data shape (2, 3, 4)
......@@ -677,7 +696,7 @@ TEST(onnx, model_reshape_reordered_dims)
TEST(onnx, model_reshape_extended_dims)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_extended_dims.onnx"));
// input data shape (2, 3, 4)
......@@ -697,7 +716,7 @@ TEST(onnx, model_reshape_extended_dims)
TEST(onnx, model_reshape_single_dim)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_single_dim.onnx"));
// input data shape (2, 3, 4)
......@@ -717,7 +736,7 @@ TEST(onnx, model_reshape_single_dim)
TEST(onnx, model_reshape_negative_dim)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.onnx"));
// input data shape (2, 3, 4)
......@@ -740,7 +759,7 @@ TEST(onnx, model_reshape_negative_dim)
TEST(onnx, model_reshape_negative_with_zero_dim)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_with_zero_dims.onnx"));
// input data shape (2, 3, 4)
......@@ -760,7 +779,7 @@ TEST(onnx, model_reshape_negative_with_zero_dim)
TEST(onnx, model_reshape_output_shape_as_input)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_output_shape_as_input.onnx"));
// input data shape (2, 3, 4)
......@@ -780,7 +799,7 @@ TEST(onnx, model_reshape_output_shape_as_input)
TEST(onnx, model_reduce_log_sum)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -797,7 +816,7 @@ TEST(onnx, model_reduce_log_sum)
TEST(onnx, model_reduce_log_sum_exp)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_log_sum_exp.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -814,8 +833,8 @@ TEST(onnx, model_reduce_log_sum_exp)
TEST(onnx, model_reduce_l1)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l1.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l1.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs{
......@@ -831,8 +850,8 @@ TEST(onnx, model_reduce_l1)
TEST(onnx, model_reduce_l2)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l2.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_l2.onnx"));
// input data shape (1, 1, 4, 4)
Inputs inputs{
......@@ -848,7 +867,7 @@ TEST(onnx, model_reduce_l2)
TEST(onnx, model_reduce_max)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_max.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -865,7 +884,7 @@ TEST(onnx, model_reduce_max)
TEST(onnx, model_reduce_mean)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_mean.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -882,7 +901,7 @@ TEST(onnx, model_reduce_mean)
TEST(onnx, model_reduce_min)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_min.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -899,7 +918,7 @@ TEST(onnx, model_reduce_min)
TEST(onnx, model_reduce_prod)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_prod.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -916,7 +935,7 @@ TEST(onnx, model_reduce_prod)
TEST(onnx, model_reduce_sum)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -933,7 +952,7 @@ TEST(onnx, model_reduce_sum)
TEST(onnx, model_reduce_sum_square)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_square.onnx"));
// input data shape (1, 1, 4, 4)
......@@ -951,7 +970,7 @@ TEST(onnx, model_reduce_sum_square)
TEST(onnx, model_shape)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/shape.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/shape.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
......@@ -970,7 +989,7 @@ TEST(onnx, model_shape)
TEST(onnx, model_elu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/elu.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/elu.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -1016,7 +1035,7 @@ TEST(onnx, model_elu)
TEST(onnx, model_leaky_relu)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/leaky_relu.onnx"));
Inputs inputs;
......@@ -1048,7 +1067,7 @@ TEST(onnx, model_leaky_relu)
TEST(onnx, prelu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/prelu.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/prelu.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -1078,7 +1097,7 @@ TEST(onnx, prelu)
TEST(onnx, model_selu)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/selu.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/selu.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -1118,8 +1137,8 @@ TEST(onnx, model_selu)
TEST(onnx, model_sigmoid)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/sigmoid.onnx"));
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/sigmoid.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -1194,7 +1213,7 @@ TEST(onnx, model_sigmoid)
TEST(onnx, model_tanh)
{
auto function =
onnx_import::import_onnx_function(file_util::path_join(SERIALIZED_ZOO, "onnx/tanh.onnx"));
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tanh.onnx"));
Inputs inputs;
inputs.emplace_back(
......@@ -1268,7 +1287,7 @@ TEST(onnx, model_tanh)
TEST(onnx, model_thresholded_relu)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/thresholded_relu.onnx"));
Inputs inputs;
......@@ -1294,7 +1313,7 @@ TEST(onnx, model_unsupported_op)
{
try
{
onnx_import::import_onnx_function(
onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/unsupported_op.onnx"));
FAIL() << "Expected ngraph::ngraph_error";
}
......@@ -1319,7 +1338,7 @@ TEST(onnx, model_custom_op)
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.onnx"));
Inputs inputs{{1, 2, 3, 4}};
......@@ -1337,7 +1356,7 @@ TEST(onnx, model_custom_op_default_domain)
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator_default_domain.onnx"));
Inputs inputs{{1, 2, 3, 4}};
......@@ -1349,7 +1368,7 @@ TEST(onnx, model_custom_op_default_domain)
TEST(onnx, model_conv2d_dilation_assymetric_pads_strides)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv2d_dilation_assym_pads_strides.onnx"));
// "", // auto_pad
......@@ -1389,7 +1408,7 @@ TEST(onnx, model_conv2d_dilation_assymetric_pads_strides)
TEST(onnx, model_conv3d_bias)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/conv3d_bias.onnx"));
// "", // auto_pad
......@@ -1503,7 +1522,7 @@ TEST(onnx, model_conv3d_bias)
TEST(onnx, model_matmul_vec_ten3d)
{
auto function = onnx_import::import_onnx_function(
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_vec_ten3d.onnx"));
Inputs inputs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment