Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
31ee5658
Unverified
Commit
31ee5658
authored
6 years ago
by
Michał Karzyński
Committed by
GitHub
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Attribute helper functions (#1468)
parent
33f4f394
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
6 deletions
+76
-6
attribute.hpp
src/ngraph/frontend/onnx_import/attribute.hpp
+17
-6
node.hpp
src/ngraph/frontend/onnx_import/node.hpp
+59
-0
No files found.
src/ngraph/frontend/onnx_import/attribute.hpp
View file @
31ee5658
...
...
@@ -80,11 +80,12 @@ namespace ngraph
template
<>
inline
float
get_value
(
const
onnx
::
AttributeProto
&
attribute
)
{
if
(
unlikely
(
attribute
.
type
()
!=
onnx
::
AttributeProto_AttributeType_FLOAT
))
switch
(
attribute
.
type
(
))
{
throw
error
::
attribute
::
InvalidData
{
attribute
.
type
()};
case
onnx
:
:
AttributeProto_AttributeType_INT
:
return
attribute
.
i
();
case
onnx
:
:
AttributeProto_AttributeType_FLOAT
:
return
attribute
.
f
();
default
:
throw
error
::
attribute
::
InvalidData
{
attribute
.
type
()};
}
return
attribute
.
f
();
}
template
<>
...
...
@@ -92,6 +93,10 @@ namespace ngraph
{
switch
(
attribute
.
type
())
{
case
onnx
:
:
AttributeProto_AttributeType_INT
:
return
{
static_cast
<
float
>
(
attribute
.
i
())};
case
onnx
:
:
AttributeProto_AttributeType_INTS
:
return
{
std
::
begin
(
attribute
.
floats
()),
std
::
end
(
attribute
.
floats
())};
case
onnx
:
:
AttributeProto_AttributeType_FLOAT
:
return
{
attribute
.
f
()};
case
onnx
:
:
AttributeProto_AttributeType_FLOATS
:
return
{
std
::
begin
(
attribute
.
floats
()),
std
::
end
(
attribute
.
floats
())};
...
...
@@ -102,11 +107,13 @@ namespace ngraph
template
<>
inline
double
get_value
(
const
onnx
::
AttributeProto
&
attribute
)
{
if
(
unlikely
(
attribute
.
type
()
!=
onnx
::
AttributeProto_AttributeType_FLOAT
))
switch
(
attribute
.
type
(
))
{
throw
error
::
attribute
::
InvalidData
{
attribute
.
type
()};
case
onnx
:
:
AttributeProto_AttributeType_FLOAT
:
return
static_cast
<
double
>
(
attribute
.
f
());
case
onnx
:
:
AttributeProto_AttributeType_INT
:
return
attribute
.
i
();
default
:
throw
error
::
attribute
::
InvalidData
{
attribute
.
type
()};
}
return
static_cast
<
double
>
(
attribute
.
f
());
}
template
<>
...
...
@@ -114,6 +121,10 @@ namespace ngraph
{
switch
(
attribute
.
type
())
{
case
onnx
:
:
AttributeProto_AttributeType_INT
:
return
{
static_cast
<
double
>
(
attribute
.
i
())};
case
onnx
:
:
AttributeProto_AttributeType_INTS
:
return
{
std
::
begin
(
attribute
.
ints
()),
std
::
end
(
attribute
.
ints
())};
case
onnx
:
:
AttributeProto_AttributeType_FLOAT
:
return
{
static_cast
<
double
>
(
attribute
.
f
())};
case
onnx
:
:
AttributeProto_AttributeType_FLOATS
:
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/frontend/onnx_import/node.hpp
View file @
31ee5658
...
...
@@ -119,6 +119,65 @@ namespace ngraph
return
(
outs
<<
"<Node("
<<
node
.
op_type
()
<<
"): "
<<
node
.
get_name
()
<<
">"
);
}
namespace
attribute
{
/**
* @brief Get shape of kernel (filter) in pixels.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline
Shape
get_kernel_shape
(
const
Node
&
node
)
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
"kernel_shape"
,
{
1
,
1
});
}
namespace
detail
{
inline
Strides
get_strides_helper
(
const
Node
&
node
,
const
std
::
string
&
name
,
const
Shape
&
kernel_shape
)
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
name
,
std
::
vector
<
std
::
size_t
>
(
kernel_shape
.
size
(),
1UL
));
}
}
// namespace detail
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @param kernel_shape The shape of the kernel which we retrieve strides for.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline
Strides
get_strides
(
const
Node
&
node
,
const
Shape
&
kernel_shape
)
{
return
detail
::
get_strides_helper
(
node
,
"strides"
,
kernel_shape
);
}
/**
* @brief Get number of pixels to stride operation by in each direction.
*
* @param node The Node ptr representing Conv or Pool operation.
* @return The kernel Shape object representing its dimensions (height, width, depth).
*/
inline
Strides
get_strides
(
const
Node
&
node
)
{
return
get_strides
(
node
,
get_kernel_shape
(
node
));
}
/**
* @brief Get number of pixels for filter dilation in each direction.
*
* @param node The Node ptr representing ONNX operation.
* @return The Strides object containing number of pixels for filter dilation
* (height, width, depth).
*/
inline
Strides
get_dilations
(
const
Node
&
node
)
{
return
detail
::
get_strides_helper
(
node
,
"dilations"
,
get_kernel_shape
(
node
));
}
}
// namespace attribute
}
// namespace onnx_import
}
// namespace ngraph
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment