Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
3dc2a915
Unverified
Commit
3dc2a915
authored
Dec 04, 2018
by
Chris Sullivan
Committed by
GitHub
Dec 04, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into master
parents
a0446a2f
1b71fdca
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
154 additions
and
138 deletions
+154
-138
__init__.py
python/ngraph/impl/onnx_import/__init__.py
+4
-2
runtime.py
python/ngraph/runtime.py
+10
-14
onnx_import.cpp
python/pyngraph/onnx_import/onnx_import.cpp
+21
-6
regmodule_pyngraph_op.cpp
python/pyngraph/ops/regmodule_pyngraph_op.cpp
+0
-1
regmodule_pyngraph_op.hpp
python/pyngraph/ops/regmodule_pyngraph_op.hpp
+0
-1
pyngraph.cpp
python/pyngraph/pyngraph.cpp
+0
-2
result_vector.cpp
python/pyngraph/result_vector.cpp
+0
-41
result_vector.hpp
python/pyngraph/result_vector.hpp
+0
-23
setup.py
python/setup.py
+0
-2
test_onnx_import.py
python/test/ngraph/test_onnx_import.py
+2
-2
test_ops_unary.py
python/test/ngraph/test_ops_unary.py
+2
-2
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+2
-0
graph.cpp
src/ngraph/frontend/onnx_import/core/graph.cpp
+0
-10
graph.hpp
src/ngraph/frontend/onnx_import/core/graph.hpp
+0
-1
onnx.cpp
src/ngraph/frontend/onnx_import/onnx.cpp
+21
-9
onnx.hpp
src/ngraph/frontend/onnx_import/onnx.hpp
+29
-3
batch_norm.cpp
src/ngraph/frontend/onnx_import/op/batch_norm.cpp
+0
-2
pad.cpp
src/ngraph/frontend/onnx_import/op/pad.cpp
+41
-13
pad.hpp
src/ngraph/frontend/onnx_import/op/pad.hpp
+18
-3
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+3
-0
serialize_onnx.cpp
src/tools/serialize_onnx/serialize_onnx.cpp
+1
-1
onnx_import.cpp
test/onnx_import.cpp
+0
-0
No files found.
python/ngraph/impl/onnx_import/__init__.py
View file @
3dc2a915
...
...
@@ -32,5 +32,7 @@ else:
flags
=
sys
.
getdlopenflags
()
|
ctypes
.
RTLD_GLOBAL
sys
.
setdlopenflags
(
flags
)
from
_pyngraph_onnx_import
import
import_onnx_model
from
_pyngraph_onnx_import
import
import_onnx_model_file
from
_pyngraph_onnx_import
import
load_onnx_model
from
_pyngraph_onnx_import
import
load_onnx_model_file
from
_pyngraph_onnx_import
import
import_onnx_function
from
_pyngraph_onnx_import
import
import_onnx_function_file
python/ngraph/runtime.py
View file @
3dc2a915
...
...
@@ -67,7 +67,6 @@ class Computation(object):
self
.
runtime
=
runtime
self
.
function
=
ng_function
self
.
parameters
=
ng_function
.
get_parameters
()
self
.
results
=
ng_function
.
get_results
()
self
.
tensor_views
=
[]
# type: List[Tensor]
for
parameter
in
self
.
parameters
:
...
...
@@ -75,12 +74,6 @@ class Computation(object):
element_type
=
parameter
.
get_element_type
()
self
.
tensor_views
.
append
(
runtime
.
backend
.
create_tensor
(
element_type
,
shape
))
self
.
result_views
=
[]
# type: List[Tensor]
for
result
in
self
.
results
:
shape
=
result
.
get_shape
()
element_type
=
result
.
get_element_type
()
self
.
result_views
.
append
(
runtime
.
backend
.
create_tensor
(
element_type
,
shape
))
def
__repr__
(
self
):
# type: () -> str
params_string
=
', '
.
join
([
param
.
name
for
param
in
self
.
parameters
])
return
'<Computation: {}({})>'
.
format
(
self
.
function
.
get_name
(),
params_string
)
...
...
@@ -92,15 +85,18 @@ class Computation(object):
value
=
np
.
array
(
value
)
Computation
.
_write_ndarray_to_tensor_view
(
value
,
tensor_view
)
self
.
runtime
.
backend
.
call
(
self
.
function
,
self
.
result_views
,
self
.
tensor_views
)
result_element_type
=
self
.
function
.
get_output_element_type
(
0
)
result_shape
=
self
.
function
.
get_output_shape
(
0
)
result_dtype
=
get_dtype
(
result_element_type
)
result_view
=
self
.
runtime
.
backend
.
create_tensor
(
result_element_type
,
result_shape
)
result_arr
=
np
.
empty
(
result_shape
,
dtype
=
result_dtype
)
results
=
[]
for
result_view
in
self
.
result_views
:
result
=
np
.
ndarray
(
result_view
.
shape
,
dtype
=
get_dtype
(
result_view
.
element_type
))
Computation
.
_read_tensor_view_to_ndarray
(
result_view
,
result
)
results
.
append
(
result
)
self
.
runtime
.
backend
.
call
(
self
.
function
,
[
result_view
],
self
.
tensor_views
)
return
results
Computation
.
_read_tensor_view_to_ndarray
(
result_view
,
result_arr
)
result_arr
=
result_arr
.
reshape
(
result_shape
)
return
result_arr
def
serialize
(
self
,
indent
=
0
):
# type: (int) -> str
"""Serialize function (compute graph) to a JSON string.
...
...
python/pyngraph/onnx_import/onnx_import.cpp
View file @
3dc2a915
...
...
@@ -28,19 +28,34 @@
namespace
py
=
pybind11
;
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_model
(
const
std
::
string
&
model_proto
)
static
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
load_onnx_model
(
const
std
::
string
&
model_proto
)
{
std
::
istringstream
iss
(
model_proto
,
std
::
ios_base
::
binary
|
std
::
ios_base
::
in
);
return
ngraph
::
onnx_import
::
import
_onnx_model
(
iss
);
return
ngraph
::
onnx_import
::
load
_onnx_model
(
iss
);
}
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_
model_file
(
const
std
::
string
&
filename
)
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_
function
(
const
std
::
string
&
model_proto
)
{
return
ngraph
::
onnx_import
::
import_onnx_model
(
filename
);
std
::
istringstream
iss
(
model_proto
,
std
::
ios_base
::
binary
|
std
::
ios_base
::
in
);
return
ngraph
::
onnx_import
::
import_onnx_function
(
iss
);
}
static
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
load_onnx_model_file
(
const
std
::
string
&
filename
)
{
return
ngraph
::
onnx_import
::
load_onnx_model
(
filename
);
}
static
std
::
shared_ptr
<
ngraph
::
Function
>
import_onnx_function_file
(
const
std
::
string
&
filename
)
{
return
ngraph
::
onnx_import
::
import_onnx_function
(
filename
);
}
void
regmodule_pyngraph_onnx_import
(
py
::
module
mod
)
{
mod
.
def
(
"import_onnx_model"
,
&
import_onnx_model
);
mod
.
def
(
"import_onnx_model_file"
,
&
import_onnx_model_file
);
mod
.
def
(
"load_onnx_model"
,
&
load_onnx_model
);
mod
.
def
(
"import_onnx_function"
,
&
import_onnx_function
);
mod
.
def
(
"load_onnx_model_file"
,
&
load_onnx_model_file
);
mod
.
def
(
"import_onnx_function_file"
,
&
import_onnx_function_file
);
}
python/pyngraph/ops/regmodule_pyngraph_op.cpp
View file @
3dc2a915
...
...
@@ -93,5 +93,4 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tan
(
m_op
);
regclass_pyngraph_op_Tanh
(
m_op
);
regclass_pyngraph_op_TopK
(
m_op
);
regclass_pyngraph_op_Result
(
m_op
);
}
python/pyngraph/ops/regmodule_pyngraph_op.hpp
View file @
3dc2a915
...
...
@@ -68,7 +68,6 @@
#include "pyngraph/ops/relu.hpp"
#include "pyngraph/ops/replace_slice.hpp"
#include "pyngraph/ops/reshape.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/ops/reverse.hpp"
#include "pyngraph/ops/select.hpp"
#include "pyngraph/ops/sign.hpp"
...
...
python/pyngraph/pyngraph.cpp
View file @
3dc2a915
...
...
@@ -27,7 +27,6 @@
#include "pyngraph/ops/util/regmodule_pyngraph_op_util.hpp"
#include "pyngraph/parameter_vector.hpp"
#include "pyngraph/passes/regmodule_pyngraph_passes.hpp"
#include "pyngraph/result_vector.hpp"
#include "pyngraph/runtime/regmodule_pyngraph_runtime.hpp"
#include "pyngraph/serializer.hpp"
#include "pyngraph/shape.hpp"
...
...
@@ -59,5 +58,4 @@ PYBIND11_MODULE(_pyngraph, m)
regmodule_pyngraph_runtime
(
m
);
regmodule_pyngraph_passes
(
m
);
regmodule_pyngraph_util
(
m
);
regclass_pyngraph_ResultVector
(
m
);
}
python/pyngraph/result_vector.cpp
deleted
100644 → 0
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/result.hpp" // ngraph::op::Result
#include "ngraph/result_vector.hpp"
#include "pyngraph/ops/result.hpp"
#include "pyngraph/result_vector.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_ResultVector
(
py
::
module
m
)
{
py
::
class_
<
ngraph
::
ResultVector
,
std
::
shared_ptr
<
ngraph
::
ResultVector
>>
result_vector
(
m
,
"ResultVector"
);
result_vector
.
doc
()
=
"ngraph.impl.ResultVector wraps ngraph::ResultVector"
;
result_vector
.
def
(
py
::
init
<
const
std
::
initializer_list
<
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>&>
());
result_vector
.
def
(
py
::
init
<
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>&>
());
result_vector
.
def
(
py
::
init
<
const
ngraph
::
ResultVector
&>
());
result_vector
.
def
(
"__len__"
,
[](
const
ngraph
::
ResultVector
&
v
)
{
return
v
.
size
();
});
result_vector
.
def
(
"__getitem__"
,
[](
const
ngraph
::
ResultVector
&
v
,
int
key
)
{
return
v
[
key
];
});
result_vector
.
def
(
"__iter__"
,
[](
ngraph
::
ResultVector
&
v
)
{
return
py
::
make_iterator
(
v
.
begin
(),
v
.
end
());
},
py
::
keep_alive
<
0
,
1
>
());
/* Keep vector alive while iterator is used */
}
python/pyngraph/result_vector.hpp
deleted
100644 → 0
View file @
a0446a2f
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
void
regclass_pyngraph_ResultVector
(
py
::
module
m
);
python/setup.py
View file @
3dc2a915
...
...
@@ -149,7 +149,6 @@ sources = [
'pyngraph/parameter_vector.cpp'
,
'pyngraph/pyngraph.cpp'
,
'pyngraph/util.cpp'
,
'pyngraph/result_vector.cpp'
,
'pyngraph/ops/util/arithmetic_reduction.cpp'
,
'pyngraph/ops/util/binary_elementwise_comparison.cpp'
,
'pyngraph/ops/util/op_annotations.cpp'
,
...
...
@@ -224,7 +223,6 @@ sources = [
'pyngraph/ops/min.cpp'
,
'pyngraph/ops/batch_norm.cpp'
,
'pyngraph/ops/softmax.cpp'
,
'pyngraph/ops/result.cpp'
,
'pyngraph/runtime/backend.cpp'
,
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp'
,
'pyngraph/runtime/tensor.cpp'
,
...
...
python/test/ngraph/test_onnx_import.py
View file @
3dc2a915
...
...
@@ -17,13 +17,13 @@
import
os
import
numpy
as
np
from
ngraph.impl.onnx_import
import
import
_onnx_model_file
from
ngraph.impl.onnx_import
import
load
_onnx_model_file
from
test.ngraph.util
import
get_runtime
def
test_import_onnx_function
():
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'models/add_abc.onnx'
)
ng_function
=
import_onnx_model_file
(
model_path
)
ng_function
=
load_onnx_model_file
(
model_path
)[
0
]
dtype
=
np
.
float32
value_a
=
np
.
array
([
1.0
],
dtype
=
dtype
)
...
...
python/test/ngraph/test_ops_unary.py
View file @
3dc2a915
...
...
@@ -48,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
input_data
=
range_start
+
np
.
random
.
rand
(
2
,
3
,
4
)
*
(
range_end
-
range_start
)
expected
=
numpy_fn
(
input_data
)
result
=
run_op_node
([
input_data
],
ng_api_fn
)
[
0
]
result
=
run_op_node
([
input_data
],
ng_api_fn
)
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
result
=
run_op_numeric_data
(
input_data
,
ng_api_fn
)
[
0
]
result
=
run_op_numeric_data
(
input_data
,
ng_api_fn
)
np
.
testing
.
assert_allclose
(
result
,
expected
,
rtol
=
0.001
)
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
3dc2a915
...
...
@@ -102,6 +102,8 @@ add_library(onnx_import STATIC
op/neg.hpp
op/not.hpp
op/or.hpp
op/pad.cpp
op/pad.hpp
op/pow.hpp
op/prelu.cpp
op/prelu.hpp
...
...
src/ngraph/frontend/onnx_import/core/graph.cpp
View file @
3dc2a915
...
...
@@ -95,16 +95,6 @@ namespace ngraph
}
}
NodeVector
Graph
::
get_ng_outputs
()
const
{
NodeVector
results
;
for
(
const
auto
&
output
:
m_graph_proto
->
output
())
{
results
.
emplace_back
(
get_ng_node_from_cache
(
output
.
name
()));
}
return
results
;
}
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/core/graph.hpp
View file @
3dc2a915
...
...
@@ -38,7 +38,6 @@ namespace ngraph
const
std
::
vector
<
Node
>&
get_nodes
()
const
{
return
m_nodes
;
}
const
std
::
vector
<
ValueInfo
>&
get_inputs
()
const
{
return
m_inputs
;
}
const
std
::
vector
<
ValueInfo
>&
get_outputs
()
const
{
return
m_outputs
;
}
NodeVector
get_ng_outputs
()
const
;
const
ParameterVector
&
get_ng_parameters
()
const
{
return
m_parameters
;
}
std
::
shared_ptr
<
ngraph
::
Node
>
get_ng_node_from_cache
(
const
std
::
string
&
name
)
const
{
...
...
src/ngraph/frontend/onnx_import/onnx.cpp
View file @
3dc2a915
...
...
@@ -15,7 +15,6 @@
//*****************************************************************************
#include <fstream>
#include <memory>
#include "core/graph.hpp"
#include "core/model.hpp"
...
...
@@ -51,32 +50,45 @@ namespace ngraph
}
// namespace error
}
// namespace detail
std
::
shared_ptr
<
Function
>
import_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
)
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
)
{
onnx
::
ModelProto
model_proto
;
if
(
!
model_proto
.
ParseFromIstream
(
&
sin
))
{
throw
detail
::
error
::
stream_parse
{
sin
};
}
std
::
vector
<
std
::
shared_ptr
<
Function
>>
output_functions
;
Model
model
{
model_proto
};
Graph
graph
{
model_proto
.
graph
(),
model
,
weights
};
auto
function
=
std
::
make_shared
<
Function
>
(
graph
.
get_ng_outputs
(),
graph
.
get_ng_parameters
(),
graph
.
get_name
());
for
(
std
::
size_t
i
{
0
};
i
<
function
->
get_output_size
();
++
i
)
for
(
const
auto
&
output
:
graph
.
get_outputs
())
{
function
->
get_output_op
(
i
)
->
set_name
(
graph
.
get_outputs
().
at
(
i
).
get_name
());
output_functions
.
emplace_back
(
std
::
make_shared
<
Function
>
(
graph
.
get_ng_node_from_cache
(
output
.
get_name
()),
graph
.
get_ng_parameters
()));
}
return
function
;
return
output_functions
;
}
std
::
shared_ptr
<
Function
>
import_onnx_model
(
const
std
::
string
&
path
,
const
Weights
&
weights
)
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
const
std
::
string
&
path
,
const
Weights
&
weights
)
{
std
::
ifstream
ifs
{
path
,
std
::
ios
::
in
|
std
::
ios
::
binary
};
if
(
!
ifs
.
is_open
())
{
throw
detail
::
error
::
file_open
{
path
};
}
return
import_onnx_model
(
ifs
,
weights
);
return
load_onnx_model
(
ifs
,
weights
);
}
std
::
shared_ptr
<
Function
>
import_onnx_function
(
std
::
istream
&
sin
,
const
Weights
&
weights
)
{
return
load_onnx_model
(
sin
,
weights
).
front
();
}
std
::
shared_ptr
<
Function
>
import_onnx_function
(
const
std
::
string
&
path
,
const
Weights
&
weights
)
{
return
load_onnx_model
(
path
,
weights
).
front
();
}
void
register_operator
(
const
std
::
string
&
name
,
...
...
src/ngraph/frontend/onnx_import/onnx.hpp
View file @
3dc2a915
...
...
@@ -40,6 +40,31 @@ namespace ngraph
const
std
::
string
&
domain
,
Operator
fn
);
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The serialized
/// ONNX model is read from input stream.
/// \param sin input stream (e.g. file stream, memory stream, etc),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
/// is read from ONNX file.
/// \param filename file name (relative or absolute path name),
/// \param weights weights associated with the model. If weights are embedded into
/// the model this parameter shall be empty. Having weights in a model
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a vector of nGraph functions. The number of functions
/// depends on number of outputs from ONNX graph.
std
::
vector
<
std
::
shared_ptr
<
Function
>>
load_onnx_model
(
const
std
::
string
&
filename
,
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream.
...
...
@@ -49,7 +74,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std
::
shared_ptr
<
Function
>
import_onnx_model
(
std
::
istream
&
sin
,
const
Weights
&
weights
=
{});
std
::
shared_ptr
<
Function
>
import_onnx_function
(
std
::
istream
&
sin
,
const
Weights
&
weights
=
{});
/// \brief Convert an ONNX model to nGraph functions
/// The function translated serialized ONNX model to nGraph functions. The ONNX model
...
...
@@ -60,8 +86,8 @@ namespace ngraph
/// and providing through this parameters is invalid (the weights from
/// the model will take precedence).
/// \return The function returns a nGraph function representing single output from graph.
std
::
shared_ptr
<
Function
>
import_onnx_
model
(
const
std
::
string
&
filename
,
const
Weights
&
weights
=
{});
std
::
shared_ptr
<
Function
>
import_onnx_
function
(
const
std
::
string
&
filename
,
const
Weights
&
weights
=
{});
}
// namespace onnx_import
...
...
src/ngraph/frontend/onnx_import/op/batch_norm.cpp
View file @
3dc2a915
...
...
@@ -40,13 +40,11 @@ namespace ngraph
std
::
shared_ptr
<
ngraph
::
Node
>
var
{
nullptr
};
std
::
int64_t
is_test
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"is_test"
,
1
)};
std
::
int64_t
spatial
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"spatial"
,
1
)};
double
epsilon
{
node
.
get_attribute_value
<
double
>
(
"epsilon"
,
1e-5
)};
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED
(
node
,
is_test
)
<<
"only 'is_test' mode is supported."
;
ASSERT_IS_SUPPORTED
(
node
,
spatial
)
<<
"only 'spatial' mode is supported."
;
if
(
inputs
.
size
()
>=
5
)
{
...
...
python/pyngraph/ops/result
.cpp
→
src/ngraph/frontend/onnx_import/op/pad
.cpp
View file @
3dc2a915
...
...
@@ -14,19 +14,47 @@
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/result.hpp"
#include "pyngraph/ops/result.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/op/pad.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/shape.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_op_Result
(
py
::
module
m
)
namespace
ngraph
{
py
::
class_
<
ngraph
::
op
::
Result
,
std
::
shared_ptr
<
ngraph
::
op
::
Result
>
,
ngraph
::
Node
>
result
(
m
,
"Result"
);
result
.
doc
()
=
"ngraph.impl.op.Result wraps ngraph::op::Result"
;
}
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
pad
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
const
Shape
&
data_shape
=
data
->
get_shape
();
double
value
=
node
.
get_attribute_value
<
double
>
(
"value"
,
0
);
auto
paddings
=
convpool
::
get_pads
(
node
,
data_shape
);
ngraph
::
CoordinateDiff
padding_below
=
paddings
.
first
;
ngraph
::
CoordinateDiff
padding_above
=
paddings
.
second
;
return
{
std
::
make_shared
<
ngraph
::
op
::
Pad
>
(
data
,
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
data
->
get_element_type
(),
ngraph
::
Shape
{},
std
::
vector
<
double
>
{
value
}),
Shape
(
padding_below
.
begin
(),
padding_below
.
end
()),
Shape
(
padding_above
.
begin
(),
padding_above
.
end
()),
Shape
(
data_shape
.
size
(),
0
))};
}
}
// namespace set_1
}
//namespace op
}
// namespace onnx_import
}
// namespace ngraph
python/pyngraph/ops/result
.hpp
→
src/ngraph/frontend/onnx_import/op/pad
.hpp
View file @
3dc2a915
...
...
@@ -16,8 +16,23 @@
#pragma once
#include <pybind11/pybind11.h>
#include "ngraph/frontend/onnx_import/core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace
py
=
pybind11
;
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
pad
(
const
Node
&
node
);
void
regclass_pyngraph_op_Result
(
py
::
module
m
);
}
// namespace set_1
}
//namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
3dc2a915
...
...
@@ -66,6 +66,8 @@
#include "op/neg.hpp"
#include "op/not.hpp"
#include "op/or.hpp"
#include "op/pad.cpp"
#include "op/pad.hpp"
#include "op/pow.hpp"
#include "op/prelu.hpp"
#include "op/reciprocal.hpp"
...
...
@@ -195,6 +197,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Neg"
,
1
,
neg
);
REGISTER_OPERATOR
(
"Not"
,
1
,
logical_not
);
REGISTER_OPERATOR
(
"Or"
,
1
,
logical_or
);
REGISTER_OPERATOR
(
"Pad"
,
1
,
pad
);
REGISTER_OPERATOR
(
"Pow"
,
1
,
pow
);
REGISTER_OPERATOR
(
"PRelu"
,
1
,
prelu
);
REGISTER_OPERATOR
(
"Reciprocal"
,
1
,
reciprocal
);
...
...
src/tools/serialize_onnx/serialize_onnx.cpp
View file @
3dc2a915
...
...
@@ -66,7 +66,7 @@ int main(int argc, char** argv)
ifstream
f
(
input
);
if
(
f
)
{
s
td
::
shared_ptr
<
ngraph
::
Function
>
function
=
ngraph
::
onnx_import
::
import_onnx_model
(
input
);
s
hared_ptr
<
ngraph
::
Function
>
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
input
);
ngraph
::
stopwatch
timer
;
timer
.
start
();
...
...
test/onnx_import.cpp
View file @
3dc2a915
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment