Unverified Commit b57c7661 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub

Update broadcasting helpers to use v1 operators. (#4194)

* Helper function get_axes_mapping.

* Enhance Broadcast:v1 NUMPY broadcasting.

- Enable NUMPY broadcasting mechanism to work in bothdirections:
    target_shape <-> arg_shape

* Add opset1:squeeze and fix bug in reading squeezed axis idx.

* Fix and enhance downgrade pass for Broadcast:v1

* Use Broadcast:v1 in ONNX Expand operator.

* Replace Broadcast:v0 with v1 in some helper functions.

* Remove call to deprecated legacy_broadcasting helper function.

* Add helper get_axes_mapping_output function.

* Use directly Broadcast:v1 instead of helper function.

* Get back operators from v0 in helper function.

* Use helper function and some refactoring.

* Add legacy style broadcast helper function for opset1.

* User helper broadcasting function for arithmetic operators.

* Add empty axis only if its size is equal to one.

* Aplly review remarks:

- Rename broadcasting function deleting _values_ infix
- Remove variables used only once.
- Use STL library where possible.
- Remove unnecessary conditions.

* Add helper for Broadcast:v1.

* Fix merge artifact and force unsigned type for argument.

* Review. Add additional check for static output.

* Apply clang-format.

* Fix: call v0 ops in ngraph::builder namespace.
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 8086aeba
......@@ -91,8 +91,8 @@ NodeVector builder::MatmulFactory::make_matmul_op()
// Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes = op::numpy_style_broadcast_for_matmul_operation(
left.get_node_shared_ptr(), right.get_node_shared_ptr());
const OutputVector& broadcasted_nodes =
op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
......
......@@ -229,7 +229,7 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
return builder::opset1::reshape(value, output_shape);
}
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<int64_t> axes)
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes)
{
if (axes.empty())
{
......
......@@ -156,7 +156,7 @@ namespace ngraph
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::int64_t> axes = {0});
std::vector<std::size_t> axes = {0});
}
} // namespace builder
} // namespace ngraph
......@@ -17,6 +17,7 @@
#include "add.hpp"
#include "default_opset.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -28,15 +29,16 @@ namespace ngraph
{
NodeVector add(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Add>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -22,6 +22,7 @@
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -33,15 +34,16 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Divide>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -28,6 +28,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -50,7 +51,10 @@ namespace ngraph
->get_vector<std::size_t>();
const ngraph::Shape shape_shape{shape_vector};
return {ngraph::op::numpy_style_broadcast(data, shape_shape)};
return {std::make_shared<default_opset::Broadcast>(
data,
default_opset::Constant::create(
element::i64, Shape{shape_shape.size()}, shape_shape))};
}
} // namespace set_1
......
......@@ -42,47 +42,52 @@ namespace ngraph
NodeVector instance_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::shared_ptr<ngraph::Node> scale{node.get_ng_inputs().at(1)};
std::shared_ptr<ngraph::Node> bias{node.get_ng_inputs().at(2)};
Output<ngraph::Node> scale(node.get_ng_inputs().at(1));
Output<ngraph::Node> bias(node.get_ng_inputs().at(2));
const Shape& data_shape = data->get_shape();
const Shape& scale_shape = scale.get_shape();
const Shape& bias_shape = bias.get_shape();
const float epsilon{node.get_attribute_value<float>("epsilon", 1e-5f)};
CHECK_VALID_NODE(node,
(scale->get_shape().size() == 1 &&
scale->get_shape()[0] == data->get_shape().at(1)),
CHECK_VALID_NODE(
node,
(scale_shape.size() == 1 && scale_shape[0] == data_shape.at(1)),
"Scale input must be one dimensional vector of number of "
"input data channels size.");
CHECK_VALID_NODE(node,
(bias->get_shape().size() == 1 &&
bias->get_shape()[0] == data->get_shape().at(1)),
(bias_shape.size() == 1 && bias_shape[0] == data_shape.at(1)),
"Bias input must be one dimensional vector of number of "
"input data channels size.");
// all dimensions except spatial/feature
const AxisSet reduction_axes{
common::get_monotonic_range<std::size_t>(data->get_shape().size(), 2)};
common::get_monotonic_range<std::size_t>(data_shape.size(), 2)};
const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<default_opset::Constant>(data->get_element_type(),
data->get_shape(),
std::vector<float>{epsilon});
scale = ngraph::op::legacy_style_broadcast_for_binary_operation(data, scale, 1)
.at(1);
bias = ngraph::op::legacy_style_broadcast_for_binary_operation(data, bias, 1)
.at(1);
std::shared_ptr<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = std::make_shared<ngraph::opset0::Broadcast>(
mean, data->get_shape(), reduction_axes);
std::shared_ptr<ngraph::Node> variance =
builder::variance(data, reduction_axes);
variance = std::make_shared<ngraph::opset0::Broadcast>(
variance, data->get_shape(), reduction_axes);
std::make_shared<default_opset::Constant>(
data->get_element_type(), data_shape, std::vector<float>{epsilon});
scale = ngraph::op::opset1::make_broadcast(scale, data_shape, 1);
bias = ngraph::op::opset1::make_broadcast(bias, data_shape, 1);
Output<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = ngraph::op::opset1::make_broadcast(mean, data_shape, reduction_axes);
Output<ngraph::Node> variance = builder::variance(data, reduction_axes);
variance =
ngraph::op::opset1::make_broadcast(variance, data_shape, reduction_axes);
const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node);
return {scale * (data - mean) / sqrt + bias};
// scale * (data - mean) / sqrt + bias
std::shared_ptr<ngraph::Node> result{
std::make_shared<default_opset::Subtract>(data, mean)};
result = std::make_shared<default_opset::Multiply>(scale, result);
result = std::make_shared<default_opset::Divide>(result, sqrt);
result = std::make_shared<default_opset::Add>(result, bias);
return {result};
}
} // namespace set_1
......
......@@ -35,15 +35,16 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Multiply>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -31,15 +31,16 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Subtract>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -63,8 +63,8 @@ namespace ngraph
if (m_keep_dims == 0)
{
const auto reshaped_indices =
ngraph::builder::opset1::squeeze(indices, {m_normalized_axis});
const auto reshaped_indices = ngraph::builder::opset1::squeeze(
indices, {static_cast<std::size_t>(m_normalized_axis)});
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
}
return std::make_shared<default_opset::Convert>(indices, element::i64);
......
......@@ -81,20 +81,17 @@ std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY ||
m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
{
if (input(0).get_partial_shape().is_static() &&
input_value(1).get_node_shared_ptr()->is_constant())
if (input(0).get_partial_shape().is_static() && output(0).get_partial_shape().is_static())
{
auto arg_shape = input(0).get_shape();
auto target_shape =
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto result_shape = output(0).get_shape();
auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
? m_broadcast_spec.m_axis
: target_shape.size() - arg_shape.size();
: result_shape.size() - arg_shape.size();
NGRAPH_CHECK(start_axis >= 0);
for (size_t i = 0; i < target_shape.size(); i++)
for (size_t i = 0; i < result_shape.size(); i++)
{
if (i < start_axis || target_shape[i] != arg_shape[i - start_axis])
if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
{
broadcast_axes.insert(i);
}
......@@ -229,13 +226,15 @@ void op::v1::Broadcast::validate_and_infer_types()
arg_shape.size());
for (auto i = start_axis; i < target_shape.size(); i++)
{
NODE_VALIDATION_CHECK(this,
arg_shape[i - start_axis] == 1 ||
NODE_VALIDATION_CHECK(
this,
arg_shape[i - start_axis] == 1 || target_shape[i] == 1 ||
arg_shape[i - start_axis] == target_shape[i],
"Broadcast incorrect target shape. Expecting ",
"Broadcast incorrect target shape. Expecting either 1 or ",
arg_shape[i - start_axis],
" . Got ",
target_shape[i]);
result_shape[i] = std::max(arg_shape[i - start_axis], target_shape[i]);
}
}
}
......
......@@ -120,7 +120,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
shared_ptr<Node> curr_time_step_node = op::Constant::create(
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
shared_ptr<Node> batch_seq_length =
Output<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1);
......
......@@ -21,8 +21,10 @@
#include "broadcasting.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/util.hpp"
......@@ -274,44 +276,7 @@ namespace ngraph
return broadcast_node_numpy_style(value, bcast_shape.first, bcast_shape.second[0]);
}
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right)
{
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_output_shape.insert(std::end(right_output_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_full_shape.insert(std::end(right_full_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
OutputVector
numpy_style_broadcast_values_for_matmul_operation(const Output<ngraph::Node>& left,
OutputVector numpy_style_broadcast_for_matmul_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right)
{
const auto& left_shape = left.get_shape();
......@@ -346,13 +311,12 @@ namespace ngraph
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis)
OutputVector legacy_style_broadcast_for_binary_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left->get_shape();
const auto& right_shape = right->get_shape();
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
......@@ -407,9 +371,57 @@ namespace ngraph
return {left, broadcast_right};
}
OutputVector
legacy_style_broadcast_values_for_binary_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right,
NodeVector pdpd_style_broadcast(const NodeVector& inputs, int64_t axis)
{
if (inputs.size() <= 1)
{
return inputs;
}
NodeVector broadcasted_inputs{inputs[0]};
for (std::size_t i = 1; i < inputs.size(); ++i)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0]->get_shape(), axis));
}
return broadcasted_inputs;
}
OutputVector pdpd_style_broadcast(const OutputVector& inputs, int64_t axis)
{
if (inputs.size() <= 1)
{
return inputs;
}
OutputVector broadcasted_inputs{inputs[0]};
for (std::size_t i = 1; i < inputs.size(); ++i)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
}
return broadcasted_inputs;
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
{
std::vector<std::size_t> result(output_shape.size() - input_shape.size());
// Populate the result vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range:
// [start_match_axis, start_match_axis + input_shape.size()]
std::iota(std::begin(result), std::begin(result) + start_match_axis, 0);
std::iota(std::begin(result) + start_match_axis,
std::end(result),
start_match_axis + input_shape.size());
return result;
}
namespace opset1
{
Output<Node> legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left.get_shape();
......@@ -418,7 +430,7 @@ namespace ngraph
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return {left, right};
return right;
}
// Prepare new shape of right operand for broadcasting
......@@ -426,7 +438,7 @@ namespace ngraph
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape[dimension] == 1)
if (new_right_shape.at(dimension) == 1)
{
new_right_shape.pop_back();
}
......@@ -454,67 +466,73 @@ namespace ngraph
new_right_shape.erase(std::begin(new_right_shape),
std::next(std::begin(new_right_shape), num_ones));
auto reshape_right = std::make_shared<ngraph::op::Reshape>(
auto reshape_right = std::make_shared<Reshape>(
right, ngraph::get_default_order(right_shape), new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
auto broadcast_right = std::make_shared<ngraph::op::Broadcast>(
auto broadcasted_right = std::make_shared<v1::Broadcast>(
reshape_right,
left_shape,
calculate_broadcast_axes(left_shape, new_right_shape, start_match_axis));
Constant::create(element::i64, Shape{left_shape.size()}, left_shape),
get_axes_mapping_output(left_shape, new_right_shape, start_match_axis));
return {left, broadcast_right};
return broadcasted_right;
}
NodeVector pdpd_style_broadcast(const NodeVector& inputs, int64_t axis)
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes)
{
if (inputs.size() <= 1)
NGRAPH_CHECK((broadcast_axes.size() <= output_shape.size()));
std::vector<size_t> axes_mapping(output_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
{
return inputs;
axes_mapping.erase(axes_mapping.begin() + *i);
}
return axes_mapping;
}
NodeVector broadcasted_inputs{inputs[0]};
for (std::size_t i = 1; i < inputs.size(); ++i)
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0]->get_shape(), axis));
}
return broadcasted_inputs;
NGRAPH_CHECK((input_shape.size() + start_match_axis <= output_shape.size()));
std::vector<std::size_t> mapping(input_shape.size());
std::iota(std::begin(mapping), std::end(mapping), start_match_axis);
return Constant::create(element::i64, Shape{mapping.size()}, mapping);
}
OutputVector pdpd_style_broadcast(const OutputVector& inputs, int64_t axis)
{
if (inputs.size() <= 1)
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes)
{
return inputs;
std::vector<size_t> axes_mapping{get_axes_mapping(output_shape, broadcast_axes)};
return Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
}
OutputVector broadcasted_inputs{inputs[0]};
for (std::size_t i = 1; i < inputs.size(); ++i)
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes)
{
broadcasted_inputs.push_back(
broadcast_value_pdpd_style(inputs[i], inputs[0].get_shape(), axis));
}
return broadcasted_inputs;
return std::make_shared<v1::Broadcast>(
node,
Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, broadcast_axes));
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
std::size_t start_match_axis)
{
std::vector<std::size_t> result(output_shape.size() - input_shape.size());
// Populate the result vector with monotonic increasing series from 0 until
// output_shape_size, excluding values in range:
// [start_match_axis, start_match_axis + input_shape.size()]
std::iota(std::begin(result), std::begin(result) + start_match_axis, 0);
std::iota(std::begin(result) + start_match_axis,
std::end(result),
start_match_axis + input_shape.size());
return result;
return std::make_shared<v1::Broadcast>(
node,
Constant::create(element::i64, Shape{target_shape.size()}, target_shape),
get_axes_mapping_output(target_shape, node.get_shape(), start_match_axis));
}
} // namespace opset1
} // namespace op
} // namespace ngraph
......@@ -16,7 +16,9 @@
#pragma once
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
......@@ -69,51 +71,10 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_broadcast_values_for_binary_operation");
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary
/// operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
OutputVector legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul`
/// operation.
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html).
/// This mean that only \"stack of matrices\" axes are bidirectionally
/// broadcasted. The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix
/// multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix
/// multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_values_for_matmul_operation.");
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
......@@ -130,7 +91,7 @@ namespace ngraph
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
OutputVector numpy_style_broadcast_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires
......@@ -197,5 +158,68 @@ namespace ngraph
new_shape,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
namespace opset1
{
///
/// \brief Broadcast right node to left node's shape using legacy scheme.
///
/// \param[in] left The left hand side node of binary operation.
/// \param[in] right The right hand side node of binary operation. The one
/// to be broadcasted.
/// \param[in] start_match_axis The axis index starting mutually equal shapes
/// of both nodes.
///
/// \return The Output object connected to node producing broadcasted right node.
///
Output<Node> legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
///
/// \brief Reconstructs axes mapping vector for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The vector with axes indexes mapping .
///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical as the output shape.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The Output object with Node returning axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
std::size_t start_match_axis);
} // namespace opset1
} // namespace op
} // namespace ngraph
......@@ -148,13 +148,48 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{
auto arg = node->input_value(0);
const auto& arg_shape = arg.get_shape();
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape =
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
auto target_shape = node->output(0).get_shape();
NGRAPH_CHECK(node->get_broadcast_axes().first);
// (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{
ngraph::op::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
// the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
std::vector<size_t> empty_axes;
for (size_t a{0}; a < axes_mapping.size(); ++a)
{
if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
{
empty_axes.push_back(a);
}
}
// Check if arg_shape contains some more empty dimensions marked to broadcast.
// If axes_mapping size is less than arg_shape size, then some of arg dimensions may
// be equal to one and marked to broadcast.
if (axes_mapping.size() < arg_shape.size())
{
for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
{
if (arg_shape.at(a) == 1)
{
empty_axes.push_back(a);
}
}
}
if (!empty_axes.empty())
{
squeezed_arg = builder::squeeze(arg, empty_axes);
}
auto replacement_node =
make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second);
make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
replace_node(node, replacement_node);
return replacement_node;
......
......@@ -20,6 +20,7 @@
#include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp"
......@@ -106,25 +107,10 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{
auto result_shape = node->get_broadcast_shape();
auto result_shape_node =
op::Constant::create(element::i64, Shape{result_shape.size()}, result_shape);
auto broadcast_axes = node->get_broadcast_axes();
// Flip broadcast_axes to get axes_mapping
std::vector<size_t> axes_mapping(result_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); i++)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
auto axes_mapping_node =
op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
auto replacement_node = make_shared<op::v1::Broadcast>(
node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0));
replace_node(node, replacement_node);
return replacement_node;
auto replacement_node = ngraph::op::opset1::make_broadcast(
node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
replace_node(node, replacement_node.get_node_shared_ptr());
return replacement_node.get_node_shared_ptr();
}
shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment