Commit 50334cbf authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] numpy broadcasting refactoring (#2496)

* Remove get_numpy_broadcast_shape helper function

* Remove numpy_style_broadcast_for_binary_operation helper function

* Remove TODO

* Review fix pt. 1

* Remove parameters as shape containers

* Fix LSTM

* Review fix pt. 1

* Style apply

* Use old comment
parent b8106133
......@@ -44,8 +44,7 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector logical_and(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -44,8 +44,7 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector equal(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -79,8 +79,7 @@ namespace ngraph
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes =
numpy_style_broadcast_for_binary_operation(a_dot_b, input_c);
NodeVector broadcasted_nodes = numpy_style_broadcast({a_dot_b, input_c});
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector less(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -62,21 +62,21 @@ namespace ngraph
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs);
auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
......
......@@ -46,8 +46,7 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -79,8 +79,9 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
values->get_element_type());
on_value = numpy_style_broadcast_for_binary_operation(one_hot, on_value)[1];
off_value = numpy_style_broadcast_for_binary_operation(one_hot, off_value)[1];
auto broadcasted_values = numpy_style_broadcast({one_hot, on_value, off_value});
on_value = broadcasted_values[1];
off_value = broadcasted_values[2];
one_hot = one_hot * (on_value - off_value) + off_value;
return {one_hot};
}
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector logical_or(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -31,8 +31,7 @@ namespace ngraph
{
inline NodeVector pow(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -57,8 +57,7 @@ namespace ngraph
}
else if (data_shape != slope_shape)
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
slope = numpy_style_broadcast({slope, data})[0];
}
// x < 0 => f(x) = x * slope
......
......@@ -45,8 +45,7 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -33,8 +33,7 @@ namespace ngraph
{
inline NodeVector logical_xor(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
......
......@@ -26,59 +26,59 @@
#include "ngraph/op/reshape.hpp"
#include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// \brief Calculate the output shape of numpy-style broadcast operation for two shapes.
///
/// \param left_shape Shape of first input tensor.
/// \param right_shape Shape of the second input tensor.
/// \return Shape of the output tensor and full shape of input tensors.
static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_shape,
ngraph::Shape right_shape)
/// more info:
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// example:
/// left: [3, 1, 10] right: [5, 1]
/// return: [3, 5, 10]
///
/// \param left_shape First input shape.
/// \param right_shape Second input Shape.
/// \return Broadcast shape of input shapes.
ngraph::Shape calculate_broadcast_shape(ngraph::Shape left_shape, ngraph::Shape right_shape)
{
ngraph::Shape output_shape;
auto rank_left = left_shape.size();
auto rank_right = right_shape.size();
auto max_rank = std::max(rank_left, rank_right);
ngraph::Shape result;
auto left_rank = left_shape.size();
auto right_rank = right_shape.size();
auto max_rank = std::max(left_rank, right_rank);
// left-pad the left_shape with ones
left_shape.insert(std::begin(left_shape), max_rank - rank_left, 1);
left_shape.insert(std::begin(left_shape), max_rank - left_rank, 1);
// left-pad the right_shape with ones
right_shape.insert(std::begin(right_shape), max_rank - rank_right, 1);
right_shape.insert(std::begin(right_shape), max_rank - right_rank, 1);
for (std::size_t index = 0; index < max_rank; ++index)
{
output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index)));
result.push_back(std::max(left_shape.at(index), right_shape.at(index)));
}
return {output_shape, left_shape, right_shape};
}
return result;
};
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes.
/// \brief Calculate the output shape of numpy-style broadcast operation for all input shapes.
///
/// This function finds the maximum tensor shape that will be the result of element-wise operation
/// that will be applied to the inputs vector. The function also prepares the shape of each input
/// that will be applied to the input shapes vector. The function also prepares the shape of each input
/// for the element-wise operation by left-padding those shapes so that their rank is equal to
/// the target_shape's rank.
/// the left_shape's rank.
///
/// \param inputs A vector of input nodes for which a common shape should be found
/// \param input_shapes A vector of input shapes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs)
get_numpy_broadcast_shapes(const std::vector<ngraph::Shape>& input_shapes)
{
auto shape_left_fold = [](const ngraph::Shape& accumulator,
const std::shared_ptr<ngraph::Node>& input) {
// TODO: in a separate PR remove the 'get_numpy_broadcast_shape' function
return get_numpy_broadcast_shape(accumulator, input->get_shape()).at(0);
};
ngraph::Shape target_shape =
std::accumulate(std::begin(inputs), std::end(inputs), ngraph::Shape{}, shape_left_fold);
ngraph::Shape target_shape = std::accumulate(std::begin(input_shapes),
std::end(input_shapes),
ngraph::Shape{},
calculate_broadcast_shape);
std::vector<ngraph::Shape> full_shapes;
for (const std::shared_ptr<ngraph::Node>& input : inputs)
for (const ngraph::Shape& input : input_shapes)
{
ngraph::Shape padded_shape = input->get_shape();
ngraph::Shape padded_shape{input};
padded_shape.insert(std::begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(std::move(padded_shape));
}
......@@ -86,6 +86,24 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
return {target_shape, full_shapes};
}
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes.
///
/// \param inputs A vector of input nodes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs)
{
std::vector<ngraph::Shape> input_shapes;
for (const auto& input : inputs)
{
input_shapes.push_back(input->get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
......@@ -141,22 +159,7 @@ namespace ngraph
{
namespace onnx_import
{
NodeVector
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
const auto& numpy_shapes = get_numpy_broadcast_shape(left_shape, right_shape);
auto output_shape = numpy_shapes.at(0);
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
return {broadcast_node_numpy_style(left, output_shape, left_full_shape),
broadcast_node_numpy_style(right, output_shape, right_full_shape)};
}
NodeVector numpy_style_broadcast(NodeVector inputs)
NodeVector numpy_style_broadcast(const NodeVector& inputs)
{
if (inputs.size() <= 1)
{
......@@ -186,13 +189,13 @@ namespace ngraph
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shape(
Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)});
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.at(0);
auto right_output_shape = numpy_shapes.at(0);
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
......@@ -201,8 +204,8 @@ namespace ngraph
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
......
......@@ -27,32 +27,12 @@ namespace ngraph
{
namespace onnx_import
{
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
///
/// \return Left and right node after broadcasting.
NodeVector
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param inputs Left and right node (inputs of the binary op).
///
/// \return Left and right node after broadcasting.
inline NodeVector numpy_style_broadcast_for_binary_operation(NodeVector inputs)
{
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
}
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(NodeVector inputs);
NodeVector numpy_style_broadcast(const NodeVector& inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
......
......@@ -77,7 +77,7 @@ namespace ngraph
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{numpy_style_broadcast_for_binary_operation(arg0, arg1)};
NodeVector args{numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1));
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment