Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
a2521cf9
Commit
a2521cf9
authored
Sep 04, 2018
by
tsocha
Committed by
Michał Karzyński
Sep 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Numpy style binary broadcasting (#1549)
parent
cc989301
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
161 additions
and
4 deletions
+161
-4
add.hpp
src/ngraph/frontend/onnx_import/op/add.hpp
+3
-1
div.hpp
src/ngraph/frontend/onnx_import/op/div.hpp
+3
-1
mul.hpp
src/ngraph/frontend/onnx_import/op/mul.hpp
+3
-1
sub.hpp
src/ngraph/frontend/onnx_import/op/sub.hpp
+3
-1
broadcasting.cpp
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
+85
-0
broadcasting.hpp
src/ngraph/frontend/onnx_import/utils/broadcasting.hpp
+20
-0
add_bcast.onnx
test/models/onnx/add_bcast.onnx
+19
-0
onnx_import.cpp
test/onnx_import.cpp
+25
-0
No files found.
src/ngraph/frontend/onnx_import/op/add.hpp
View file @
a2521cf9
...
...
@@ -20,6 +20,7 @@
#include "ngraph/op/add.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace
ngraph
{
...
...
@@ -29,7 +30,8 @@ namespace ngraph
{
inline
NodeVector
add
(
const
Node
&
node
)
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
NodeVector
ng_inputs
{
numpy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
())};
return
{
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
...
...
src/ngraph/frontend/onnx_import/op/div.hpp
View file @
a2521cf9
...
...
@@ -20,6 +20,7 @@
#include "ngraph/op/divide.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace
ngraph
{
...
...
@@ -29,7 +30,8 @@ namespace ngraph
{
inline
NodeVector
div
(
const
Node
&
node
)
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
NodeVector
ng_inputs
{
numpy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
())};
return
{
std
::
make_shared
<
ngraph
::
op
::
Divide
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
...
...
src/ngraph/frontend/onnx_import/op/mul.hpp
View file @
a2521cf9
...
...
@@ -20,6 +20,7 @@
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace
ngraph
{
...
...
@@ -29,7 +30,8 @@ namespace ngraph
{
inline
NodeVector
mul
(
const
Node
&
node
)
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
NodeVector
ng_inputs
{
numpy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
())};
return
{
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
...
...
src/ngraph/frontend/onnx_import/op/sub.hpp
View file @
a2521cf9
...
...
@@ -20,6 +20,7 @@
#include "ngraph/op/subtract.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace
ngraph
{
...
...
@@ -29,7 +30,8 @@ namespace ngraph
{
inline
NodeVector
sub
(
const
Node
&
node
)
{
NodeVector
ng_inputs
{
node
.
get_ng_inputs
()};
NodeVector
ng_inputs
{
numpy_style_broadcast_for_binary_operation
(
node
.
get_ng_inputs
())};
return
{
std
::
make_shared
<
ngraph
::
op
::
Subtract
>
(
ng_inputs
.
at
(
0
),
ng_inputs
.
at
(
1
))};
}
...
...
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
View file @
a2521cf9
...
...
@@ -17,12 +17,97 @@
#include <numeric>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
///
/// \param left_shape Shape of first input tensor.
/// \param right_shape Shape of the second input tensor.
/// \return Shape of the output tensor and full shape of input tensors.
static
std
::
vector
<
ngraph
::
Shape
>
calculate_numpy_broadcast_shape
(
ngraph
::
Shape
left_shape
,
ngraph
::
Shape
right_shape
)
{
ngraph
::
Shape
output_shape
;
auto
rank_left
=
left_shape
.
size
();
auto
rank_right
=
right_shape
.
size
();
auto
max_rank
=
std
::
max
(
rank_left
,
rank_right
);
for
(
auto
i
=
0
;
i
<
(
max_rank
-
rank_left
);
++
i
)
{
left_shape
.
insert
(
std
::
begin
(
left_shape
),
1
);
}
for
(
auto
i
=
0
;
i
<
(
max_rank
-
rank_right
);
++
i
)
{
right_shape
.
insert
(
std
::
begin
(
right_shape
),
1
);
}
for
(
auto
index
=
0
;
index
<
max_rank
;
++
index
)
{
output_shape
.
push_back
(
std
::
max
(
left_shape
.
at
(
index
),
right_shape
.
at
(
index
)));
}
return
{
output_shape
,
left_shape
,
right_shape
};
}
namespace
ngraph
{
namespace
onnx_import
{
NodeVector
numpy_style_broadcast_for_binary_operation
(
const
std
::
shared_ptr
<
Node
>&
left
,
const
std
::
shared_ptr
<
Node
>&
right
)
{
auto
left_shape
=
left
->
get_shape
();
auto
right_shape
=
right
->
get_shape
();
auto
numpy_shapes
=
calculate_numpy_broadcast_shape
(
left_shape
,
right_shape
);
auto
output_shape
=
numpy_shapes
.
at
(
0
);
auto
left_full_shape
=
numpy_shapes
.
at
(
1
);
auto
right_full_shape
=
numpy_shapes
.
at
(
2
);
AxisVector
left_broadcast_axes
;
AxisVector
right_broadcast_axes
;
Shape
new_left_shape
;
Shape
new_right_shape
;
// Positions of dims which have length of 1 are needed to calculate broadcast_axes for nGraph broadcast operation.
// We need to remove all ones from source shape (left_broadcast_axes) to avoid broadcasting axis conflict.
for
(
auto
index
=
0
;
index
<
output_shape
.
size
();
++
index
)
{
(
left_full_shape
.
at
(
index
)
==
1
)
?
left_broadcast_axes
.
push_back
(
index
)
:
new_left_shape
.
push_back
(
left_full_shape
.
at
(
index
));
(
right_full_shape
.
at
(
index
)
==
1
)
?
right_broadcast_axes
.
push_back
(
index
)
:
new_right_shape
.
push_back
(
right_full_shape
.
at
(
index
));
}
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std
::
vector
<
size_t
>
left_input_order
(
left
->
get_shape
().
size
());
std
::
iota
(
std
::
begin
(
left_input_order
),
std
::
end
(
left_input_order
),
0
);
// Remove dims which have length of 1 from source shape
std
::
shared_ptr
<
Node
>
broadcasted_left
=
std
::
make_shared
<
op
::
Reshape
>
(
left
,
left_input_order
,
new_left_shape
);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std
::
vector
<
size_t
>
right_input_order
(
right
->
get_shape
().
size
());
std
::
iota
(
std
::
begin
(
right_input_order
),
std
::
end
(
right_input_order
),
0
);
// Remove dims which have length of 1 from source shape
std
::
shared_ptr
<
Node
>
broadcasted_right
=
std
::
make_shared
<
op
::
Reshape
>
(
right
,
right_input_order
,
new_right_shape
);
broadcasted_left
=
std
::
make_shared
<
op
::
Broadcast
>
(
broadcasted_left
,
output_shape
,
left_broadcast_axes
);
broadcasted_right
=
std
::
make_shared
<
op
::
Broadcast
>
(
broadcasted_right
,
output_shape
,
right_broadcast_axes
);
return
{
broadcasted_left
,
broadcasted_right
};
}
AxisSet
calculate_broadcast_axes
(
const
Shape
&
output_shape
,
const
Shape
&
input_shape
,
std
::
size_t
start_match_axis
)
...
...
src/ngraph/frontend/onnx_import/utils/broadcasting.hpp
View file @
a2521cf9
...
...
@@ -25,6 +25,26 @@ namespace ngraph
{
namespace
onnx_import
{
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
///
/// \return Left and right node after broadcasting.
NodeVector
numpy_style_broadcast_for_binary_operation
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
left
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
right
);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param inputs Left and right node (inputs of the binary op).
///
/// \return Left and right node after broadcasting.
inline
NodeVector
numpy_style_broadcast_for_binary_operation
(
NodeVector
inputs
)
{
return
numpy_style_broadcast_for_binary_operation
(
inputs
.
at
(
0
),
inputs
.
at
(
1
));
}
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
...
...
test/models/onnx/add_bcast.onnx
0 → 100644
View file @
a2521cf9
backend-test:g
x
ysum"Addtest_add_bcastZ
x
Z
y
b
sum
B
\ No newline at end of file
test/onnx_import.cpp
View file @
a2521cf9
...
...
@@ -576,6 +576,31 @@ TEST(onnx, model_div)
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
,
result_vectors
.
front
()));
}
TEST
(
onnx
,
model_add_bcast
)
{
auto
function
=
onnx_import
::
import_onnx_function
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/add_bcast.onnx"
));
Inputs
inputs
;
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
3
>
(
{{{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
}},
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
}},
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
}}})
.
get_vector
());
inputs
.
emplace_back
(
test
::
NDArray
<
float
,
1
>
({
1
,
2
,
3
,
4
,
5
}).
get_vector
());
Outputs
expected_output
{
test
::
NDArray
<
float
,
4
>
(
{{{{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
}},
{{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
}},
{{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
4
,
5
,
6
}}}})
.
get_vector
()};
Outputs
outputs
{
execute
(
function
,
inputs
,
"INTERPRETER"
)};
EXPECT_TRUE
(
test
::
all_close_f
(
expected_output
.
front
(),
outputs
.
front
()));
}
TEST
(
onnx
,
model_reshape_reduced_dims
)
{
auto
function
=
onnx_import
::
import_onnx_function
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment