Unverified Commit b57c7661 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub

Update broadcasting helpers to use v1 operators. (#4194)

* Helper function get_axes_mapping.

* Enhance Broadcast:v1 NUMPY broadcasting.

- Enable NUMPY broadcasting mechanism to work in bothdirections:
    target_shape <-> arg_shape

* Add opset1:squeeze and fix bug in reading squeezed axis idx.

* Fix and enhance downgrade pass for Broadcast:v1

* Use Broadcast:v1 in ONNX Expand operator.

* Replace Broadcast:v0 with v1 in some helper functions.

* Remove call to deprecated legacy_broadcasting helper function.

* Add helper get_axes_mapping_output function.

* Use directly Broadcast:v1 instead of helper function.

* Get back operators from v0 in helper function.

* Use helper function and some refactoring.

* Add legacy style broadcast helper function for opset1.

* User helper broadcasting function for arithmetic operators.

* Add empty axis only if its size is equal to one.

* Aplly review remarks:

- Rename broadcasting function deleting _values_ infix
- Remove variables used only once.
- Use STL library where possible.
- Remove unnecessary conditions.

* Add helper for Broadcast:v1.

* Fix merge artifact and force unsigned type for argument.

* Review. Add additional check for static output.

* Apply clang-format.

* Fix: call v0 ops in ngraph::builder namespace.
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 8086aeba
...@@ -91,8 +91,8 @@ NodeVector builder::MatmulFactory::make_matmul_op() ...@@ -91,8 +91,8 @@ NodeVector builder::MatmulFactory::make_matmul_op()
// Broadcast input arguments only if both of them are not vectors. // Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1) if (left_rank > 1 && right_rank > 1)
{ {
const NodeVector& broadcasted_nodes = op::numpy_style_broadcast_for_matmul_operation( const OutputVector& broadcasted_nodes =
left.get_node_shared_ptr(), right.get_node_shared_ptr()); op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0); left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1); right = broadcasted_nodes.at(1);
......
...@@ -229,7 +229,7 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t ...@@ -229,7 +229,7 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
return builder::opset1::reshape(value, output_shape); return builder::opset1::reshape(value, output_shape);
} }
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<int64_t> axes) shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes)
{ {
if (axes.empty()) if (axes.empty())
{ {
......
...@@ -156,7 +156,7 @@ namespace ngraph ...@@ -156,7 +156,7 @@ namespace ngraph
/// ///
/// \return Reshape:v1 op. /// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value, std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::int64_t> axes = {0}); std::vector<std::size_t> axes = {0});
} }
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "add.hpp" #include "add.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -28,15 +29,16 @@ namespace ngraph ...@@ -28,15 +29,16 @@ namespace ngraph
{ {
NodeVector add(const Node& node) NodeVector add(const Node& node)
{ {
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size(); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto axis = auto lhs_rank = lhs_node.get_shape().size();
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); auto rhs_rank = rhs_node.get_shape().size();
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation( auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Add>( return {std::make_shared<default_opset::Add>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,15 +34,16 @@ namespace ngraph ...@@ -33,15 +34,16 @@ namespace ngraph
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size(); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto axis = auto lhs_rank = lhs_node.get_shape().size();
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); auto rhs_rank = rhs_node.get_shape().size();
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation( auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Divide>( return {std::make_shared<default_opset::Divide>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -50,7 +51,10 @@ namespace ngraph ...@@ -50,7 +51,10 @@ namespace ngraph
->get_vector<std::size_t>(); ->get_vector<std::size_t>();
const ngraph::Shape shape_shape{shape_vector}; const ngraph::Shape shape_shape{shape_vector};
return {ngraph::op::numpy_style_broadcast(data, shape_shape)}; return {std::make_shared<default_opset::Broadcast>(
data,
default_opset::Constant::create(
element::i64, Shape{shape_shape.size()}, shape_shape))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -42,47 +42,52 @@ namespace ngraph ...@@ -42,47 +42,52 @@ namespace ngraph
NodeVector instance_norm(const Node& node) NodeVector instance_norm(const Node& node)
{ {
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)}; const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::shared_ptr<ngraph::Node> scale{node.get_ng_inputs().at(1)}; Output<ngraph::Node> scale(node.get_ng_inputs().at(1));
std::shared_ptr<ngraph::Node> bias{node.get_ng_inputs().at(2)}; Output<ngraph::Node> bias(node.get_ng_inputs().at(2));
const Shape& data_shape = data->get_shape();
const Shape& scale_shape = scale.get_shape();
const Shape& bias_shape = bias.get_shape();
const float epsilon{node.get_attribute_value<float>("epsilon", 1e-5f)}; const float epsilon{node.get_attribute_value<float>("epsilon", 1e-5f)};
CHECK_VALID_NODE(node, CHECK_VALID_NODE(
(scale->get_shape().size() == 1 && node,
scale->get_shape()[0] == data->get_shape().at(1)), (scale_shape.size() == 1 && scale_shape[0] == data_shape.at(1)),
"Scale input must be one dimensional vector of number of " "Scale input must be one dimensional vector of number of "
"input data channels size."); "input data channels size.");
CHECK_VALID_NODE(node, CHECK_VALID_NODE(node,
(bias->get_shape().size() == 1 && (bias_shape.size() == 1 && bias_shape[0] == data_shape.at(1)),
bias->get_shape()[0] == data->get_shape().at(1)),
"Bias input must be one dimensional vector of number of " "Bias input must be one dimensional vector of number of "
"input data channels size."); "input data channels size.");
// all dimensions except spatial/feature // all dimensions except spatial/feature
const AxisSet reduction_axes{ const AxisSet reduction_axes{
common::get_monotonic_range<std::size_t>(data->get_shape().size(), 2)}; common::get_monotonic_range<std::size_t>(data_shape.size(), 2)};
const std::shared_ptr<ngraph::Node> eps_node = const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<default_opset::Constant>(data->get_element_type(), std::make_shared<default_opset::Constant>(
data->get_shape(), data->get_element_type(), data_shape, std::vector<float>{epsilon});
std::vector<float>{epsilon});
scale = ngraph::op::opset1::make_broadcast(scale, data_shape, 1);
scale = ngraph::op::legacy_style_broadcast_for_binary_operation(data, scale, 1) bias = ngraph::op::opset1::make_broadcast(bias, data_shape, 1);
.at(1);
bias = ngraph::op::legacy_style_broadcast_for_binary_operation(data, bias, 1) Output<ngraph::Node> mean = builder::mean(data, reduction_axes);
.at(1); mean = ngraph::op::opset1::make_broadcast(mean, data_shape, reduction_axes);
std::shared_ptr<ngraph::Node> mean = builder::mean(data, reduction_axes); Output<ngraph::Node> variance = builder::variance(data, reduction_axes);
mean = std::make_shared<ngraph::opset0::Broadcast>( variance =
mean, data->get_shape(), reduction_axes); ngraph::op::opset1::make_broadcast(variance, data_shape, reduction_axes);
std::shared_ptr<ngraph::Node> variance =
builder::variance(data, reduction_axes);
variance = std::make_shared<ngraph::opset0::Broadcast>(
variance, data->get_shape(), reduction_axes);
const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node); const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node);
return {scale * (data - mean) / sqrt + bias}; // scale * (data - mean) / sqrt + bias
std::shared_ptr<ngraph::Node> result{
std::make_shared<default_opset::Subtract>(data, mean)};
result = std::make_shared<default_opset::Multiply>(scale, result);
result = std::make_shared<default_opset::Divide>(result, sqrt);
result = std::make_shared<default_opset::Add>(result, bias);
return {result};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -35,15 +35,16 @@ namespace ngraph ...@@ -35,15 +35,16 @@ namespace ngraph
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size(); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto axis = auto lhs_rank = lhs_node.get_shape().size();
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); auto rhs_rank = rhs_node.get_shape().size();
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation( auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Multiply>( return {std::make_shared<default_opset::Multiply>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,15 +31,16 @@ namespace ngraph ...@@ -31,15 +31,16 @@ namespace ngraph
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size(); const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto axis = auto lhs_rank = lhs_node.get_shape().size();
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); auto rhs_rank = rhs_node.get_shape().size();
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation( auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Subtract>( return {std::make_shared<default_opset::Subtract>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -63,8 +63,8 @@ namespace ngraph ...@@ -63,8 +63,8 @@ namespace ngraph
if (m_keep_dims == 0) if (m_keep_dims == 0)
{ {
const auto reshaped_indices = const auto reshaped_indices = ngraph::builder::opset1::squeeze(
ngraph::builder::opset1::squeeze(indices, {m_normalized_axis}); indices, {static_cast<std::size_t>(m_normalized_axis)});
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64); return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
} }
return std::make_shared<default_opset::Convert>(indices, element::i64); return std::make_shared<default_opset::Convert>(indices, element::i64);
......
...@@ -81,20 +81,17 @@ std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const ...@@ -81,20 +81,17 @@ std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY || else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY ||
m_broadcast_spec.m_type == AutoBroadcastType::PDPD) m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
{ {
if (input(0).get_partial_shape().is_static() && if (input(0).get_partial_shape().is_static() && output(0).get_partial_shape().is_static())
input_value(1).get_node_shared_ptr()->is_constant())
{ {
auto arg_shape = input(0).get_shape(); auto arg_shape = input(0).get_shape();
auto target_shape = auto result_shape = output(0).get_shape();
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD) auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
? m_broadcast_spec.m_axis ? m_broadcast_spec.m_axis
: target_shape.size() - arg_shape.size(); : result_shape.size() - arg_shape.size();
NGRAPH_CHECK(start_axis >= 0); NGRAPH_CHECK(start_axis >= 0);
for (size_t i = 0; i < target_shape.size(); i++) for (size_t i = 0; i < result_shape.size(); i++)
{ {
if (i < start_axis || target_shape[i] != arg_shape[i - start_axis]) if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
{ {
broadcast_axes.insert(i); broadcast_axes.insert(i);
} }
...@@ -229,13 +226,15 @@ void op::v1::Broadcast::validate_and_infer_types() ...@@ -229,13 +226,15 @@ void op::v1::Broadcast::validate_and_infer_types()
arg_shape.size()); arg_shape.size());
for (auto i = start_axis; i < target_shape.size(); i++) for (auto i = start_axis; i < target_shape.size(); i++)
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(
arg_shape[i - start_axis] == 1 || this,
arg_shape[i - start_axis] == 1 || target_shape[i] == 1 ||
arg_shape[i - start_axis] == target_shape[i], arg_shape[i - start_axis] == target_shape[i],
"Broadcast incorrect target shape. Expecting ", "Broadcast incorrect target shape. Expecting either 1 or ",
arg_shape[i - start_axis], arg_shape[i - start_axis],
" . Got ", " . Got ",
target_shape[i]); target_shape[i]);
result_shape[i] = std::max(arg_shape[i - start_axis], target_shape[i]);
} }
} }
} }
......
...@@ -120,7 +120,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data, ...@@ -120,7 +120,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
shared_ptr<Node> curr_time_step_node = op::Constant::create( shared_ptr<Node> curr_time_step_node = op::Constant::create(
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step)); element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
shared_ptr<Node> batch_seq_length = Output<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation( op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis) curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1); .at(1);
......
This diff is collapsed.
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#pragma once #pragma once
#include <cstddef>
#include <memory> #include <memory>
#include <vector>
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -69,51 +71,10 @@ namespace ngraph ...@@ -69,51 +71,10 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape /// \param start_match_axis position in shape denoting start of the mutually equal shape
/// ///
/// \return Left and right node after broadcasting. /// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left, OutputVector legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_broadcast_values_for_binary_operation");
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary
/// operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
const Output<Node>& right, const Output<Node>& right,
size_t start_match_axis); size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul`
/// operation.
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html).
/// This mean that only \"stack of matrices\" axes are bidirectionally
/// broadcasted. The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix
/// multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix
/// multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_values_for_matmul_operation.");
/// \brief Broadcast shape of two nodes to make them compatible for a matrix /// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication. /// multiplication.
/// ///
...@@ -130,7 +91,7 @@ namespace ngraph ...@@ -130,7 +91,7 @@ namespace ngraph
/// ///
/// \return The vector containing both outputs broadcasted. /// \return The vector containing both outputs broadcasted.
/// ///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left, OutputVector numpy_style_broadcast_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right); const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires /// \brief Cast shape of all input nodes for an element-wise operation that requires
...@@ -197,5 +158,68 @@ namespace ngraph ...@@ -197,5 +158,68 @@ namespace ngraph
new_shape, new_shape,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis)); calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
} }
namespace opset1
{
///
/// \brief Broadcast right node to left node's shape using legacy scheme.
///
/// \param[in] left The left hand side node of binary operation.
/// \param[in] right The right hand side node of binary operation. The one
/// to be broadcasted.
/// \param[in] start_match_axis The axis index starting mutually equal shapes
/// of both nodes.
///
/// \return The Output object connected to node producing broadcasted right node.
///
Output<Node> legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
///
/// \brief Reconstructs axes mapping vector for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The vector with axes indexes mapping .
///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical as the output shape.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The Output object with Node returning axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
std::size_t start_match_axis);
} // namespace opset1
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
...@@ -148,13 +148,48 @@ namespace ...@@ -148,13 +148,48 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{ {
auto arg = node->input_value(0); auto arg = node->input_value(0);
const auto& arg_shape = arg.get_shape();
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape = auto target_shape = node->output(0).get_shape();
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
NGRAPH_CHECK(node->get_broadcast_axes().first); NGRAPH_CHECK(node->get_broadcast_axes().first);
// (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{
ngraph::op::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
// the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
std::vector<size_t> empty_axes;
for (size_t a{0}; a < axes_mapping.size(); ++a)
{
if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
{
empty_axes.push_back(a);
}
}
// Check if arg_shape contains some more empty dimensions marked to broadcast.
// If axes_mapping size is less than arg_shape size, then some of arg dimensions may
// be equal to one and marked to broadcast.
if (axes_mapping.size() < arg_shape.size())
{
for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
{
if (arg_shape.at(a) == 1)
{
empty_axes.push_back(a);
}
}
}
if (!empty_axes.empty())
{
squeezed_arg = builder::squeeze(arg, empty_axes);
}
auto replacement_node = auto replacement_node =
make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second); make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return replacement_node; return replacement_node;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
...@@ -106,25 +107,10 @@ namespace ...@@ -106,25 +107,10 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node) shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{ {
auto result_shape = node->get_broadcast_shape(); auto replacement_node = ngraph::op::opset1::make_broadcast(
auto result_shape_node = node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
op::Constant::create(element::i64, Shape{result_shape.size()}, result_shape); replace_node(node, replacement_node.get_node_shared_ptr());
auto broadcast_axes = node->get_broadcast_axes(); return replacement_node.get_node_shared_ptr();
// Flip broadcast_axes to get axes_mapping
std::vector<size_t> axes_mapping(result_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); i++)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
auto axes_mapping_node =
op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
auto replacement_node = make_shared<op::v1::Broadcast>(
node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0));
replace_node(node, replacement_node);
return replacement_node;
} }
shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; } shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment