Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
cc989301
Commit
cc989301
authored
6 years ago
by
Artur Wojcik
Committed by
Michał Karzyński
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Tensor: add support for raw_data (#1552)
parent
3ffccb33
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
0 deletions
+49
-0
tensor.hpp
src/ngraph/frontend/onnx_import/core/tensor.hpp
+49
-0
No files found.
src/ngraph/frontend/onnx_import/core/tensor.hpp
View file @
cc989301
...
...
@@ -65,6 +65,22 @@ namespace ngraph
}
};
struct
data_type_undefined
:
ngraph_error
{
data_type_undefined
()
:
ngraph_error
{
"data type is not defined"
}
{
}
};
struct
segments_unsupported
:
ngraph_error
{
segments_unsupported
()
:
ngraph_error
{
"loading segments not supported"
}
{
}
};
}
// namespace tensor
}
// namespace error
...
...
@@ -82,6 +98,13 @@ namespace ngraph
{
return
{
std
::
begin
(
container
),
std
::
end
(
container
)};
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
__get_raw_data
(
const
std
::
string
&
raw_data
)
{
auto
it
=
reinterpret_cast
<
const
T
*>
(
raw_data
.
data
());
return
{
it
,
it
+
(
raw_data
.
size
()
/
sizeof
(
T
))};
}
}
}
...
...
@@ -94,6 +117,10 @@ namespace ngraph
template
<>
inline
std
::
vector
<
double
>
get_data
(
const
onnx
::
TensorProto
&
tensor
)
{
if
(
tensor
.
has_raw_data
())
{
return
detail
::
__get_raw_data
<
double
>
(
tensor
.
raw_data
());
}
if
(
tensor
.
data_type
()
==
onnx
::
TensorProto_DataType_DOUBLE
)
{
return
detail
::
__get_data
<
double
>
(
tensor
.
double_data
());
...
...
@@ -121,6 +148,10 @@ namespace ngraph
template
<>
inline
std
::
vector
<
float
>
get_data
(
const
onnx
::
TensorProto
&
tensor
)
{
if
(
tensor
.
has_raw_data
())
{
return
detail
::
__get_raw_data
<
float
>
(
tensor
.
raw_data
());
}
if
((
tensor
.
data_type
()
==
onnx
::
TensorProto_DataType_FLOAT
)
or
(
tensor
.
data_type
()
==
onnx
::
TensorProto_DataType_FLOAT16
))
{
...
...
@@ -144,6 +175,10 @@ namespace ngraph
template
<>
inline
std
::
vector
<
int32_t
>
get_data
(
const
onnx
::
TensorProto
&
tensor
)
{
if
(
tensor
.
has_raw_data
())
{
return
detail
::
__get_raw_data
<
int32_t
>
(
tensor
.
raw_data
());
}
if
(
tensor
.
data_type
()
==
onnx
::
TensorProto_DataType_INT32
)
{
return
detail
::
__get_data
<
int32_t
>
(
tensor
.
int32_data
());
...
...
@@ -154,6 +189,10 @@ namespace ngraph
template
<>
inline
std
::
vector
<
int64_t
>
get_data
(
const
onnx
::
TensorProto
&
tensor
)
{
if
(
tensor
.
has_raw_data
())
{
return
detail
::
__get_raw_data
<
int64_t
>
(
tensor
.
raw_data
());
}
if
(
tensor
.
data_type
()
!=
onnx
::
TensorProto_DataType_INT64
)
{
throw
error
::
tensor
::
invalid_data_type
{
tensor
.
data_type
()};
...
...
@@ -164,6 +203,10 @@ namespace ngraph
template
<>
inline
std
::
vector
<
uint64_t
>
get_data
(
const
onnx
::
TensorProto
&
tensor
)
{
if
(
tensor
.
has_raw_data
())
{
return
detail
::
__get_raw_data
<
uint64_t
>
(
tensor
.
raw_data
());
}
if
(
tensor
.
data_type
()
!=
onnx
::
TensorProto_DataType_UINT64
)
{
throw
error
::
tensor
::
invalid_data_type
{
tensor
.
data_type
()};
...
...
@@ -213,6 +256,10 @@ namespace ngraph
template
<
typename
T
>
std
::
vector
<
T
>
get_data
()
const
{
if
(
m_tensor_proto
->
has_segment
())
{
throw
error
::
tensor
::
segments_unsupported
{};
}
return
detail
::
tensor
::
get_data
<
T
>
(
*
m_tensor_proto
);
}
...
...
@@ -254,6 +301,8 @@ namespace ngraph
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UNDEFINED
:
throw
error
::
tensor
::
data_type_undefined
{};
default
:
throw
error
::
tensor
::
unsupported_data_type
{
m_tensor_proto
->
data_type
()};
}
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment