Unverified Commit 8ef1ec04 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

[ONNX] Support for legacy broadcasting rules (#1924)

parent c637d629
......@@ -29,6 +29,19 @@ namespace ngraph
namespace op
{
namespace set_1
{
inline NodeVector add(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
namespace set_7
{
inline NodeVector add(const Node& node)
{
......
......@@ -29,6 +29,19 @@ namespace ngraph
namespace op
{
namespace set_1
{
inline NodeVector div(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
namespace set_7
{
inline NodeVector div(const Node& node)
{
......
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/multiply.hpp"
#include "core/node.hpp"
......@@ -29,6 +30,20 @@ namespace ngraph
namespace op
{
namespace set_1
{
inline NodeVector mul(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
namespace set_7
{
inline NodeVector mul(const Node& node)
{
......@@ -38,7 +53,7 @@ namespace ngraph
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
} // namespace set_7
} //namespace op
......
......@@ -29,6 +29,20 @@ namespace ngraph
namespace op
{
namespace set_1
{
inline NodeVector sub(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace set_1
namespace set_7
{
inline NodeVector sub(const Node& node)
{
......
......@@ -146,6 +146,7 @@ namespace ngraph
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Acos", 1, acos);
REGISTER_OPERATOR("Add", 1, add);
REGISTER_OPERATOR("Add", 7, add);
REGISTER_OPERATOR("And", 1, logical_and);
REGISTER_OPERATOR("ArgMin", 1, argmin);
REGISTER_OPERATOR("ArgMax", 1, argmax);
......@@ -161,6 +162,7 @@ namespace ngraph
REGISTER_OPERATOR("Conv", 1, conv);
REGISTER_OPERATOR("Cos", 1, cos);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Div", 7, div);
REGISTER_OPERATOR("Dropout", 1, identity);
REGISTER_OPERATOR("Elu", 1, elu);
REGISTER_OPERATOR("Equal", 1, equal);
......@@ -184,6 +186,7 @@ namespace ngraph
REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Mul", 1, mul);
REGISTER_OPERATOR("Mul", 7, mul);
REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or);
......@@ -214,6 +217,7 @@ namespace ngraph
REGISTER_OPERATOR("Sqrt", 1, sqrt);
REGISTER_OPERATOR("Squeeze", 1, squeeze);
REGISTER_OPERATOR("Sub", 1, sub);
REGISTER_OPERATOR("Sub", 7, sub);
REGISTER_OPERATOR("Sum", 1, sum);
REGISTER_OPERATOR("Tan", 1, tan);
REGISTER_OPERATOR("Tanh", 1, tanh);
......
......@@ -103,6 +103,67 @@ namespace ngraph
return {broadcasted_left, broadcasted_right};
}
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis)
{
auto left_shape = left->get_shape();
auto right_shape = right->get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return {left, right};
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape[dimension] == 1)
{
new_right_shape.pop_back();
}
else
{
break;
}
}
// Find first dimensions at front with length different from 1
size_t num_ones = 0;
for (size_t dimension : new_right_shape)
{
if (dimension == 1)
{
++num_ones;
}
else
{
break;
}
}
// Remove dimensions with length=1 from front
new_right_shape.erase(std::begin(new_right_shape),
std::next(std::begin(new_right_shape), num_ones));
auto reshape_right = std::make_shared<ngraph::op::Reshape>(
right, reshape::get_default_axis_vector(right_shape.size()), new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
auto broadcast_right = std::make_shared<ngraph::op::Broadcast>(
reshape_right,
left_shape,
calculate_broadcast_axes(left_shape, new_right_shape, start_match_axis));
return {left, broadcast_right};
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
......
......@@ -47,6 +47,26 @@ namespace ngraph
return numpy_style_broadcast_for_binary_operation(inputs.at(0), inputs.at(1));
}
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment