Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
a0446a2f
Unverified
Commit
a0446a2f
authored
Dec 04, 2018
by
Chris Sullivan
Committed by
GitHub
Dec 04, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into master
parents
a1392e91
c5b082c6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
175 additions
and
90 deletions
+175
-90
__init__.py
python/ngraph/impl/onnx_import/__init__.py
+2
-4
runtime.py
python/ngraph/runtime.py
+14
-10
onnx_import.cpp
python/pyngraph/onnx_import/onnx_import.cpp
+6
-21
regmodule_pyngraph_op.cpp
python/pyngraph/ops/regmodule_pyngraph_op.cpp
+1
-0
regmodule_pyngraph_op.hpp
python/pyngraph/ops/regmodule_pyngraph_op.hpp
+1
-0
result.cpp
python/pyngraph/ops/result.cpp
+32
-0
result.hpp
python/pyngraph/ops/result.hpp
+23
-0
pyngraph.cpp
python/pyngraph/pyngraph.cpp
+2
-0
result_vector.cpp
python/pyngraph/result_vector.cpp
+41
-0
result_vector.hpp
python/pyngraph/result_vector.hpp
+23
-0
setup.py
python/setup.py
+2
-0
test_onnx_import.py
python/test/ngraph/test_onnx_import.py
+2
-2
test_ops_unary.py
python/test/ngraph/test_ops_unary.py
+2
-2
graph.cpp
src/ngraph/frontend/onnx_import/core/graph.cpp
+10
-0
graph.hpp
src/ngraph/frontend/onnx_import/core/graph.hpp
+1
-0
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+9
-21
onnx.hpp
src/ngraph/frontend/onnx_import/onnx.hpp
+3
-29
serialize_onnx.cpp
src/tools/serialize_onnx/serialize_onnx.cpp
+1
-1
onnx_import.cpp
test/onnx_import.cpp
+0
-0
No files found.
python/ngraph/impl/onnx_import/__init__.py
View file @
a0446a2f
...
@@ -32,7 +32,5 @@ else:
...
@@ -32,7 +32,5 @@ else:
flags
=
sys
.
getdlopenflags
()
|
ctypes
.
RTLD_GLOBAL
flags
=
sys
.
getdlopenflags
()
|
ctypes
.
RTLD_GLOBAL
sys
.
setdlopenflags
(
flags
)
sys
.
setdlopenflags
(
flags
)
from
_pyngraph_onnx_import
import
load_onnx_model
from
_pyngraph_onnx_import
import
import_onnx_model
from
_pyngraph_onnx_import
import
load_onnx_model_file
from
_pyngraph_onnx_import
import
import_onnx_model_file
from
_pyngraph_onnx_import
import
import_onnx_function
from
_pyngraph_onnx_import
import
import_onnx_function_file
python/ngraph/runtime.py
View file @
a0446a2f
...
@@ -67,6 +67,7 @@ class Computation(object):
...
@@ -67,6 +67,7 @@ class Computation(object):
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
function
=
ng_function
self
.
function
=
ng_function
self
.
parameters
=
ng_function
.
get_parameters
()
self
.
parameters
=
ng_function
.
get_parameters
()
self
.
results
=
ng_function
.
get_results
()
self
.
tensor_views
=
[]
# type: List[Tensor]
self
.
tensor_views
=
[]
# type: List[Tensor]
for
parameter
in
self
.
parameters
:
for
parameter
in
self
.
parameters
:
...
@@ -74,6 +75,12 @@ class Computation(object):
...
@@ -74,6 +75,12 @@ class Computation(object):
element_type
=
parameter
.
get_element_type
()
element_type
=
parameter
.
get_element_type
()
self
.
tensor_views
.
append
(
runtime
.
backend
.
create_tensor
(
element_type
,
shape
))
self
.
tensor_views
.
append
(
runtime
.
backend
.
create_tensor
(
element_type
,
shape
))
self
.
result_views
=
[]
# type: List[Tensor]
for
result
in
self
.
results
:
shape
=
result
.
get_shape
()
element_type
=
result
.
get_element_type
()
self
.
result_views
.
append
(
runtime
.
backend
.
create_tensor
(
element_type
,
shape
))
def
__repr__
(
self
):
# type: () -> str
def
__repr__
(
self
):
# type: () -> str
params_string
=
', '
.
join
([
param
.
name
for
param
in
self
.
parameters
])
params_string
=
', '
.
join
([
param
.
name
for
param
in
self
.
parameters
])
return
'<Computation: {}({})>'
.
format
(
self
.
function
.
get_name
(),
params_string
)
return
'<Computation: {}({})>'
.
format
(
self
.
function
.
get_name
(),
params_string
)
...
@@ -85,18 +92,15 @@ class Computation(object):
...
@@ -85,18 +92,15 @@ class Computation(object):
value
=
np
.
array
(
value
)
value
=
np
.
array
(
value
)
Computation
.
_write_ndarray_to_tensor_view
(
value
,
tensor_view
)
Computation
.
_write_ndarray_to_tensor_view
(
value
,
tensor_view
)
result_element_type
=
self
.
function
.
get_output_element_type
(
0
)
self
.
runtime
.
backend
.
call
(
self
.
function
,
self
.
result_views
,
self
.
tensor_views
)
result_shape
=
self
.
function
.
get_output_shape
(
0
)
result_dtype
=
get_dtype
(
result_element_type
)
result_view
=
self
.
runtime
.
backend
.
create_tensor
(
result_element_type
,
result_shape
)
result_arr
=
np
.
empty
(
result_shape
,
dtype
=
result_dtype
)
self
.
runtime
.
backend
.
call
(
self
.
function
,
[
result_view
],
self
.
tensor_views
)
results
=
[]
for
result_view
in
self
.
result_views
:
result
=
np
.
ndarray
(
result_view
.
shape
,
dtype
=
get_dtype
(
result_view
.
element_type
))
Computation
.
_read_tensor_view_to_ndarray
(
result_view
,
result
)
results
.
append
(
result
)
Computation
.
_read_tensor_view_to_ndarray
(
result_view
,
result_arr
)
return
results
result_arr
=
result_arr
.
reshape
(
result_shape
)
return
result_arr
def
serialize
(
self
,
indent
=
0
):
# type: (int) -> str
def
serialize
(
self
,
indent
=
0
):
# type: (int) -> str
"""Serialize function (compute graph) to a JSON string.
"""Serialize function (compute graph) to a JSON string.
...
...
python/pyngraph/onnx_import/onnx_import.cpp
View file @
a0446a2f
...
@@ -28,34 +28,19 @@
...
@@ -28,34 +28,19 @@
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
static
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_model
(
const
std
::
string
&
model_proto
)
load_onnx_model
(
const
std
::
string
&
model_proto
)
{
{
std
::
istringstream
iss
(
model_proto
,
std
::
ios_base
::
binary
|
std
::
ios_base
::
in
);
std
::
istringstream
iss
(
model_proto
,
std
::
ios_base
::
binary
|
std
::
ios_base
::
in
);
return
ngraph
::
onnx_import
::
load
_onnx_model
(
iss
);
return
ngraph
::
onnx_import
::
import
_onnx_model
(
iss
);
}
}
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_
function
(
const
std
::
string
&
model_proto
)
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_
model_file
(
const
std
::
string
&
filename
)
{
{
std
::
istringstream
iss
(
model_proto
,
std
::
ios_base
::
binary
|
std
::
ios_base
::
in
);
return
ngraph
::
onnx_import
::
import_onnx_model
(
filename
);
return
ngraph
::
onnx_import
::
import_onnx_function
(
iss
);
}
static
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
load_onnx_model_file
(
const
std
::
string
&
filename
)
{
return
ngraph
::
onnx_import
::
load_onnx_model
(
filename
);
}
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_function_file
(
const
std
::
string
&
filename
)
{
return
ngraph
::
onnx_import
::
import_onnx_function
(
filename
);
}
}
void
regmodule_pyngraph_onnx_import
(
py
::
module
mod
)
void
regmodule_pyngraph_onnx_import
(
py
::
module
mod
)
{
{
mod
.
def
(
"load_onnx_model"
,
&
load_onnx_model
);
mod
.
def
(
"import_onnx_model"
,
&
import_onnx_model
);
mod
.
def
(
"import_onnx_function"
,
&
import_onnx_function
);
mod
.
def
(
"import_onnx_model_file"
,
&
import_onnx_model_file
);
mod
.
def
(
"load_onnx_model_file"
,
&
load_onnx_model_file
);
mod
.
def
(
"import_onnx_function_file"
,
&
import_onnx_function_file
);
}
}
python/pyngraph/ops/regmodule_pyngraph_op.cpp
View file @
a0446a2f
...
@@ -93,4 +93,5 @@ void regmodule_pyngraph_op(py::module m_op)
...
@@ -93,4 +93,5 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tan
(
m_op
);
regclass_pyngraph_op_Tan
(
m_op
);
regclass_pyngraph_op_Tanh
(
m_op
);
regclass_pyngraph_op_Tanh
(
m_op
);
regclass_pyngraph_op_TopK
(
m_op
);
regclass_pyngraph_op_TopK
(
m_op
);
regclass_pyngraph_op_Result
(
m_op
);
}
}
python/pyngraph/ops/regmodule_pyngraph_op.hpp
View file @
a0446a2f
...
@@ -68,6 +68,7 @@
...
@@ -68,6 +68,7 @@
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/sign.hpp"
#include "pyngraph/ops/sign.hpp"
...
...
python/pyngraph/ops/result.cpp
0 → 100644
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "ngraph/node.hpp"
#include "ngraph/op/result.hpp"
#include "pyngraph/ops/result.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_op_Result
(
py
::
module
m
)
{
py
::
class_
<
ngraph
::
op
::
Result
,
std
::
shared_ptr
<
ngraph
::
op
::
Result
>
,
ngraph
::
Node
>
result
(
m
,
"Result"
);
result
.
doc
()
=
"ngraph.impl.op.Result wraps ngraph::op::Result"
;
}
python/pyngraph/ops/result.hpp
0 → 100644
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
void
regclass_pyngraph_op_Result
(
py
::
module
m
);
python/pyngraph/pyngraph.cpp
View file @
a0446a2f
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/parameter_vector.hpp"
#include "pyngraph/parameter_vector.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/result_vector.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/serializer.hpp"
#include "pyngraph/serializer.hpp"
#include "pyngraph/shape.hpp"
#include "pyngraph/shape.hpp"
...
@@ -58,4 +59,5 @@ PYBIND11_MODULE(_pyngraph, m)
...
@@ -58,4 +59,5 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_runtime
(
m
);
regmodule_pyngraph_runtime
(
m
);
regmodule_pyngraph_passes
(
m
);
regmodule_pyngraph_passes
(
m
);
regmodule_pyngraph_util
(
m
);
regmodule_pyngraph_util
(
m
);
regclass_pyngraph_ResultVector
(
m
);
}
}
python/pyngraph/result_vector.cpp
0 → 100644
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/result.hpp" // ngraph::op::Result
#include "ngraph/result_vector.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/result_vector.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_ResultVector
(
py
::
module
m
)
{
py
::
class_
<
ngraph
::
ResultVector
,
std
::
shared_ptr
<
ngraph
::
ResultVector
>>
result_vector
(
m
,
"ResultVector"
);
result_vector
.
doc
()
=
"ngraph.impl.ResultVector wraps ngraph::ResultVector"
;
result_vector
.
def
(
py
::
init
<
const
std
::
initializer_list
<
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>&>
());
result_vector
.
def
(
py
::
init
<
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>&>
());
result_vector
.
def
(
py
::
init
<
const
ngraph
::
ResultVector
&>
());
result_vector
.
def
(
"__len__"
,
[](
const
ngraph
::
ResultVector
&
v
)
{
return
v
.
size
();
});
result_vector
.
def
(
"__getitem__"
,
[](
const
ngraph
::
ResultVector
&
v
,
int
key
)
{
return
v
[
key
];
});
result_vector
.
def
(
"__iter__"
,
[](
ngraph
::
ResultVector
&
v
)
{
return
py
::
make_iterator
(
v
.
begin
(),
v
.
end
());
},
py
::
keep_alive
<
0
,
1
>
());
/* Keep vector alive while iterator is used */
}
python/pyngraph/result_vector.hpp
0 → 100644
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
void
regclass_pyngraph_ResultVector
(
py
::
module
m
);
python/setup.py
View file @
a0446a2f
...
@@ -149,6 +149,7 @@ sources = [
...
@@ -149,6 +149,7 @@ sources = [
'pyngraph/parameter_vector.cpp'
,
'pyngraph/parameter_vector.cpp'
,
'pyngraph/pyngraph.cpp'
,
'pyngraph/pyngraph.cpp'
,
'pyngraph/util.cpp'
,
'pyngraph/util.cpp'
,
'pyngraph/result_vector.cpp'
,
'pyngraph/ops/util/arithmetic_reduction.cpp'
,
'pyngraph/ops/util/arithmetic_reduction.cpp'
,
'pyngraph/ops/util/binary_elementwise_comparison.cpp'
,
'pyngraph/ops/util/binary_elementwise_comparison.cpp'
,
'pyngraph/ops/util/op_annotations.cpp'
,
'pyngraph/ops/util/op_annotations.cpp'
,
...
@@ -223,6 +224,7 @@ sources = [
...
@@ -223,6 +224,7 @@ sources = [
'pyngraph/ops/min.cpp'
,
'pyngraph/ops/min.cpp'
,
'pyngraph/ops/batch_norm.cpp'
,
'pyngraph/ops/batch_norm.cpp'
,
'pyngraph/ops/softmax.cpp'
,
'pyngraph/ops/softmax.cpp'
,
'pyngraph/ops/result.cpp'
,
'pyngraph/runtime/backend.cpp'
,
'pyngraph/runtime/backend.cpp'
,
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp'
,
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp'
,
'pyngraph/runtime/tensor.cpp'
,
'pyngraph/runtime/tensor.cpp'
,
...
...
python/test/ngraph/test_onnx_import.py
View file @
a0446a2f
...
@@ -17,13 +17,13 @@
...
@@ -17,13 +17,13 @@
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
ngraph.impl.onnx_import
import
load
_onnx_model_file
from
ngraph.impl.onnx_import
import
import
_onnx_model_file
from
test.ngraph.util
import
get_runtime
from
test.ngraph.util
import
get_runtime
def
test_import_onnx_function
():
def
test_import_onnx_function
():
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'models/add_abc.onnx'
)
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'models/add_abc.onnx'
)
ng_function
=
load_onnx_model_file
(
model_path
)[
0
]
ng_function
=
import_onnx_model_file
(
model_path
)
dtype
=
np
.
float32
dtype
=
np
.
float32
value_a
=
np
.
array
([
1.0
],
dtype
=
dtype
)
value_a
=
np
.
array
([
1.0
],
dtype
=
dtype
)
...
...
python/test/ngraph/test_ops_unary.py
View file @
a0446a2f
...
@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
...
@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
input_data
=
range_start
+
np
.
random
.
rand
(
2
,
3
,
4
)
*
(
range_end
-
range_start
)
input_data
=
range_start
+
np
.
random
.
rand
(
2
,
3
,
4
)
*
(
range_end
-
range_start
)
expected
=
numpy_fn
(
input_data
)
expected
=
numpy_fn
(
input_data
)
result
=
run_op_node
([
input_data
],
ng_api_fn
)
result
=
run_op_node
([
input_data
],
ng_api_fn
)
[
0
]
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
result
=
run_op_numeric_data
(
input_data
,
ng_api_fn
)
result
=
run_op_numeric_data
(
input_data
,
ng_api_fn
)
[
0
]
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
...
...
src/ngraph/frontend/onnx_import/core/graph.cpp
View file @
a0446a2f
...
@@ -95,6 +95,16 @@ namespace ngraph
...
@@ -95,6 +95,16 @@ namespace ngraph
}
}
}
}
NodeVector
Graph
::
get_ng_outputs
()
const
{
NodeVector
results
;
for
(
const
auto
&
output
:
m_graph_proto
->
output
())
{
results
.
emplace_back
(
get_ng_node_from_cache
(
output
.
name
()));
}
return
results
;
}
}
// namespace onnx_import
}
// namespace onnx_import
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/frontend/onnx_import/core/graph.hpp
View file @
a0446a2f
...
@@ -38,6 +38,7 @@ namespace ngraph
...
@@ -38,6 +38,7 @@ namespace ngraph
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
vector
<
ValueInfo
>&
get_outputs
()
const
{
return
m_outputs
;
}
const
std
::
vector
<
ValueInfo
>&
get_outputs
()
const
{
return
m_outputs
;
}
NodeVector
get_ng_outputs
()
const
;
const
ParameterVector
&
get_ng_parameters
()
const
{
return
m_parameters
;
}
const
ParameterVector
&
get_ng_parameters
()
const
{
return
m_parameters
;
}
std
::
shared_ptr
<
ngraph
::
Node
>
get_ng_node_from_cache
(
const
std
::
string
&
name
)
const
std
::
shared_ptr
<
ngraph
::
Node
>
get_ng_node_from_cache
(
const
std
::
string
&
name
)
const
{
{
...
...
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
a0446a2f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include <fstream>
#include <fstream>
#include <memory>
#include "core/graph.hpp"
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/model.hpp"
...
@@ -50,45 +51,32 @@ namespace ngraph
...
@@ -50,45 +51,32 @@ namespace ngraph
}
// namespace error
}
// namespace error
}
// namespace detail
}
// namespace detail
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
std
::
istream
&
sin
,
std
::
shared_ptr
<
Function
>
import_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
)
const
Weights
&
weights
)
{
{
onnx
::
ModelProto
model_proto
;
onnx
::
ModelProto
model_proto
;
if
(
!
model_proto
.
ParseFromIstream
(
&
sin
))
if
(
!
model_proto
.
ParseFromIstream
(
&
sin
))
{
{
throw
detail
::
error
::
stream_parse
{
sin
};
throw
detail
::
error
::
stream_parse
{
sin
};
}
}
std
::
vector
<
std
::
shared_ptr
<
Function
>>
output_functions
;
Model
model
{
model_proto
};
Model
model
{
model_proto
};
Graph
graph
{
model_proto
.
graph
(),
model
,
weights
};
Graph
graph
{
model_proto
.
graph
(),
model
,
weights
};
for
(
const
auto
&
output
:
graph
.
get_outputs
())
auto
function
=
std
::
make_shared
<
Function
>
(
graph
.
get_ng_outputs
(),
graph
.
get_ng_parameters
(),
graph
.
get_name
());
for
(
std
::
size_t
i
{
0
};
i
<
function
->
get_output_size
();
++
i
)
{
{
output_functions
.
emplace_back
(
std
::
make_shared
<
Function
>
(
function
->
get_output_op
(
i
)
->
set_name
(
graph
.
get_outputs
().
at
(
i
).
get_name
());
graph
.
get_ng_node_from_cache
(
output
.
get_name
()),
graph
.
get_ng_parameters
()));
}
}
return
output_functions
;
return
function
;
}
}
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
Function
>
import_onnx_model
(
const
std
::
string
&
path
,
const
Weights
&
weights
)
const
Weights
&
weights
)
{
{
std
::
ifstream
ifs
{
path
,
std
::
ios
::
in
|
std
::
ios
::
binary
};
std
::
ifstream
ifs
{
path
,
std
::
ios
::
in
|
std
::
ios
::
binary
};
if
(
!
ifs
.
is_open
())
if
(
!
ifs
.
is_open
())
{
{
throw
detail
::
error
::
file_open
{
path
};
throw
detail
::
error
::
file_open
{
path
};
}
}
return
load_onnx_model
(
ifs
,
weights
);
return
import_onnx_model
(
ifs
,
weights
);
}
std
::
shared_ptr
<
Function
>
import_onnx_function
(
std
::
istream
&
sin
,
const
Weights
&
weights
)
{
return
load_onnx_model
(
sin
,
weights
).
front
();
}
std
::
shared_ptr
<
Function
>
import_onnx_function
(
const
std
::
string
&
path
,
const
Weights
&
weights
)
{
return
load_onnx_model
(
path
,
weights
).
front
();
}
}
void
register_operator
(
const
std
::
string
&
name
,
void
register_operator
(
const
std
::
string
&
name
,
...
...
src/ngraph/frontend/onnx_import/onnx.hpp
View file @
a0446a2f
...
@@ -40,31 +40,6 @@ namespace ngraph
...
@@ -40,31 +40,6 @@ namespace ngraph
const
std
::
string
&
domain
,
const
std
::
string
&
domain
,
Operator
fn
);
Operator
fn
);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
const
std
::
string
&
filename
,
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph function
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
/// ONNX model is read from input stream.
...
@@ -74,8 +49,7 @@ namespace ngraph
...
@@ -74,8 +49,7 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
/// \return The function returns a nGraph function representing single output from graph.
std
::
shared_ptr
<
Function
>
import_onnx_function
(
std
::
istream
&
sin
,
std
::
shared_ptr
<
Function
>
import_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
=
{});
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph functions
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
...
@@ -86,8 +60,8 @@ namespace ngraph
...
@@ -86,8 +60,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
/// \return The function returns a nGraph function representing single output from graph.
std
::
shared_ptr
<
Function
>
import_onnx_
function
(
const
std
::
string
&
filename
,
std
::
shared_ptr
<
Function
>
import_onnx_
model
(
const
std
::
string
&
filename
,
const
Weights
&
weights
=
{});
const
Weights
&
weights
=
{});
}
// namespace onnx_import
}
// namespace onnx_import
...
...
src/tools/serialize_onnx/serialize_onnx.cpp
View file @
a0446a2f
...
@@ -66,7 +66,7 @@ int main(int argc, char** argv)
...
@@ -66,7 +66,7 @@ int main(int argc, char** argv)
ifstream
f
(
input
);
ifstream
f
(
input
);
if
(
f
)
if
(
f
)
{
{
s
hared_ptr
<
ngraph
::
Function
>
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
input
);
s
td
::
shared_ptr
<
ngraph
::
Function
>
function
=
ngraph
::
onnx_import
::
import_onnx_model
(
input
);
ngraph
::
stopwatch
timer
;
ngraph
::
stopwatch
timer
;
timer
.
start
();
timer
.
start
();
...
...
test/onnx_import.cpp
View file @
a0446a2f
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment