Commit 2fc0bbb4 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[Py] API helper function broadcast_to (#1170)

parent c086eb2d
......@@ -24,6 +24,7 @@ from ngraph.ops import atan
from ngraph.ops import avg_pool
from ngraph.ops import batch_norm
from ngraph.ops import broadcast
from ngraph.ops import broadcast_to
from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil
from ngraph.ops import concat
......
......@@ -420,14 +420,59 @@ Node.__ge__ = greater_eq
# Custom ops
@nameable_op
def broadcast(node, new_shape, axis=None, name=None): # type: (Node, TensorShape, int, str) -> Node
"""Return node which broadcasts input node values to specified shape.
def broadcast(node, new_shape, broadcast_axes, name=None):
# type: (Node, TensorShape, Iterable[int], str) -> Node
"""Create a node which broadcasts the input node's values along specified axes to a desired shape.
:param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to.
:param broadcast_axes: The axis positions (0-based) in the result that are being broadcast.
:param name: Optional new name for output node.
:return: New node with broadcast shape.
"""
return Broadcast(node, Shape(new_shape), AxisSet(broadcast_axes))
@nameable_op
def broadcast_to(node, new_shape, axis=None, name=None):
# type: (Node, TensorShape, int, str) -> Node
"""Create a node which broadcasts the input node's values to a desired shape.
`broadcast_to` will attempt to automatically determine which axes need broadcasting.
The optional `axis` parameter specifies the starting axis position (0-based) in the output
shape from which the current shape of the tensor matches the desired new shape.
e.g. current_shape: [4, 5], new_shape: [2, 3, 4, 5, 6], axis: 2
By using the `axis` parameter you can control which output axis to broadcast along.
Example:
>>> input_node = ng.constant([1, 2, 3])
>>> current_shape = [3]
>>> new_shape = [3, 3]
>>> ng.broadcast_to(input_node, new_shape, axis=1)
array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
>>> ng.broadcast_to(input_node, new_shape, axis=0)
array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
If the `axis` parameter is not specified, `broadcast_to` will attempt to match shapes,
assuming the current shape matches the rightmost positions of the desired new shape.
This behaviour is similar to NumPy's broadcasting.
i.e. default `axis = len(new_shape) - len(current_shape)`
:param node: The node with input tensor data.
:param new_shape: The new shape we want to broadcast tensor to.
:param axis: The axis along which we perform broadcasting.
:param name: Optional new name for output node.
:return: New node with broadcasted shape.
:return: New node with broadcast shape.
"""
return Broadcast(node, Shape(new_shape), get_broadcast_axes(new_shape, node.shape, axis))
......
......@@ -75,10 +75,10 @@ def as_elementwise_compatible_nodes(*input_values): # type: (*NodeInput) -> Lis
output_nodes = []
for input_value in input_values:
if issubclass(type(input_value), Node):
input_value = ng.broadcast(input_value, broadcast_shape)
input_value = ng.broadcast_to(input_value, broadcast_shape)
output_nodes.append(input_value)
else:
input_value = make_constant_node(input_value, dtype=broadcast_dtype)
output_nodes.append(ng.broadcast(input_value, broadcast_shape))
output_nodes.append(ng.broadcast_to(input_value, broadcast_shape))
return output_nodes
......@@ -14,9 +14,13 @@
* limitations under the License.
*******************************************************************************/
#include "ngraph/axis_set.hpp" //ngraph::AxisSet
#include <iterator>
#include <sstream>
#include <string>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/axis_set.hpp" //ngraph::AxisSet
#include "pyngraph/axis_set.hpp"
namespace py = pybind11;
......@@ -27,5 +31,19 @@ void regclass_pyngraph_AxisSet(py::module m)
axis_set.doc() = "ngraph.impl.AxisSet wraps ngraph::AxisSet";
axis_set.def(py::init<const std::initializer_list<size_t>&>());
axis_set.def(py::init<const std::set<size_t>&>());
axis_set.def(py::init<const std::vector<size_t>&>());
axis_set.def(py::init<const ngraph::AxisSet&>());
axis_set.def("__len__", [](const ngraph::AxisSet& v) { return v.size(); });
axis_set.def("__iter__",
[](ngraph::AxisSet& v) { return py::make_iterator(v.begin(), v.end()); },
py::keep_alive<0, 1>()); /* Keep set alive while iterator is used */
axis_set.def("__repr__", [](const ngraph::AxisSet& self) -> std::string {
std::stringstream data_ss;
std::copy(self.begin(), self.end(), std::ostream_iterator<int>(data_ss, ", "));
std::string data_str = data_ss.str();
return "<AxisSet {" + data_str.substr(0, data_str.size() - 2) + "}>";
});
}
......@@ -93,7 +93,7 @@ def test_serialization():
expected = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]
result = run_op_node([input_data], ng.broadcast, new_shape)
result = run_op_node([input_data], ng.broadcast_to, new_shape)
assert np.allclose(result, expected)
axis = 0
......@@ -101,13 +101,13 @@ def test_serialization():
[2, 2, 2],
[3, 3, 3]]
result = run_op_node([input_data], ng.broadcast, new_shape, axis)
result = run_op_node([input_data], ng.broadcast_to, new_shape, axis)
assert np.allclose(result, expected)
input_data = np.arange(4)
new_shape = [3, 4, 2, 4]
expected = np.broadcast_to(input_data, new_shape)
result = run_op_node([input_data], ng.broadcast, new_shape)
result = run_op_node([input_data], ng.broadcast_to, new_shape)
assert np.allclose(result, expected)
......
......@@ -736,6 +736,23 @@ def test_concat():
assert np.allclose(result_arr, result_arr_ref)
@pytest.config.gpu_skip(reason="Not implemented")
def test_axisset():
set_axisset = AxisSet({1, 2, 3})
list_axisset = AxisSet([1, 2, 3])
tuple_axisset = AxisSet((1, 2, 3))
assert len(set_axisset) == 3
assert set(set_axisset) == {1, 2, 3}
assert len(list_axisset) == 3
assert set(list_axisset) == set(set_axisset)
assert len(tuple_axisset) == 3
assert set(tuple_axisset) == set(set_axisset)
@pytest.config.gpu_skip(reason="Not implemented")
def test_select():
......
......@@ -18,6 +18,7 @@
#include <cstddef>
#include <set>
#include <vector>
namespace ngraph
{
......@@ -36,6 +37,11 @@ namespace ngraph
{
}
AxisSet(const std::vector<size_t>& axes)
: std::set<size_t>(axes.begin(), axes.end())
{
}
AxisSet(const AxisSet& axes)
: std::set<size_t>(axes)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment