Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
2fc0bbb4
Commit
2fc0bbb4
authored
Jul 03, 2018
by
tsocha
Committed by
Michał Karzyński
Jul 03, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Py] API helper function broadcast_to (#1170)
parent
c086eb2d
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
9 deletions
+96
-9
__init__.py
python/ngraph/__init__.py
+1
-0
ops.py
python/ngraph/ops.py
+48
-3
broadcasting.py
python/ngraph/utils/broadcasting.py
+2
-2
axis_set.cpp
python/pyngraph/axis_set.cpp
+19
-1
test_basic.py
python/test/ngraph/test_basic.py
+3
-3
test_ops.py
python/test/test_ops.py
+17
-0
axis_set.hpp
src/ngraph/axis_set.hpp
+6
-0
No files found.
python/ngraph/__init__.py
View file @
2fc0bbb4
...
...
@@ -24,6 +24,7 @@ from ngraph.ops import atan
from
ngraph.ops
import
avg_pool
from
ngraph.ops
import
batch_norm
from
ngraph.ops
import
broadcast
from
ngraph.ops
import
broadcast_to
from
ngraph.ops
import
ceiling
from
ngraph.ops
import
ceiling
as
ceil
from
ngraph.ops
import
concat
...
...
python/ngraph/ops.py
View file @
2fc0bbb4
...
...
@@ -420,14 +420,59 @@ Node.__ge__ = greater_eq
# Custom ops
@nameable_op
def
broadcast
(
node
,
new_shape
,
axis
=
None
,
name
=
None
):
# type: (Node, TensorShape, int, str) -> Node
"""Return node which broadcasts input node values to specified shape.
def
broadcast
(
node
,
new_shape
,
broadcast_axes
,
name
=
None
):
# type: (Node, TensorShape, Iterable[int], str) -> Node
"""Create a node which broadcasts the input node's values along specified axes to a desired shape.
:param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to.
:param broadcast_axes: The axis positions (0-based) in the result that are being broadcast.
:param name: Optional new name for output node.
:return: New node with broadcast shape.
"""
return
Broadcast
(
node
,
Shape
(
new_shape
),
AxisSet
(
broadcast_axes
))
@nameable_op
def
broadcast_to
(
node
,
new_shape
,
axis
=
None
,
name
=
None
):
# type: (Node, TensorShape, int, str) -> Node
"""Create a node which broadcasts the input node's values to a desired shape.
`broadcast_to` will attempt to automatically determine which axes need broadcasting.
The optional `axis` parameter specifies the starting axis position (0-based) in the output
shape from which the current shape of the tensor matches the desired new shape.
e.g. current_shape: [4, 5], new_shape: [2, 3, 4, 5, 6], axis: 2
By using the `axis` parameter you can control which output axis to broadcast along.
Example:
>>> input_node = ng.constant([1, 2, 3])
>>> current_shape = [3]
>>> new_shape = [3, 3]
>>> ng.broadcast_to(input_node, new_shape, axis=1)
array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
>>> ng.broadcast_to(input_node, new_shape, axis=0)
array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
If the `axis` parameter is not specified, `broadcast_to` will attempt to match shapes,
assuming the current shape matches the rightmost positions of the desired new shape.
This behaviour is similar to NumPy's broadcasting.
i.e. default `axis = len(new_shape) - len(current_shape)`
:param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to.
:param axis: The axis along which we perform broadcasting.
:param name: Optional new name for output node.
:return: New node with broadcast
ed
shape.
:return: New node with broadcast shape.
"""
return
Broadcast
(
node
,
Shape
(
new_shape
),
get_broadcast_axes
(
new_shape
,
node
.
shape
,
axis
))
...
...
python/ngraph/utils/broadcasting.py
View file @
2fc0bbb4
...
...
@@ -75,10 +75,10 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis
output_nodes
=
[]
for
input_value
in
input_values
:
if
issubclass
(
type
(
input_value
),
Node
):
input_value
=
ng
.
broadcast
(
input_value
,
broadcast_shape
)
input_value
=
ng
.
broadcast
_to
(
input_value
,
broadcast_shape
)
output_nodes
.
append
(
input_value
)
else
:
input_value
=
make_constant_node
(
input_value
,
dtype
=
broadcast_dtype
)
output_nodes
.
append
(
ng
.
broadcast
(
input_value
,
broadcast_shape
))
output_nodes
.
append
(
ng
.
broadcast
_to
(
input_value
,
broadcast_shape
))
return
output_nodes
python/pyngraph/axis_set.cpp
View file @
2fc0bbb4
...
...
@@ -14,9 +14,13 @@
* limitations under the License.
*******************************************************************************/
#include "ngraph/axis_set.hpp" //ngraph::AxisSet
#include <iterator>
#include <sstream>
#include <string>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/axis_set.hpp" //ngraph::AxisSet
#include "pyngraph/axis_set.hpp"
namespace
py
=
pybind11
;
...
...
@@ -27,5 +31,19 @@ void regclass_pyngraph_AxisSet(py::module m)
axis_set
.
doc
()
=
"ngraph.impl.AxisSet wraps ngraph::AxisSet"
;
axis_set
.
def
(
py
::
init
<
const
std
::
initializer_list
<
size_t
>&>
());
axis_set
.
def
(
py
::
init
<
const
std
::
set
<
size_t
>&>
());
axis_set
.
def
(
py
::
init
<
const
std
::
vector
<
size_t
>&>
());
axis_set
.
def
(
py
::
init
<
const
ngraph
::
AxisSet
&>
());
axis_set
.
def
(
"__len__"
,
[](
const
ngraph
::
AxisSet
&
v
)
{
return
v
.
size
();
});
axis_set
.
def
(
"__iter__"
,
[](
ngraph
::
AxisSet
&
v
)
{
return
py
::
make_iterator
(
v
.
begin
(),
v
.
end
());
},
py
::
keep_alive
<
0
,
1
>
());
/* Keep set alive while iterator is used */
axis_set
.
def
(
"__repr__"
,
[](
const
ngraph
::
AxisSet
&
self
)
->
std
::
string
{
std
::
stringstream
data_ss
;
std
::
copy
(
self
.
begin
(),
self
.
end
(),
std
::
ostream_iterator
<
int
>
(
data_ss
,
", "
));
std
::
string
data_str
=
data_ss
.
str
();
return
"<AxisSet {"
+
data_str
.
substr
(
0
,
data_str
.
size
()
-
2
)
+
"}>"
;
});
}
python/test/ngraph/test_basic.py
View file @
2fc0bbb4
...
...
@@ -93,7 +93,7 @@ def test_serialization():
expected
=
[[
1
,
2
,
3
],
[
1
,
2
,
3
],
[
1
,
2
,
3
]]
result
=
run_op_node
([
input_data
],
ng
.
broadcast
,
new_shape
)
result
=
run_op_node
([
input_data
],
ng
.
broadcast
_to
,
new_shape
)
assert
np
.
allclose
(
result
,
expected
)
axis
=
0
...
...
@@ -101,13 +101,13 @@ def test_serialization():
[
2
,
2
,
2
],
[
3
,
3
,
3
]]
result
=
run_op_node
([
input_data
],
ng
.
broadcast
,
new_shape
,
axis
)
result
=
run_op_node
([
input_data
],
ng
.
broadcast
_to
,
new_shape
,
axis
)
assert
np
.
allclose
(
result
,
expected
)
input_data
=
np
.
arange
(
4
)
new_shape
=
[
3
,
4
,
2
,
4
]
expected
=
np
.
broadcast_to
(
input_data
,
new_shape
)
result
=
run_op_node
([
input_data
],
ng
.
broadcast
,
new_shape
)
result
=
run_op_node
([
input_data
],
ng
.
broadcast
_to
,
new_shape
)
assert
np
.
allclose
(
result
,
expected
)
...
...
python/test/test_ops.py
View file @
2fc0bbb4
...
...
@@ -736,6 +736,23 @@ def test_concat():
assert
np
.
allclose
(
result_arr
,
result_arr_ref
)
@pytest.config.gpu_skip
(
reason
=
"Not implemented"
)
def
test_axisset
():
set_axisset
=
AxisSet
({
1
,
2
,
3
})
list_axisset
=
AxisSet
([
1
,
2
,
3
])
tuple_axisset
=
AxisSet
((
1
,
2
,
3
))
assert
len
(
set_axisset
)
==
3
assert
set
(
set_axisset
)
==
{
1
,
2
,
3
}
assert
len
(
list_axisset
)
==
3
assert
set
(
list_axisset
)
==
set
(
set_axisset
)
assert
len
(
tuple_axisset
)
==
3
assert
set
(
tuple_axisset
)
==
set
(
set_axisset
)
@pytest.config.gpu_skip
(
reason
=
"Not implemented"
)
def
test_select
():
...
...
src/ngraph/axis_set.hpp
View file @
2fc0bbb4
...
...
@@ -18,6 +18,7 @@
#include <cstddef>
#include <set>
#include <vector>
namespace
ngraph
{
...
...
@@ -36,6 +37,11 @@ namespace ngraph
{
}
AxisSet
(
const
std
::
vector
<
size_t
>&
axes
)
:
std
::
set
<
size_t
>
(
axes
.
begin
(),
axes
.
end
())
{
}
AxisSet
(
const
AxisSet
&
axes
)
:
std
::
set
<
size_t
>
(
axes
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment