Commit a2521cf9 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Numpy style binary broadcasting (#1549)

parent cc989301
......@@ -20,6 +20,7 @@
#include "ngraph/op/add.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -29,7 +30,8 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/divide.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -29,7 +30,8 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -29,7 +30,8 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/subtract.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -29,7 +30,8 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -17,12 +17,97 @@
#include <numeric>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
///
/// \param left_shape Shape of first input tensor.
/// \param right_shape Shape of the second input tensor.
/// \return Shape of the output tensor and full shape of input tensors.
static std::vector<ngraph::Shape> calculate_numpy_broadcast_shape(ngraph::Shape left_shape,
ngraph::Shape right_shape)
{
ngraph::Shape output_shape;
auto rank_left = left_shape.size();
auto rank_right = right_shape.size();
auto max_rank = std::max(rank_left, rank_right);
for (auto i = 0; i < (max_rank - rank_left); ++i)
{
left_shape.insert(std::begin(left_shape), 1);
}
for (auto i = 0; i < (max_rank - rank_right); ++i)
{
right_shape.insert(std::begin(right_shape), 1);
}
for (auto index = 0; index < max_rank; ++index)
{
output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index)));
}
return {output_shape, left_shape, right_shape};
}
namespace ngraph
{
namespace onnx_import
{
NodeVector numpy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
{
auto left_shape = left->get_shape();
auto right_shape = right->get_shape();
auto numpy_shapes = calculate_numpy_broadcast_shape(left_shape, right_shape);
auto output_shape = numpy_shapes.at(0);
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
AxisVector left_broadcast_axes;
AxisVector right_broadcast_axes;
Shape new_left_shape;
Shape new_right_shape;
// Positions of dims which have length of 1 are needed to calculate broadcast_axes for nGraph broadcast operation.
// We need to remove all ones from source shape (left_broadcast_axes) to avoid broadcasting axis conflict.
for (auto index = 0; index < output_shape.size(); ++index)
{
(left_full_shape.at(index) == 1)
? left_broadcast_axes.push_back(index)
: new_left_shape.push_back(left_full_shape.at(index));
(right_full_shape.at(index) == 1)
? right_broadcast_axes.push_back(index)
: new_right_shape.push_back(right_full_shape.at(index));
}
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> left_input_order(left->get_shape().size());
std::iota(std::begin(left_input_order), std::end(left_input_order), 0);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_left =
std::make_shared<op::Reshape>(left, left_input_order, new_left_shape);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> right_input_order(right->get_shape().size());
std::iota(std::begin(right_input_order), std::end(right_input_order), 0);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_right =
std::make_shared<op::Reshape>(right, right_input_order, new_right_shape);
broadcasted_left = std::make_shared<op::Broadcast>(
broadcasted_left, output_shape, left_broadcast_axes);
broadcasted_right = std::make_shared<op::Broadcast>(
broadcasted_right, output_shape, right_broadcast_axes);
return {broadcasted_left, broadcasted_right};
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
......
......@@ -25,6 +25,26 @@ namespace ngraph
{
namespace onnx_import
{
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
///
/// \return Left and right node after broadcasting.
NodeVector
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param inputs Left and right node (inputs of the binary op).
///
/// \return Left and right node after broadcasting.
inline NodeVector numpy_style_broadcast_for_binary_operation(NodeVector inputs)
{
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
}
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
......
 backend-test:g

x
ysum"Addtest_add_bcastZ
x



Z
y

b
sum



B
\ No newline at end of file
......@@ -576,6 +576,31 @@ TEST(onnx, model_div)
EXPECT_TRUE(test::all_close_f(expected_output, result_vectors.front()));
}
TEST(onnx, model_add_bcast)
{
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_bcast.onnx"));
Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
{{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}})
.get_vector());
inputs.emplace_back(test::NDArray<float, 1>({1, 2, 3, 4, 5}).get_vector());
Outputs expected_output{
test::NDArray<float, 4>(
{{{{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}},
{{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}},
{{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx, model_reshape_reduced_dims)
{
auto function = onnx_import::import_onnx_function(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment