Commit 50334cbf authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] numpy broadcasting refactoring (#2496)

* Remove get_numpy_broadcast_shape helper function

* Remove numpy_style_broadcast_for_binary_operation helper function

* Remove TODO

* Review fix pt. 1

* Remove parameters as shape containers

* Fix LSTM

* Review fix pt. 1

* Style apply

* Use old comment
parent b8106133
...@@ -44,8 +44,7 @@ namespace ngraph ...@@ -44,8 +44,7 @@ namespace ngraph
{ {
inline NodeVector add(const Node& node) inline NodeVector add(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector logical_and(const Node& node) inline NodeVector logical_and(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -44,8 +44,7 @@ namespace ngraph ...@@ -44,8 +44,7 @@ namespace ngraph
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector equal(const Node& node) inline NodeVector equal(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -79,8 +79,7 @@ namespace ngraph ...@@ -79,8 +79,7 @@ namespace ngraph
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c); input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = NodeVector broadcasted_nodes = numpy_style_broadcast({a_dot_b, input_c});
numpy_style_broadcast_for_binary_operation(a_dot_b, input_c);
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable" // The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we // to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to // only use the second output from above broadcasting. In other words we want to
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector greater(const Node& node) inline NodeVector greater(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector less(const Node& node) inline NodeVector less(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -62,21 +62,21 @@ namespace ngraph ...@@ -62,21 +62,21 @@ namespace ngraph
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs); auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
} }
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs); auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
} }
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast_for_binary_operation(lhs, rhs); auto args = numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
} }
......
...@@ -46,8 +46,7 @@ namespace ngraph ...@@ -46,8 +46,7 @@ namespace ngraph
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -79,8 +79,9 @@ namespace ngraph ...@@ -79,8 +79,9 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis), std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
values->get_element_type()); values->get_element_type());
on_value = numpy_style_broadcast_for_binary_operation(one_hot, on_value)[1]; auto broadcasted_values = numpy_style_broadcast({one_hot, on_value, off_value});
off_value = numpy_style_broadcast_for_binary_operation(one_hot, off_value)[1]; on_value = broadcasted_values[1];
off_value = broadcasted_values[2];
one_hot = one_hot * (on_value - off_value) + off_value; one_hot = one_hot * (on_value - off_value) + off_value;
return {one_hot}; return {one_hot};
} }
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector logical_or(const Node& node) inline NodeVector logical_or(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -31,8 +31,7 @@ namespace ngraph ...@@ -31,8 +31,7 @@ namespace ngraph
{ {
inline NodeVector pow(const Node& node) inline NodeVector pow(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -57,8 +57,7 @@ namespace ngraph ...@@ -57,8 +57,7 @@ namespace ngraph
} }
else if (data_shape != slope_shape) else if (data_shape != slope_shape)
{ {
auto params = numpy_style_broadcast_for_binary_operation(slope, data); slope = numpy_style_broadcast({slope, data})[0];
slope = params.at(0);
} }
// x < 0 => f(x) = x * slope // x < 0 => f(x) = x * slope
......
...@@ -45,8 +45,7 @@ namespace ngraph ...@@ -45,8 +45,7 @@ namespace ngraph
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -33,8 +33,7 @@ namespace ngraph ...@@ -33,8 +33,7 @@ namespace ngraph
{ {
inline NodeVector logical_xor(const Node& node) inline NodeVector logical_xor(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto left = ng_inputs.at(0); auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left); auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1); auto right = ng_inputs.at(1);
......
...@@ -26,59 +26,59 @@ ...@@ -26,59 +26,59 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "reshape.hpp" #include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation. /// \brief Calculate the output shape of numpy-style broadcast operation for two shapes.
///
/// more info:
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules /// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// example:
/// left: [3, 1, 10] right: [5, 1]
/// return: [3, 5, 10]
/// ///
/// \param left_shape Shape of first input tensor. /// \param left_shape First input shape.
/// \param right_shape Shape of the second input tensor. /// \param right_shape Second input Shape.
/// \return Shape of the output tensor and full shape of input tensors. /// \return Broadcast shape of input shapes.
static std::vector<ngraph::Shape> get_numpy_broadcast_shape(ngraph::Shape left_shape, ngraph::Shape calculate_broadcast_shape(ngraph::Shape left_shape, ngraph::Shape right_shape)
ngraph::Shape right_shape)
{ {
ngraph::Shape output_shape; ngraph::Shape result;
auto rank_left = left_shape.size(); auto left_rank = left_shape.size();
auto rank_right = right_shape.size(); auto right_rank = right_shape.size();
auto max_rank = std::max(rank_left, rank_right); auto max_rank = std::max(left_rank, right_rank);
// left-pad the left_shape with ones // left-pad the left_shape with ones
left_shape.insert(std::begin(left_shape), max_rank - rank_left, 1); left_shape.insert(std::begin(left_shape), max_rank - left_rank, 1);
// left-pad the right_shape with ones // left-pad the right_shape with ones
right_shape.insert(std::begin(right_shape), max_rank - rank_right, 1); right_shape.insert(std::begin(right_shape), max_rank - right_rank, 1);
for (std::size_t index = 0; index < max_rank; ++index) for (std::size_t index = 0; index < max_rank; ++index)
{ {
output_shape.push_back(std::max(left_shape.at(index), right_shape.at(index))); result.push_back(std::max(left_shape.at(index), right_shape.at(index)));
} }
return {output_shape, left_shape, right_shape}; return result;
} };
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes. /// \brief Calculate the output shape of numpy-style broadcast operation for all input shapes.
/// ///
/// This function finds the maximum tensor shape that will be the result of element-wise operation /// This function finds the maximum tensor shape that will be the result of element-wise operation
/// that will be applied to the inputs vector. The function also prepares the shape of each input /// that will be applied to the input shapes vector. The function also prepares the shape of each input
/// for the element-wise operation by left-padding those shapes so that their rank is equal to /// for the element-wise operation by left-padding those shapes so that their rank is equal to
/// the target_shape's rank. /// the left_shape's rank.
/// ///
/// \param inputs A vector of input nodes for which a common shape should be found /// \param input_shapes A vector of input shapes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded /// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object /// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>> static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs) get_numpy_broadcast_shapes(const std::vector<ngraph::Shape>& input_shapes)
{ {
auto shape_left_fold = [](const ngraph::Shape& accumulator, ngraph::Shape target_shape = std::accumulate(std::begin(input_shapes),
const std::shared_ptr<ngraph::Node>& input) { std::end(input_shapes),
// TODO: in a separate PR remove the 'get_numpy_broadcast_shape' function ngraph::Shape{},
return get_numpy_broadcast_shape(accumulator, input->get_shape()).at(0); calculate_broadcast_shape);
};
ngraph::Shape target_shape =
std::accumulate(std::begin(inputs), std::end(inputs), ngraph::Shape{}, shape_left_fold);
std::vector<ngraph::Shape> full_shapes; std::vector<ngraph::Shape> full_shapes;
for (const std::shared_ptr<ngraph::Node>& input : inputs) for (const ngraph::Shape& input : input_shapes)
{ {
ngraph::Shape padded_shape = input->get_shape(); ngraph::Shape padded_shape{input};
padded_shape.insert(std::begin(padded_shape), target_shape.size() - padded_shape.size(), 1); padded_shape.insert(std::begin(padded_shape), target_shape.size() - padded_shape.size(), 1);
full_shapes.push_back(std::move(padded_shape)); full_shapes.push_back(std::move(padded_shape));
} }
...@@ -86,6 +86,24 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>> ...@@ -86,6 +86,24 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
return {target_shape, full_shapes}; return {target_shape, full_shapes};
} }
/// \brief Calculate the output shape of numpy-style broadcast operation for all input nodes.
///
/// \param inputs A vector of input nodes for which a common shape should be found
/// \return A pair that contains the target shape as its first object and a vector of padded
/// input shapes ready to be broadcasted as the second object
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::NodeVector& inputs)
{
std::vector<ngraph::Shape> input_shapes;
for (const auto& input : inputs)
{
input_shapes.push_back(input->get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node. /// \brief Broadcast input node.
/// ///
/// \note The source shape does not have to be the actual shape of input node. However /// \note The source shape does not have to be the actual shape of input node. However
...@@ -141,22 +159,7 @@ namespace ngraph ...@@ -141,22 +159,7 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
NodeVector NodeVector numpy_style_broadcast(const NodeVector& inputs)
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
const auto& numpy_shapes = get_numpy_broadcast_shape(left_shape, right_shape);
auto output_shape = numpy_shapes.at(0);
auto left_full_shape = numpy_shapes.at(1);
auto right_full_shape = numpy_shapes.at(2);
return {broadcast_node_numpy_style(left, output_shape, left_full_shape),
broadcast_node_numpy_style(right, output_shape, right_full_shape)};
}
NodeVector numpy_style_broadcast(NodeVector inputs)
{ {
if (inputs.size() <= 1) if (inputs.size() <= 1)
{ {
...@@ -186,13 +189,13 @@ namespace ngraph ...@@ -186,13 +189,13 @@ namespace ngraph
const auto& left_shape = left->get_shape(); const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape(); const auto& right_shape = right->get_shape();
// Broadcast only _stack of matrices_ axes. // Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shape( const auto& numpy_shapes = get_numpy_broadcast_shapes(
Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)}, {Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}); Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes. // Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.at(0); auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.at(0); auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions. // Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape), left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2), std::next(std::begin(left_shape), left_shape.size() - 2),
...@@ -201,8 +204,8 @@ namespace ngraph ...@@ -201,8 +204,8 @@ namespace ngraph
std::next(std::begin(right_shape), right_shape.size() - 2), std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape)); std::end(right_shape));
auto left_full_shape = numpy_shapes.at(1); auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.at(2); auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions. // Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape), left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2), std::next(std::begin(left_shape), left_shape.size() - 2),
......
...@@ -27,32 +27,12 @@ namespace ngraph ...@@ -27,32 +27,12 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
///
/// \return Left and right node after broadcasting.
NodeVector
numpy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// \param inputs Left and right node (inputs of the binary op).
///
/// \return Left and right node after broadcasting.
inline NodeVector numpy_style_broadcast_for_binary_operation(NodeVector inputs)
{
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
}
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility /// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
/// ///
/// \param inputs Original list of inputs /// \param inputs Original list of inputs
/// ///
/// \return Numpy-style broadcasted list of nodes. /// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(NodeVector inputs); NodeVector numpy_style_broadcast(const NodeVector& inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation. /// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
/// ///
......
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
// Templated binary operation - Creates Add, Minimum, Maximum, etc. // Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0, auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) { const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{numpy_style_broadcast_for_binary_operation(arg0, arg1)}; NodeVector args{numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1)); return std::make_shared<T>(args.at(0), args.at(1));
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment