Commit 2fc0bbb4 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[Py] API helper function broadcast_to (#1170)

parent c086eb2d
...@@ -24,6 +24,7 @@ from ngraph.ops import atan ...@@ -24,6 +24,7 @@ from ngraph.ops import atan
from ngraph.ops import avg_pool from ngraph.ops import avg_pool
from ngraph.ops import batch_norm from ngraph.ops import batch_norm
from ngraph.ops import broadcast from ngraph.ops import broadcast
from ngraph.ops import broadcast_to
from ngraph.ops import ceiling from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil from ngraph.ops import ceiling as ceil
from ngraph.ops import concat from ngraph.ops import concat
......
...@@ -420,14 +420,59 @@ Node.__ge__ = greater_eq ...@@ -420,14 +420,59 @@ Node.__ge__ = greater_eq
# Custom ops # Custom ops
@nameable_op @nameable_op
def broadcast(node, new_shape, axis=None, name=None): # type: (Node, TensorShape, int, str) -> Node def broadcast(node, new_shape, broadcast_axes, name=None):
"""Return node which broadcasts input node values to specified shape. # type: (Node, TensorShape, Iterable[int], str) -> Node
"""Create a node which broadcasts the input node's values along specified axes to a desired shape.
:param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to.
:param broadcast_axes: The axis positions (0-based) in the result that are being broadcast.
:param name: Optional new name for output node.
:return: New node with broadcast shape.
"""
return Broadcast(node, Shape(new_shape), AxisSet(broadcast_axes))
@nameable_op
def broadcast_to(node, new_shape, axis=None, name=None):
# type: (Node, TensorShape, int, str) -> Node
"""Create a node which broadcasts the input node's values to a desired shape.
`broadcast_to` will attempt to automatically determine which axes need broadcasting.
The optional `axis` parameter specifies the starting axis position (0-based) in the output
shape from which the current shape of the tensor matches the desired new shape.
e.g. current_shape: [4, 5], new_shape: [2, 3, 4, 5, 6], axis: 2
By using the `axis` parameter you can control which output axis to broadcast along.
Example:
>>> input_node = ng.constant([1, 2, 3])
>>> current_shape = [3]
>>> new_shape = [3, 3]
>>> ng.broadcast_to(input_node, new_shape, axis=1)
array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
>>> ng.broadcast_to(input_node, new_shape, axis=0)
array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
If the `axis` parameter is not specified, `broadcast_to` will attempt to match shapes,
assuming the current shape matches the rightmost positions of the desired new shape.
This behaviour is similar to NumPy's broadcasting.
i.e. default `axis = len(new_shape) - len(current_shape)`
:param node: The node with input tensor data. :param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to. :param new_shape: The new shape we want to broadcast tensor to.
:param axis: The axis along which we perform broadcasting. :param axis: The axis along which we perform broadcasting.
:param name: Optional new name for output node. :param name: Optional new name for output node.
:return: New node with broadcasted shape. :return: New node with broadcast shape.
""" """
return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis)) return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis))
......
...@@ -75,10 +75,10 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis ...@@ -75,10 +75,10 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis
output_nodes = [] output_nodes = []
for input_value in input_values: for input_value in input_values:
if issubclass(type(input_value), Node): if issubclass(type(input_value), Node):
input_value = ng.broadcast(input_value, broadcast_shape) input_value = ng.broadcast_to(input_value, broadcast_shape)
output_nodes.append(input_value) output_nodes.append(input_value)
else: else:
input_value = make_constant_node(input_value, dtype=broadcast_dtype) input_value = make_constant_node(input_value, dtype=broadcast_dtype)
output_nodes.append(ng.broadcast(input_value, broadcast_shape)) output_nodes.append(ng.broadcast_to(input_value, broadcast_shape))
return output_nodes return output_nodes
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/axis_set.hpp" //ngraph::AxisSet #include <iterator>
#include <sstream>
#include <string>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "ngraph/axis_set.hpp" //ngraph::AxisSet
#include "pyngraph/axis_set.hpp" #include "pyngraph/axis_set.hpp"
namespace py = pybind11; namespace py = pybind11;
...@@ -27,5 +31,19 @@ void regclass_pyngraph_AxisSet(py::module m) ...@@ -27,5 +31,19 @@ void regclass_pyngraph_AxisSet(py::module m)
axis_set.doc() = "ngraph.impl.AxisSet wraps ngraph::AxisSet"; axis_set.doc() = "ngraph.impl.AxisSet wraps ngraph::AxisSet";
axis_set.def(py::init<const std::initializer_list<size_t>&>()); axis_set.def(py::init<const std::initializer_list<size_t>&>());
axis_set.def(py::init<const std::set<size_t>&>()); axis_set.def(py::init<const std::set<size_t>&>());
axis_set.def(py::init<const std::vector<size_t>&>());
axis_set.def(py::init<const ngraph::AxisSet&>()); axis_set.def(py::init<const ngraph::AxisSet&>());
axis_set.def("__len__", [](const ngraph::AxisSet& v) { return v.size(); });
axis_set.def("__iter__",
[](ngraph::AxisSet& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep set alive while iterator is used */
axis_set.def("__repr__", [](const ngraph::AxisSet& self) -> std::string {
std::stringstream data_ss;
std::copy(self.begin(), self.end(), std::ostream_iterator<int>(data_ss, ", "));
std::string data_str = data_ss.str();
return "<AxisSet {" + data_str.substr(0, data_str.size() - 2) + "}>";
});
} }
...@@ -93,7 +93,7 @@ def test_serialization(): ...@@ -93,7 +93,7 @@ def test_serialization():
expected = [[1, 2, 3], expected = [[1, 2, 3],
[1, 2, 3], [1, 2, 3],
[1, 2, 3]] [1, 2, 3]]
result = run_op_node([input_data], ng.broadcast, new_shape) result = run_op_node([input_data], ng.broadcast_to, new_shape)
assert np.allclose(result, expected) assert np.allclose(result, expected)
axis = 0 axis = 0
...@@ -101,13 +101,13 @@ def test_serialization(): ...@@ -101,13 +101,13 @@ def test_serialization():
[2, 2, 2], [2, 2, 2],
[3, 3, 3]] [3, 3, 3]]
result = run_op_node([input_data], ng.broadcast, new_shape, axis) result = run_op_node([input_data], ng.broadcast_to, new_shape, axis)
assert np.allclose(result, expected) assert np.allclose(result, expected)
input_data = np.arange(4) input_data = np.arange(4)
new_shape = [3, 4, 2, 4] new_shape = [3, 4, 2, 4]
expected = np.broadcast_to(input_data, new_shape) expected = np.broadcast_to(input_data, new_shape)
result = run_op_node([input_data], ng.broadcast, new_shape) result = run_op_node([input_data], ng.broadcast_to, new_shape)
assert np.allclose(result, expected) assert np.allclose(result, expected)
......
...@@ -736,6 +736,23 @@ def test_concat(): ...@@ -736,6 +736,23 @@ def test_concat():
assert np.allclose(result_arr, result_arr_ref) assert np.allclose(result_arr, result_arr_ref)
@pytest.config.gpu_skip(reason="Not implemented")
def test_axisset():
set_axisset = AxisSet({1, 2, 3})
list_axisset = AxisSet([1, 2, 3])
tuple_axisset = AxisSet((1, 2, 3))
assert len(set_axisset) == 3
assert set(set_axisset) == {1, 2, 3}
assert len(list_axisset) == 3
assert set(list_axisset) == set(set_axisset)
assert len(tuple_axisset) == 3
assert set(tuple_axisset) == set(set_axisset)
@pytest.config.gpu_skip(reason="Not implemented") @pytest.config.gpu_skip(reason="Not implemented")
def test_select(): def test_select():
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <cstddef> #include <cstddef>
#include <set> #include <set>
#include <vector>
namespace ngraph namespace ngraph
{ {
...@@ -36,6 +37,11 @@ namespace ngraph ...@@ -36,6 +37,11 @@ namespace ngraph
{ {
} }
AxisSet(const std::vector<size_t>& axes)
: std::set<size_t>(axes.begin(), axes.end())
{
}
AxisSet(const AxisSet& axes) AxisSet(const AxisSet& axes)
: std::set<size_t>(axes) : std::set<size_t>(axes)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment