Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
394b74fa
Commit
394b74fa
authored
Aug 30, 2018
by
Artur Wojcik
Committed by
Michał Karzyński
Aug 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] re-enable unit tests for CentOS (#1517)
parent
c4970542
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
82 additions
and
92 deletions
+82
-92
attribute.cpp
src/ngraph/frontend/onnx_import/core/attribute.cpp
+2
-2
attribute.hpp
src/ngraph/frontend/onnx_import/core/attribute.hpp
+19
-19
graph.cpp
src/ngraph/frontend/onnx_import/core/graph.cpp
+5
-5
graph.hpp
src/ngraph/frontend/onnx_import/core/graph.hpp
+2
-2
model.hpp
src/ngraph/frontend/onnx_import/core/model.hpp
+6
-6
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+1
-1
node.hpp
src/ngraph/frontend/onnx_import/core/node.hpp
+5
-5
tensor.hpp
src/ngraph/frontend/onnx_import/core/tensor.hpp
+11
-11
value_info.hpp
src/ngraph/frontend/onnx_import/core/value_info.hpp
+8
-8
CMakeLists.txt
test/CMakeLists.txt
+0
-8
onnx_import.cpp
test/onnx_import.cpp
+23
-25
No files found.
src/ngraph/frontend/onnx_import/core/attribute.cpp
View file @
394b74fa
...
@@ -23,10 +23,10 @@ namespace ngraph
...
@@ -23,10 +23,10 @@ namespace ngraph
{
{
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
()
const
std
::
vector
<
Graph
>
Attribute
::
get_graph_array
()
const
{
{
return
{
std
::
begin
(
m_attribute_proto
.
graphs
()),
std
::
end
(
m_attribute_proto
.
graphs
())};
return
{
std
::
begin
(
m_attribute_proto
->
graphs
()),
std
::
end
(
m_attribute_proto
->
graphs
())};
}
}
Graph
Attribute
::
get_graph
()
const
{
return
Graph
{
m_attribute_proto
.
g
()};
}
Graph
Attribute
::
get_graph
()
const
{
return
Graph
{
m_attribute_proto
->
g
()};
}
}
// namespace onnx_import
}
// namespace onnx_import
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/frontend/onnx_import/core/attribute.hpp
View file @
394b74fa
...
@@ -38,8 +38,8 @@ namespace ngraph
...
@@ -38,8 +38,8 @@ namespace ngraph
{
{
struct
Attribute
:
ngraph_error
struct
Attribute
:
ngraph_error
{
{
Attribute
(
std
::
string
msg
,
onnx
::
AttributeProto_AttributeType
type
)
Attribute
(
const
std
::
string
&
msg
,
onnx
::
AttributeProto_AttributeType
type
)
:
ngraph_error
{
std
::
move
(
msg
)
+
": "
+
:
ngraph_error
{
msg
+
": "
+
onnx
::
AttributeProto_AttributeType_Name
(
type
)}
onnx
::
AttributeProto_AttributeType_Name
(
type
)}
{
{
}
}
...
@@ -246,7 +246,7 @@ namespace ngraph
...
@@ -246,7 +246,7 @@ namespace ngraph
Attribute
()
=
delete
;
Attribute
()
=
delete
;
explicit
Attribute
(
const
onnx
::
AttributeProto
&
attribute_proto
)
explicit
Attribute
(
const
onnx
::
AttributeProto
&
attribute_proto
)
:
m_attribute_proto
{
attribute_proto
}
:
m_attribute_proto
{
&
attribute_proto
}
{
{
}
}
...
@@ -256,8 +256,8 @@ namespace ngraph
...
@@ -256,8 +256,8 @@ namespace ngraph
Attribute
&
operator
=
(
Attribute
&&
)
noexcept
=
delete
;
Attribute
&
operator
=
(
Attribute
&&
)
noexcept
=
delete
;
Attribute
&
operator
=
(
const
Attribute
&
)
=
delete
;
Attribute
&
operator
=
(
const
Attribute
&
)
=
delete
;
const
std
::
string
&
get_name
()
const
{
return
m_attribute_proto
.
name
();
}
const
std
::
string
&
get_name
()
const
{
return
m_attribute_proto
->
name
();
}
Type
get_type
()
const
{
return
static_cast
<
Type
>
(
m_attribute_proto
.
type
());
}
Type
get_type
()
const
{
return
static_cast
<
Type
>
(
m_attribute_proto
->
type
());
}
bool
is_tensor
()
const
{
return
get_type
()
==
Type
::
tensor
;
}
bool
is_tensor
()
const
{
return
get_type
()
==
Type
::
tensor
;
}
bool
is_tensor_array
()
const
{
return
get_type
()
==
Type
::
tensor_array
;
}
bool
is_tensor_array
()
const
{
return
get_type
()
==
Type
::
tensor_array
;
}
bool
is_float
()
const
{
return
get_type
()
==
Type
::
float_point
;
}
bool
is_float
()
const
{
return
get_type
()
==
Type
::
float_point
;
}
...
@@ -268,50 +268,50 @@ namespace ngraph
...
@@ -268,50 +268,50 @@ namespace ngraph
bool
is_string_array
()
const
{
return
get_type
()
==
Type
::
string_array
;
}
bool
is_string_array
()
const
{
return
get_type
()
==
Type
::
string_array
;
}
bool
is_graph
()
const
{
return
get_type
()
==
Type
::
graph
;
}
bool
is_graph
()
const
{
return
get_type
()
==
Type
::
graph
;
}
bool
is_graph_array
()
const
{
return
get_type
()
==
Type
::
graph_array
;
}
bool
is_graph_array
()
const
{
return
get_type
()
==
Type
::
graph_array
;
}
Tensor
get_tensor
()
const
{
return
Tensor
{
m_attribute_proto
.
t
()};
}
Tensor
get_tensor
()
const
{
return
Tensor
{
m_attribute_proto
->
t
()};
}
float
get_float
()
const
{
return
m_attribute_proto
.
f
();
}
float
get_float
()
const
{
return
m_attribute_proto
->
f
();
}
int64_t
get_integer
()
const
{
return
m_attribute_proto
.
i
();
}
int64_t
get_integer
()
const
{
return
m_attribute_proto
->
i
();
}
const
std
::
string
&
get_string
()
const
{
return
m_attribute_proto
.
s
();
}
const
std
::
string
&
get_string
()
const
{
return
m_attribute_proto
->
s
();
}
Graph
get_graph
()
const
;
Graph
get_graph
()
const
;
std
::
vector
<
Tensor
>
get_tensor_array
()
const
std
::
vector
<
Tensor
>
get_tensor_array
()
const
{
{
return
{
std
::
begin
(
m_attribute_proto
.
tensors
()),
return
{
std
::
begin
(
m_attribute_proto
->
tensors
()),
std
::
end
(
m_attribute_proto
.
tensors
())};
std
::
end
(
m_attribute_proto
->
tensors
())};
}
}
std
::
vector
<
float
>
get_float_array
()
const
std
::
vector
<
float
>
get_float_array
()
const
{
{
return
{
std
::
begin
(
m_attribute_proto
.
floats
()),
return
{
std
::
begin
(
m_attribute_proto
->
floats
()),
std
::
end
(
m_attribute_proto
.
floats
())};
std
::
end
(
m_attribute_proto
->
floats
())};
}
}
std
::
vector
<
int64_t
>
get_integer_array
()
const
std
::
vector
<
int64_t
>
get_integer_array
()
const
{
{
return
{
std
::
begin
(
m_attribute_proto
.
ints
()),
std
::
end
(
m_attribute_proto
.
ints
())};
return
{
std
::
begin
(
m_attribute_proto
->
ints
()),
std
::
end
(
m_attribute_proto
->
ints
())};
}
}
std
::
vector
<
std
::
string
>
get_string_array
()
const
std
::
vector
<
std
::
string
>
get_string_array
()
const
{
{
return
{
std
::
begin
(
m_attribute_proto
.
strings
()),
return
{
std
::
begin
(
m_attribute_proto
->
strings
()),
std
::
end
(
m_attribute_proto
.
strings
())};
std
::
end
(
m_attribute_proto
->
strings
())};
}
}
std
::
vector
<
Graph
>
get_graph_array
()
const
;
std
::
vector
<
Graph
>
get_graph_array
()
const
;
/* explicit */
operator
onnx
::
AttributeProto_AttributeType
()
const
/* explicit */
operator
onnx
::
AttributeProto_AttributeType
()
const
{
{
return
m_attribute_proto
.
type
();
return
m_attribute_proto
->
type
();
}
}
template
<
typename
T
>
template
<
typename
T
>
T
get_value
()
const
T
get_value
()
const
{
{
return
detail
::
attribute
::
get_value
<
T
>
(
m_attribute_proto
);
return
detail
::
attribute
::
get_value
<
T
>
(
*
m_attribute_proto
);
}
}
private
:
private
:
const
onnx
::
AttributeProto
&
m_attribute_proto
;
const
onnx
::
AttributeProto
*
m_attribute_proto
;
};
};
}
// namespace onnx_import
}
// namespace onnx_import
...
...
src/ngraph/frontend/onnx_import/core/graph.cpp
View file @
394b74fa
...
@@ -22,9 +22,9 @@ namespace ngraph
...
@@ -22,9 +22,9 @@ namespace ngraph
namespace
onnx_import
namespace
onnx_import
{
{
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
)
Graph
::
Graph
(
const
onnx
::
GraphProto
&
graph_proto
)
:
m_graph_proto
(
graph_proto
)
:
m_graph_proto
{
&
graph_proto
}
{
{
for
(
const
auto
&
tensor
:
m_graph_proto
.
initializer
())
for
(
const
auto
&
tensor
:
m_graph_proto
->
initializer
())
{
{
if
(
tensor
.
has_name
())
if
(
tensor
.
has_name
())
{
{
...
@@ -33,20 +33,20 @@ namespace ngraph
...
@@ -33,20 +33,20 @@ namespace ngraph
}
}
// Process all ONNX graph inputs, convert them to nGraph nodes and store in cache
// Process all ONNX graph inputs, convert them to nGraph nodes and store in cache
for
(
const
auto
&
input
:
m_graph_proto
.
input
())
for
(
const
auto
&
input
:
m_graph_proto
->
input
())
{
{
m_inputs
.
emplace_back
(
input
);
m_inputs
.
emplace_back
(
input
);
m_ng_node_cache
[
input
.
name
()]
=
m_ng_node_cache
[
input
.
name
()]
=
m_inputs
.
back
().
get_ng_node
(
m_parameters
,
m_initializers
);
m_inputs
.
back
().
get_ng_node
(
m_parameters
,
m_initializers
);
}
}
for
(
const
auto
&
output
:
m_graph_proto
.
output
())
for
(
const
auto
&
output
:
m_graph_proto
->
output
())
{
{
m_outputs
.
emplace_back
(
output
);
m_outputs
.
emplace_back
(
output
);
}
}
// Process ONNX graph nodes, convert to nGraph nodes
// Process ONNX graph nodes, convert to nGraph nodes
for
(
const
auto
&
node_proto
:
m_graph_proto
.
node
())
for
(
const
auto
&
node_proto
:
m_graph_proto
->
node
())
{
{
m_nodes
.
emplace_back
(
node_proto
,
this
);
m_nodes
.
emplace_back
(
node_proto
,
this
);
const
Node
&
node
{
m_nodes
.
back
()};
const
Node
&
node
{
m_nodes
.
back
()};
...
...
src/ngraph/frontend/onnx_import/core/graph.hpp
View file @
394b74fa
...
@@ -42,9 +42,9 @@ namespace ngraph
...
@@ -42,9 +42,9 @@ namespace ngraph
return
m_ng_node_cache
.
at
(
name
);
return
m_ng_node_cache
.
at
(
name
);
}
}
const
std
::
string
&
get_name
()
const
{
return
m_graph_proto
.
name
();
}
const
std
::
string
&
get_name
()
const
{
return
m_graph_proto
->
name
();
}
private
:
private
:
const
onnx
::
GraphProto
&
m_graph_proto
;
const
onnx
::
GraphProto
*
m_graph_proto
;
std
::
vector
<
Node
>
m_nodes
;
std
::
vector
<
Node
>
m_nodes
;
std
::
vector
<
ValueInfo
>
m_inputs
;
std
::
vector
<
ValueInfo
>
m_inputs
;
std
::
vector
<
ValueInfo
>
m_outputs
;
std
::
vector
<
ValueInfo
>
m_outputs
;
...
...
src/ngraph/frontend/onnx_import/core/model.hpp
View file @
394b74fa
...
@@ -28,7 +28,7 @@ namespace ngraph
...
@@ -28,7 +28,7 @@ namespace ngraph
public
:
public
:
Model
()
=
delete
;
Model
()
=
delete
;
explicit
Model
(
const
onnx
::
ModelProto
&
model_proto
)
explicit
Model
(
const
onnx
::
ModelProto
&
model_proto
)
:
m_model_proto
{
model_proto
}
:
m_model_proto
{
&
model_proto
}
{
{
}
}
...
@@ -38,16 +38,16 @@ namespace ngraph
...
@@ -38,16 +38,16 @@ namespace ngraph
Model
&
operator
=
(
Model
&&
)
noexcept
=
delete
;
Model
&
operator
=
(
Model
&&
)
noexcept
=
delete
;
Model
&
operator
=
(
const
Model
&
)
=
delete
;
Model
&
operator
=
(
const
Model
&
)
=
delete
;
const
std
::
string
&
get_producer_name
()
const
{
return
m_model_proto
.
producer_name
();
}
const
std
::
string
&
get_producer_name
()
const
{
return
m_model_proto
->
producer_name
();
}
const
onnx
::
GraphProto
&
get_graph
()
const
{
return
m_model_proto
.
graph
();
}
const
onnx
::
GraphProto
&
get_graph
()
const
{
return
m_model_proto
->
graph
();
}
std
::
int64_t
get_model_version
()
const
{
return
m_model_proto
.
model_version
();
}
std
::
int64_t
get_model_version
()
const
{
return
m_model_proto
->
model_version
();
}
const
std
::
string
&
get_producer_version
()
const
const
std
::
string
&
get_producer_version
()
const
{
{
return
m_model_proto
.
producer_version
();
return
m_model_proto
->
producer_version
();
}
}
private
:
private
:
const
onnx
::
ModelProto
&
m_model_proto
;
const
onnx
::
ModelProto
*
m_model_proto
;
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
outs
,
const
Model
&
model
)
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
outs
,
const
Model
&
model
)
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
394b74fa
...
@@ -26,7 +26,7 @@ namespace ngraph
...
@@ -26,7 +26,7 @@ namespace ngraph
NodeVector
Node
::
get_ng_inputs
()
const
NodeVector
Node
::
get_ng_inputs
()
const
{
{
NodeVector
result
;
NodeVector
result
;
for
(
const
auto
&
name
:
m_node_proto
.
input
())
for
(
const
auto
&
name
:
m_node_proto
->
input
())
{
{
result
.
push_back
(
m_graph
->
get_ng_node_from_cache
(
name
));
result
.
push_back
(
m_graph
->
get_ng_node_from_cache
(
name
));
}
}
...
...
src/ngraph/frontend/onnx_import/core/node.hpp
View file @
394b74fa
...
@@ -53,7 +53,7 @@ namespace ngraph
...
@@ -53,7 +53,7 @@ namespace ngraph
public
:
public
:
Node
()
=
delete
;
Node
()
=
delete
;
Node
(
const
onnx
::
NodeProto
&
node_proto
,
const
Graph
*
graph
)
Node
(
const
onnx
::
NodeProto
&
node_proto
,
const
Graph
*
graph
)
:
m_node_proto
{
node_proto
}
:
m_node_proto
{
&
node_proto
}
,
m_graph
{
graph
}
,
m_graph
{
graph
}
,
m_attributes
{
std
::
begin
(
node_proto
.
attribute
()),
std
::
end
(
node_proto
.
attribute
())}
,
m_attributes
{
std
::
begin
(
node_proto
.
attribute
()),
std
::
end
(
node_proto
.
attribute
())}
,
m_output_names
{
std
::
begin
(
node_proto
.
output
()),
std
::
end
(
node_proto
.
output
())}
,
m_output_names
{
std
::
begin
(
node_proto
.
output
()),
std
::
end
(
node_proto
.
output
())}
...
@@ -70,13 +70,13 @@ namespace ngraph
...
@@ -70,13 +70,13 @@ namespace ngraph
NodeVector
get_ng_nodes
()
const
;
NodeVector
get_ng_nodes
()
const
;
NodeVector
get_ng_inputs
()
const
;
NodeVector
get_ng_inputs
()
const
;
const
std
::
string
&
op_type
()
const
{
return
m_node_proto
.
op_type
();
}
const
std
::
string
&
op_type
()
const
{
return
m_node_proto
->
op_type
();
}
const
std
::
string
&
get_name
()
const
{
return
m_node_proto
.
name
();
}
const
std
::
string
&
get_name
()
const
{
return
m_node_proto
->
name
();
}
const
std
::
vector
<
std
::
reference_wrapper
<
const
std
::
string
>>&
get_output_names
()
const
const
std
::
vector
<
std
::
reference_wrapper
<
const
std
::
string
>>&
get_output_names
()
const
{
{
return
m_output_names
;
return
m_output_names
;
}
}
const
std
::
string
&
output
(
int
index
)
const
{
return
m_node_proto
.
output
(
index
);
}
const
std
::
string
&
output
(
int
index
)
const
{
return
m_node_proto
->
output
(
index
);
}
template
<
typename
T
>
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
{
{
...
@@ -106,7 +106,7 @@ namespace ngraph
...
@@ -106,7 +106,7 @@ namespace ngraph
}
}
private
:
private
:
const
onnx
::
NodeProto
&
m_node_proto
;
const
onnx
::
NodeProto
*
m_node_proto
;
const
Graph
*
m_graph
;
const
Graph
*
m_graph
;
std
::
vector
<
Attribute
>
m_attributes
;
std
::
vector
<
Attribute
>
m_attributes
;
std
::
vector
<
std
::
reference_wrapper
<
const
std
::
string
>>
m_output_names
;
std
::
vector
<
std
::
reference_wrapper
<
const
std
::
string
>>
m_output_names
;
...
...
src/ngraph/frontend/onnx_import/core/tensor.hpp
View file @
394b74fa
...
@@ -198,7 +198,7 @@ namespace ngraph
...
@@ -198,7 +198,7 @@ namespace ngraph
Tensor
()
=
delete
;
Tensor
()
=
delete
;
explicit
Tensor
(
const
onnx
::
TensorProto
&
tensor
)
explicit
Tensor
(
const
onnx
::
TensorProto
&
tensor
)
:
m_tensor_proto
{
tensor
}
:
m_tensor_proto
{
&
tensor
}
,
m_shape
{
std
::
begin
(
tensor
.
dims
()),
std
::
end
(
tensor
.
dims
())}
,
m_shape
{
std
::
begin
(
tensor
.
dims
()),
std
::
end
(
tensor
.
dims
())}
{
{
}
}
...
@@ -213,34 +213,34 @@ namespace ngraph
...
@@ -213,34 +213,34 @@ namespace ngraph
template
<
typename
T
>
template
<
typename
T
>
std
::
vector
<
T
>
get_data
()
const
std
::
vector
<
T
>
get_data
()
const
{
{
return
detail
::
tensor
::
get_data
<
T
>
(
m_tensor_proto
);
return
detail
::
tensor
::
get_data
<
T
>
(
*
m_tensor_proto
);
}
}
const
std
::
string
&
get_name
()
const
const
std
::
string
&
get_name
()
const
{
{
if
(
!
m_tensor_proto
.
has_name
())
if
(
!
m_tensor_proto
->
has_name
())
{
{
throw
error
::
tensor
::
unspecified_name
{};
throw
error
::
tensor
::
unspecified_name
{};
}
}
return
m_tensor_proto
.
name
();
return
m_tensor_proto
->
name
();
}
}
Type
get_type
()
const
Type
get_type
()
const
{
{
if
(
!
m_tensor_proto
.
has_data_type
())
if
(
!
m_tensor_proto
->
has_data_type
())
{
{
throw
error
::
tensor
::
unspecified_data_type
{};
throw
error
::
tensor
::
unspecified_data_type
{};
}
}
return
static_cast
<
Type
>
(
m_tensor_proto
.
data_type
());
return
static_cast
<
Type
>
(
m_tensor_proto
->
data_type
());
}
}
const
element
::
Type
&
get_ng_type
()
const
const
element
::
Type
&
get_ng_type
()
const
{
{
if
(
!
m_tensor_proto
.
has_data_type
())
if
(
!
m_tensor_proto
->
has_data_type
())
{
{
throw
error
::
tensor
::
unspecified_data_type
{};
throw
error
::
tensor
::
unspecified_data_type
{};
}
}
switch
(
m_tensor_proto
.
data_type
())
switch
(
m_tensor_proto
->
data_type
())
{
{
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
...
@@ -254,13 +254,13 @@ namespace ngraph
...
@@ -254,13 +254,13 @@ namespace ngraph
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
default
:
throw
error
::
tensor
::
unsupported_data_type
{
m_tensor_proto
.
data_type
()};
default
:
throw
error
::
tensor
::
unsupported_data_type
{
m_tensor_proto
->
data_type
()};
}
}
}
}
operator
onnx
::
TensorProto_DataType
()
const
{
return
m_tensor_proto
.
data_type
();
}
operator
onnx
::
TensorProto_DataType
()
const
{
return
m_tensor_proto
->
data_type
();
}
private
:
private
:
const
onnx
::
TensorProto
&
m_tensor_proto
;
const
onnx
::
TensorProto
*
m_tensor_proto
;
Shape
m_shape
;
Shape
m_shape
;
};
};
...
...
src/ngraph/frontend/onnx_import/core/value_info.hpp
View file @
394b74fa
...
@@ -60,7 +60,7 @@ namespace ngraph
...
@@ -60,7 +60,7 @@ namespace ngraph
ValueInfo
()
=
delete
;
ValueInfo
()
=
delete
;
explicit
ValueInfo
(
const
onnx
::
ValueInfoProto
&
value_info_proto
)
explicit
ValueInfo
(
const
onnx
::
ValueInfoProto
&
value_info_proto
)
:
m_value_info_proto
{
value_info_proto
}
:
m_value_info_proto
{
&
value_info_proto
}
{
{
if
(
value_info_proto
.
type
().
has_tensor_type
())
if
(
value_info_proto
.
type
().
has_tensor_type
())
{
{
...
@@ -74,15 +74,15 @@ namespace ngraph
...
@@ -74,15 +74,15 @@ namespace ngraph
ValueInfo
&
operator
=
(
const
ValueInfo
&
)
=
delete
;
ValueInfo
&
operator
=
(
const
ValueInfo
&
)
=
delete
;
ValueInfo
&
operator
=
(
ValueInfo
&&
)
=
delete
;
ValueInfo
&
operator
=
(
ValueInfo
&&
)
=
delete
;
const
std
::
string
&
get_name
()
const
{
return
m_value_info_proto
.
name
();
}
const
std
::
string
&
get_name
()
const
{
return
m_value_info_proto
->
name
();
}
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
const
Shape
&
get_shape
()
const
{
return
m_shape
;
}
const
element
::
Type
&
get_element_type
()
const
const
element
::
Type
&
get_element_type
()
const
{
{
if
(
!
m_value_info_proto
.
type
().
tensor_type
().
has_elem_type
())
if
(
!
m_value_info_proto
->
type
().
tensor_type
().
has_elem_type
())
{
{
throw
error
::
value_info
::
unspecified_element_type
{};
throw
error
::
value_info
::
unspecified_element_type
{};
}
}
switch
(
m_value_info_proto
.
type
().
tensor_type
().
elem_type
())
switch
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
())
{
{
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
...
@@ -98,7 +98,7 @@ namespace ngraph
...
@@ -98,7 +98,7 @@ namespace ngraph
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
default
:
default
:
throw
error
::
value_info
::
unsupported_element_type
{
throw
error
::
value_info
::
unsupported_element_type
{
m_value_info_proto
.
type
().
tensor_type
().
elem_type
()};
m_value_info_proto
->
type
().
tensor_type
().
elem_type
()};
}
}
}
}
...
@@ -126,7 +126,7 @@ namespace ngraph
...
@@ -126,7 +126,7 @@ namespace ngraph
std
::
shared_ptr
<
op
::
Constant
>
get_ng_constant
(
const
Tensor
&
tensor
)
const
std
::
shared_ptr
<
op
::
Constant
>
get_ng_constant
(
const
Tensor
&
tensor
)
const
{
{
switch
(
m_value_info_proto
.
type
().
tensor_type
().
elem_type
())
switch
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
())
{
{
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
make_ng_constant
<
bool
>
(
element
::
boolean
,
tensor
);
return
make_ng_constant
<
bool
>
(
element
::
boolean
,
tensor
);
...
@@ -153,7 +153,7 @@ namespace ngraph
...
@@ -153,7 +153,7 @@ namespace ngraph
return
make_ng_constant
<
uint64_t
>
(
element
::
u64
,
tensor
);
return
make_ng_constant
<
uint64_t
>
(
element
::
u64
,
tensor
);
default
:
default
:
throw
error
::
value_info
::
unsupported_element_type
{
throw
error
::
value_info
::
unsupported_element_type
{
m_value_info_proto
.
type
().
tensor_type
().
elem_type
()};
m_value_info_proto
->
type
().
tensor_type
().
elem_type
()};
}
}
}
}
...
@@ -165,7 +165,7 @@ namespace ngraph
...
@@ -165,7 +165,7 @@ namespace ngraph
}
}
private
:
private
:
const
onnx
::
ValueInfoProto
&
m_value_info_proto
;
const
onnx
::
ValueInfoProto
*
m_value_info_proto
;
Shape
m_shape
;
Shape
m_shape
;
};
};
...
...
test/CMakeLists.txt
View file @
394b74fa
...
@@ -47,15 +47,7 @@ set(SRC
...
@@ -47,15 +47,7 @@ set(SRC
)
)
if
(
NGRAPH_ONNX_IMPORT_ENABLE
)
if
(
NGRAPH_ONNX_IMPORT_ENABLE
)
if
(
APPLE OR WIN32
)
list
(
APPEND SRC onnx_import.cpp
)
list
(
APPEND SRC onnx_import.cpp
)
else
()
# ONNX unit tests temporarly disabled if CentOS detected
# (Protobuf issue with interpreting messages)
if
(
NOT
${
DISTRIB_ID
}
STREQUAL
"CentOS Linux"
)
list
(
APPEND SRC onnx_import.cpp
)
endif
()
endif
()
endif
()
endif
()
if
(
NGRAPH_INTERPRETER_ENABLE
)
if
(
NGRAPH_INTERPRETER_ENABLE
)
...
...
test/onnx_import.cpp
View file @
394b74fa
...
@@ -56,12 +56,12 @@ TEST(onnx, model_add_abc_initializers)
...
@@ -56,12 +56,12 @@ TEST(onnx, model_add_abc_initializers)
TEST
(
onnx
,
model_addmul_abc
)
TEST
(
onnx
,
model_addmul_abc
)
{
{
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
auto
function
=
onnx_import
::
import_onnx_function
(
ngraph
::
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/addmul_abc.onnx"
));
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/addmul_abc.onnx"
));
std
::
vector
<
std
::
vector
<
float
>>
inputs
;
std
::
vector
<
std
::
vector
<
float
>>
inputs
;
ngraph
::
Shape
shape
{
1
,
2
,
2
};
Shape
shape
{
1
,
2
,
2
};
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
9
,
10
}},
{{
11
,
12
}}}).
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
9
,
10
}},
{{
11
,
12
}}}).
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
5
,
6
}},
{{
7
,
8
}}}).
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
5
,
6
}},
{{
7
,
8
}}}).
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
1
,
2
}},
{{
3
,
4
}}}).
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
({{{
1
,
2
}},
{{
3
,
4
}}}).
get_vector
());
...
@@ -124,8 +124,7 @@ TEST(onnx, model_split_variable_parts_2d)
...
@@ -124,8 +124,7 @@ TEST(onnx, model_split_variable_parts_2d)
namespace
namespace
{
{
std
::
vector
<
std
::
vector
<
float
>>
std
::
vector
<
std
::
vector
<
float
>>
conv2d_execute
(
const
std
::
shared_ptr
<
Function
>&
function
)
conv2d_execute
(
const
std
::
shared_ptr
<
ngraph
::
Function
>&
function
)
{
{
std
::
vector
<
std
::
vector
<
float
>>
args
;
std
::
vector
<
std
::
vector
<
float
>>
args
;
...
@@ -151,8 +150,8 @@ namespace
...
@@ -151,8 +150,8 @@ namespace
TEST
(
onnx
,
model_conv2d_strides_padding
)
TEST
(
onnx
,
model_conv2d_strides_padding
)
{
{
// Convolution with strides=2 and padding=1
// Convolution with strides=2 and padding=1
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
auto
function
=
onnx_import
::
import_onnx_function
(
ngraph
::
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_padding.onnx"
));
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_padding.onnx"
));
// (1, 1, 4, 3)
// (1, 1, 4, 3)
auto
expected_output
=
test
::
NDArray
<
float
,
4
>
({{{{
12.
f
,
27.
f
,
24.
f
},
auto
expected_output
=
test
::
NDArray
<
float
,
4
>
({{{{
12.
f
,
27.
f
,
24.
f
},
...
@@ -168,8 +167,8 @@ TEST(onnx, model_conv2d_strides_padding)
...
@@ -168,8 +167,8 @@ TEST(onnx, model_conv2d_strides_padding)
TEST
(
onnx
,
model_conv2d_strides_no_padding
)
TEST
(
onnx
,
model_conv2d_strides_no_padding
)
{
{
// Convolution with strides=2 and padding=1
// Convolution with strides=2 and padding=1
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
auto
function
=
onnx_import
::
import_onnx_function
(
ngraph
::
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_no_padding.onnx"
));
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_no_padding.onnx"
));
// (1, 1, 3, 2)
// (1, 1, 3, 2)
auto
expected_output
=
auto
expected_output
=
...
@@ -182,8 +181,8 @@ TEST(onnx, model_conv2d_strides_no_padding)
...
@@ -182,8 +181,8 @@ TEST(onnx, model_conv2d_strides_no_padding)
TEST
(
onnx
,
model_conv2d_strides_assymetric_padding
)
TEST
(
onnx
,
model_conv2d_strides_assymetric_padding
)
{
{
// Convolution with strides=2 and padding=1
// Convolution with strides=2 and padding=1
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
ngraph
::
file_util
::
path_joi
n
(
auto
function
=
onnx_import
::
import_onnx_functio
n
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_and_asymmetric_padding.onnx"
));
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/conv_with_strides_and_asymmetric_padding.onnx"
));
// (1, 1, 4, 2)
// (1, 1, 4, 2)
auto
expected_output
=
auto
expected_output
=
...
@@ -297,8 +296,8 @@ TEST(onnx, model_batchnorm_default)
...
@@ -297,8 +296,8 @@ TEST(onnx, model_batchnorm_default)
TEST
(
onnx
,
model_relu
)
TEST
(
onnx
,
model_relu
)
{
{
// Simple ReLU test
// Simple ReLU test
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
auto
function
=
ngraph
::
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/relu.onnx"
));
onnx_import
::
import_onnx_function
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/relu.onnx"
));
Inputs
inputs
{{
-
1
,
-
2
,
0
,
1
,
2
,
3
}};
Inputs
inputs
{{
-
1
,
-
2
,
0
,
1
,
2
,
3
}};
Outputs
expected_outputs
{{
0
,
0
,
0
,
1
,
2
,
3
}};
Outputs
expected_outputs
{{
0
,
0
,
0
,
1
,
2
,
3
}};
...
@@ -385,11 +384,10 @@ TEST(onnx, model_mean)
...
@@ -385,11 +384,10 @@ TEST(onnx, model_mean)
TEST
(
onnx
,
model_gemm_abc
)
TEST
(
onnx
,
model_gemm_abc
)
{
{
auto
function
=
ngraph
::
onnx_import
::
import_onnx_function
(
auto
function
=
onnx_import
::
import_onnx_function
(
ngraph
::
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/gemm_abc.onnx"
));
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/gemm_abc.onnx"
));
std
::
vector
<
std
::
vector
<
float
>>
inputs
;
Inputs
inputs
;
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
2
>
(
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
2
>
(
{{
1
,
2
,
3
,
4
,
5
,
6
},
{
7
,
8
,
9
,
10
,
11
,
12
},
{
13
,
14
,
15
,
16
,
17
,
18
}})
{{
1
,
2
,
3
,
4
,
5
,
6
},
{
7
,
8
,
9
,
10
,
11
,
12
},
{
13
,
14
,
15
,
16
,
17
,
18
}})
.
get_vector
());
.
get_vector
());
...
@@ -405,13 +403,13 @@ TEST(onnx, model_gemm_abc)
...
@@ -405,13 +403,13 @@ TEST(onnx, model_gemm_abc)
inputs
.
emplace_back
(
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
2
>
({{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
}}).
get_vector
());
test
::
NDArray
<
float
,
2
>
({{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
}}).
get_vector
());
auto
expected_output
=
Outputs
expected_outputs
{
test
::
NDArray
<
float
,
2
>
(
test
::
NDArray
<
float
,
2
>
(
{{
340
,
350.5
,
361
,
371.5
},
{
862
,
890.5
,
919
,
947.5
},
{
1384
,
1430.5
,
1477
,
1523.5
}})
{{
340
,
350.5
,
361
,
371.5
},
{
862
,
890.5
,
919
,
947.5
},
{
1384
,
1430.5
,
1477
,
1523.5
}})
.
get_vector
();
.
get_vector
()
}
;
auto
result_vectors
=
execute
(
function
,
inputs
,
"INTERPRETER"
)
;
Outputs
outputs
{
execute
(
function
,
inputs
,
"INTERPRETER"
)}
;
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
,
result_vector
s
.
front
()));
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
s
.
front
(),
output
s
.
front
()));
}
}
TEST
(
onnx
,
model_matmul
)
TEST
(
onnx
,
model_matmul
)
...
@@ -428,11 +426,11 @@ TEST(onnx, model_matmul)
...
@@ -428,11 +426,11 @@ TEST(onnx, model_matmul)
test
::
NDArray
<
float
,
2
>
({{
13
,
14
,
15
},
{
16
,
17
,
18
},
{
19
,
20
,
21
},
{
22
,
23
,
24
}})
test
::
NDArray
<
float
,
2
>
({{
13
,
14
,
15
},
{
16
,
17
,
18
},
{
19
,
20
,
21
},
{
22
,
23
,
24
}})
.
get_vector
());
.
get_vector
());
auto
expected_output
=
Outputs
expected_outputs
{
test
::
NDArray
<
float
,
2
>
({{
190
,
200
,
210
},
{
470
,
496
,
522
},
{
750
,
792
,
834
}}).
get_vector
();
test
::
NDArray
<
float
,
2
>
({{
190
,
200
,
210
},
{
470
,
496
,
522
},
{
750
,
792
,
834
}}).
get_vector
()
}
;
auto
result_vectors
=
execute
(
function
,
inputs
,
"INTERPRETER"
)
;
Outputs
outputs
{
execute
(
function
,
inputs
,
"INTERPRETER"
)}
;
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
,
result_vector
s
.
front
()));
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
s
.
front
(),
output
s
.
front
()));
}
}
TEST
(
onnx
,
model_softmax
)
TEST
(
onnx
,
model_softmax
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment