Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8ef1ec04
Unverified
Commit
8ef1ec04
authored
Oct 30, 2018
by
Michał Karzyński
Committed by
GitHub
Oct 30, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Support for legacy broadcasting rules (#1924)
parent
c637d629
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
141 additions
and
1 deletion
+141
-1
add.hpp
src/ngraph/frontend/onnx_import/op/add.hpp
+13
-0
div.hpp
src/ngraph/frontend/onnx_import/op/div.hpp
+13
-0
mul.hpp
src/ngraph/frontend/onnx_import/op/mul.hpp
+16
-1
sub.hpp
src/ngraph/frontend/onnx_import/op/sub.hpp
+14
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+4
-0
broadcasting.cpp
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
+61
-0
broadcasting.hpp
src/ngraph/frontend/onnx_import/utils/broadcasting.hpp
+20
-0
No files found.
src/ngraph/frontend/onnx_import/op/add.hpp
View file @
8ef1ec04
...
...
@@ -29,6 +29,19 @@ namespace ngraph
namespace
op
{
namespace
set_1
{
inline
NodeVector
add
(
const
Node
&
node
)
{
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
NodeVector
ng_inputs
{
legacy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
().
at
(
0
),
node
.
get_ng_inputs
().
at
(
1
),
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
}
// namespace set_1
namespace
set_7
{
inline
NodeVector
add
(
const
Node
&
node
)
{
...
...
src/ngraph/frontend/onnx_import/op/div.hpp
View file @
8ef1ec04
...
...
@@ -29,6 +29,19 @@ namespace ngraph
namespace
op
{
namespace
set_1
{
inline
NodeVector
div
(
const
Node
&
node
)
{
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
NodeVector
ng_inputs
{
legacy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
().
at
(
0
),
node
.
get_ng_inputs
().
at
(
1
),
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Divide
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
}
// namespace set_1
namespace
set_7
{
inline
NodeVector
div
(
const
Node
&
node
)
{
...
...
src/ngraph/frontend/onnx_import/op/mul.hpp
View file @
8ef1ec04
...
...
@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
...
...
@@ -29,6 +30,20 @@ namespace ngraph
namespace
op
{
namespace
set_1
{
inline
NodeVector
mul
(
const
Node
&
node
)
{
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
NodeVector
ng_inputs
{
legacy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
().
at
(
0
),
node
.
get_ng_inputs
().
at
(
1
),
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
}
// namespace set_1
namespace
set_7
{
inline
NodeVector
mul
(
const
Node
&
node
)
{
...
...
@@ -38,7 +53,7 @@ namespace ngraph
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
}
// namespace set_
1
}
// namespace set_
7
}
//namespace op
...
...
src/ngraph/frontend/onnx_import/op/sub.hpp
View file @
8ef1ec04
...
...
@@ -29,6 +29,20 @@ namespace ngraph
namespace
op
{
namespace
set_1
{
inline
NodeVector
sub
(
const
Node
&
node
)
{
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
NodeVector
ng_inputs
{
legacy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
().
at
(
0
),
node
.
get_ng_inputs
().
at
(
1
),
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Subtract
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
}
// namespace set_1
namespace
set_7
{
inline
NodeVector
sub
(
const
Node
&
node
)
{
...
...
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
8ef1ec04
...
...
@@ -146,6 +146,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Abs"
,
1
,
abs
);
REGISTER_OPERATOR
(
"Acos"
,
1
,
acos
);
REGISTER_OPERATOR
(
"Add"
,
1
,
add
);
REGISTER_OPERATOR
(
"Add"
,
7
,
add
);
REGISTER_OPERATOR
(
"And"
,
1
,
logical_and
);
REGISTER_OPERATOR
(
"ArgMin"
,
1
,
argmin
);
REGISTER_OPERATOR
(
"ArgMax"
,
1
,
argmax
);
...
...
@@ -161,6 +162,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Conv"
,
1
,
conv
);
REGISTER_OPERATOR
(
"Cos"
,
1
,
cos
);
REGISTER_OPERATOR
(
"Div"
,
1
,
div
);
REGISTER_OPERATOR
(
"Div"
,
7
,
div
);
REGISTER_OPERATOR
(
"Dropout"
,
1
,
identity
);
REGISTER_OPERATOR
(
"Elu"
,
1
,
elu
);
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
...
...
@@ -184,6 +186,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Mean"
,
1
,
mean
);
REGISTER_OPERATOR
(
"Min"
,
1
,
min
);
REGISTER_OPERATOR
(
"Mul"
,
1
,
mul
);
REGISTER_OPERATOR
(
"Mul"
,
7
,
mul
);
REGISTER_OPERATOR
(
"Neg"
,
1
,
neg
);
REGISTER_OPERATOR
(
"Not"
,
1
,
logical_not
);
REGISTER_OPERATOR
(
"Or"
,
1
,
logical_or
);
...
...
@@ -214,6 +217,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Sqrt"
,
1
,
sqrt
);
REGISTER_OPERATOR
(
"Squeeze"
,
1
,
squeeze
);
REGISTER_OPERATOR
(
"Sub"
,
1
,
sub
);
REGISTER_OPERATOR
(
"Sub"
,
7
,
sub
);
REGISTER_OPERATOR
(
"Sum"
,
1
,
sum
);
REGISTER_OPERATOR
(
"Tan"
,
1
,
tan
);
REGISTER_OPERATOR
(
"Tanh"
,
1
,
tanh
);
...
...
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
View file @
8ef1ec04
...
...
@@ -103,6 +103,67 @@ namespace ngraph
return
{
broadcasted_left
,
broadcasted_right
};
}
NodeVector
legacy_style_broadcast_for_binary_operation
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
,
std
::
size_t
start_match_axis
)
{
auto
left_shape
=
left
->
get_shape
();
auto
right_shape
=
right
->
get_shape
();
bool
dimensions_identical
=
(
left_shape
==
right_shape
);
if
(
dimensions_identical
)
{
return
{
left
,
right
};
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto
new_right_shape
=
right_shape
;
for
(
int
dimension
=
new_right_shape
.
size
()
-
1
;
dimension
>=
0
;
--
dimension
)
{
if
(
new_right_shape
[
dimension
]
==
1
)
{
new_right_shape
.
pop_back
();
}
else
{
break
;
}
}
// Find first dimensions at front with length different from 1
size_t
num_ones
=
0
;
for
(
size_t
dimension
:
new_right_shape
)
{
if
(
dimension
==
1
)
{
++
num_ones
;
}
else
{
break
;
}
}
// Remove dimensions with length=1 from front
new_right_shape
.
erase
(
std
::
begin
(
new_right_shape
),
std
::
next
(
std
::
begin
(
new_right_shape
),
num_ones
));
auto
reshape_right
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
right
,
reshape
::
get_default_axis_vector
(
right_shape
.
size
()),
new_right_shape
);
// Move broadcast start axis parameter to right
start_match_axis
+=
num_ones
;
auto
broadcast_right
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
reshape_right
,
left_shape
,
calculate_broadcast_axes
(
left_shape
,
new_right_shape
,
start_match_axis
));
return
{
left
,
broadcast_right
};
}
AxisSet
calculate_broadcast_axes
(
const
Shape
&
output_shape
,
const
Shape
&
input_shape
,
std
::
size_t
start_match_axis
)
...
...
src/ngraph/frontend/onnx_import/utils/broadcasting.hpp
View file @
8ef1ec04
...
...
@@ -47,6 +47,26 @@ namespace ngraph
return
numpy_style_broadcast_for_binary_operation
(
inputs
.
at
(
0
),
inputs
.
at
(
1
));
}
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
,
std
::
size_t
start_match_axis
);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment