Unverified Commit b57c7661 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub

Update broadcasting helpers to use v1 operators. (#4194)

* Helper function get_axes_mapping.

* Enhance Broadcast:v1 NUMPY broadcasting.

- Enable NUMPY broadcasting mechanism to work in bothdirections:
    target_shape <-> arg_shape

* Add opset1:squeeze and fix bug in reading squeezed axis idx.

* Fix and enhance downgrade pass for Broadcast:v1

* Use Broadcast:v1 in ONNX Expand operator.

* Replace Broadcast:v0 with v1 in some helper functions.

* Remove call to deprecated legacy_broadcasting helper function.

* Add helper get_axes_mapping_output function.

* Use directly Broadcast:v1 instead of helper function.

* Get back operators from v0 in helper function.

* Use helper function and some refactoring.

* Add legacy style broadcast helper function for opset1.

* User helper broadcasting function for arithmetic operators.

* Add empty axis only if its size is equal to one.

* Aplly review remarks:

- Rename broadcasting function deleting _values_ infix
- Remove variables used only once.
- Use STL library where possible.
- Remove unnecessary conditions.

* Add helper for Broadcast:v1.

* Fix merge artifact and force unsigned type for argument.

* Review. Add additional check for static output.

* Apply clang-format.

* Fix: call v0 ops in ngraph::builder namespace.
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 8086aeba
......@@ -91,8 +91,8 @@ NodeVector builder::MatmulFactory::make_matmul_op()
// Broadcast input arguments only if both of them are not vectors.
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes = op::numpy_style_broadcast_for_matmul_operation(
left.get_node_shared_ptr(), right.get_node_shared_ptr());
const OutputVector& broadcasted_nodes =
op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
......
......@@ -229,7 +229,7 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
return builder::opset1::reshape(value, output_shape);
}
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<int64_t> axes)
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<size_t> axes)
{
if (axes.empty())
{
......
......@@ -156,7 +156,7 @@ namespace ngraph
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::int64_t> axes = {0});
std::vector<std::size_t> axes = {0});
}
} // namespace builder
} // namespace ngraph
......@@ -17,6 +17,7 @@
#include "add.hpp"
#include "default_opset.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -28,15 +29,16 @@ namespace ngraph
{
NodeVector add(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Add>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -22,6 +22,7 @@
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -33,15 +34,16 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Divide>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -28,6 +28,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -50,7 +51,10 @@ namespace ngraph
->get_vector<std::size_t>();
const ngraph::Shape shape_shape{shape_vector};
return {ngraph::op::numpy_style_broadcast(data, shape_shape)};
return {std::make_shared<default_opset::Broadcast>(
data,
default_opset::Constant::create(
element::i64, Shape{shape_shape.size()}, shape_shape))};
}
} // namespace set_1
......
......@@ -42,47 +42,52 @@ namespace ngraph
NodeVector instance_norm(const Node& node)
{
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::shared_ptr<ngraph::Node> scale{node.get_ng_inputs().at(1)};
std::shared_ptr<ngraph::Node> bias{node.get_ng_inputs().at(2)};
Output<ngraph::Node> scale(node.get_ng_inputs().at(1));
Output<ngraph::Node> bias(node.get_ng_inputs().at(2));
const Shape& data_shape = data->get_shape();
const Shape& scale_shape = scale.get_shape();
const Shape& bias_shape = bias.get_shape();
const float epsilon{node.get_attribute_value<float>("epsilon", 1e-5f)};
CHECK_VALID_NODE(node,
(scale->get_shape().size() == 1 &&
scale->get_shape()[0] == data->get_shape().at(1)),
"Scale input must be one dimensional vector of number of "
"input data channels size.");
CHECK_VALID_NODE(
node,
(scale_shape.size() == 1 && scale_shape[0] == data_shape.at(1)),
"Scale input must be one dimensional vector of number of "
"input data channels size.");
CHECK_VALID_NODE(node,
(bias->get_shape().size() == 1 &&
bias->get_shape()[0] == data->get_shape().at(1)),
(bias_shape.size() == 1 && bias_shape[0] == data_shape.at(1)),
"Bias input must be one dimensional vector of number of "
"input data channels size.");
// all dimensions except spatial/feature
const AxisSet reduction_axes{
common::get_monotonic_range<std::size_t>(data->get_shape().size(), 2)};
common::get_monotonic_range<std::size_t>(data_shape.size(), 2)};
const std::shared_ptr<ngraph::Node> eps_node =
std::make_shared<default_opset::Constant>(data->get_element_type(),
data->get_shape(),
std::vector<float>{epsilon});
scale = ngraph::op::legacy_style_broadcast_for_binary_operation(data, scale, 1)
.at(1);
bias = ngraph::op::legacy_style_broadcast_for_binary_operation(data, bias, 1)
.at(1);
std::shared_ptr<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = std::make_shared<ngraph::opset0::Broadcast>(
mean, data->get_shape(), reduction_axes);
std::shared_ptr<ngraph::Node> variance =
builder::variance(data, reduction_axes);
variance = std::make_shared<ngraph::opset0::Broadcast>(
variance, data->get_shape(), reduction_axes);
std::make_shared<default_opset::Constant>(
data->get_element_type(), data_shape, std::vector<float>{epsilon});
scale = ngraph::op::opset1::make_broadcast(scale, data_shape, 1);
bias = ngraph::op::opset1::make_broadcast(bias, data_shape, 1);
Output<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = ngraph::op::opset1::make_broadcast(mean, data_shape, reduction_axes);
Output<ngraph::Node> variance = builder::variance(data, reduction_axes);
variance =
ngraph::op::opset1::make_broadcast(variance, data_shape, reduction_axes);
const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node);
return {scale * (data - mean) / sqrt + bias};
// scale * (data - mean) / sqrt + bias
std::shared_ptr<ngraph::Node> result{
std::make_shared<default_opset::Subtract>(data, mean)};
result = std::make_shared<default_opset::Multiply>(scale, result);
result = std::make_shared<default_opset::Divide>(result, sqrt);
result = std::make_shared<default_opset::Add>(result, bias);
return {result};
}
} // namespace set_1
......
......@@ -35,15 +35,16 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Multiply>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -31,15 +31,16 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
auto left_rank = node.get_ng_inputs().at(0)->get_shape().size();
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
auto lhs_rank = lhs_node.get_shape().size();
auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation(
lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Subtract>(
ng_inputs.at(0), ng_inputs.at(1), ngraph::op::AutoBroadcastSpec::NONE)};
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
} // namespace set_1
......
......@@ -63,8 +63,8 @@ namespace ngraph
if (m_keep_dims == 0)
{
const auto reshaped_indices =
ngraph::builder::opset1::squeeze(indices, {m_normalized_axis});
const auto reshaped_indices = ngraph::builder::opset1::squeeze(
indices, {static_cast<std::size_t>(m_normalized_axis)});
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
}
return std::make_shared<default_opset::Convert>(indices, element::i64);
......
......@@ -81,20 +81,17 @@ std::pair<bool, AxisSet> op::v1::Broadcast::get_broadcast_axes() const
else if (m_broadcast_spec.m_type == AutoBroadcastType::NUMPY ||
m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
{
if (input(0).get_partial_shape().is_static() &&
input_value(1).get_node_shared_ptr()->is_constant())
if (input(0).get_partial_shape().is_static() && output(0).get_partial_shape().is_static())
{
auto arg_shape = input(0).get_shape();
auto target_shape =
static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
->get_shape_val();
auto result_shape = output(0).get_shape();
auto start_axis = (m_broadcast_spec.m_type == AutoBroadcastType::PDPD)
? m_broadcast_spec.m_axis
: target_shape.size() - arg_shape.size();
: result_shape.size() - arg_shape.size();
NGRAPH_CHECK(start_axis >= 0);
for (size_t i = 0; i < target_shape.size(); i++)
for (size_t i = 0; i < result_shape.size(); i++)
{
if (i < start_axis || target_shape[i] != arg_shape[i - start_axis])
if (i < start_axis || result_shape[i] != arg_shape[i - start_axis])
{
broadcast_axes.insert(i);
}
......@@ -229,13 +226,15 @@ void op::v1::Broadcast::validate_and_infer_types()
arg_shape.size());
for (auto i = start_axis; i < target_shape.size(); i++)
{
NODE_VALIDATION_CHECK(this,
arg_shape[i - start_axis] == 1 ||
arg_shape[i - start_axis] == target_shape[i],
"Broadcast incorrect target shape. Expecting ",
arg_shape[i - start_axis],
" . Got ",
target_shape[i]);
NODE_VALIDATION_CHECK(
this,
arg_shape[i - start_axis] == 1 || target_shape[i] == 1 ||
arg_shape[i - start_axis] == target_shape[i],
"Broadcast incorrect target shape. Expecting either 1 or ",
arg_shape[i - start_axis],
" . Got ",
target_shape[i]);
result_shape[i] = std::max(arg_shape[i - start_axis], target_shape[i]);
}
}
}
......
......@@ -120,7 +120,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
shared_ptr<Node> curr_time_step_node = op::Constant::create(
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
shared_ptr<Node> batch_seq_length =
Output<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1);
......
This diff is collapsed.
......@@ -16,7 +16,9 @@
#pragma once
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
......@@ -69,50 +71,9 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_broadcast_values_for_binary_operation");
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary
/// operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul`
/// operation.
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html).
/// This mean that only \"stack of matrices\" axes are bidirectionally
/// broadcasted. The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix
/// multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix
/// multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_values_for_matmul_operation.");
OutputVector legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix
/// multiplication.
......@@ -130,8 +91,8 @@ namespace ngraph
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
OutputVector numpy_style_broadcast_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Cast shape of all input nodes for an element-wise operation that requires
/// shape-compatibility
......@@ -197,5 +158,68 @@ namespace ngraph
new_shape,
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
} // namespace op
namespace opset1
{
///
/// \brief Broadcast right node to left node's shape using legacy scheme.
///
/// \param[in] left The left hand side node of binary operation.
/// \param[in] right The right hand side node of binary operation. The one
/// to be broadcasted.
/// \param[in] start_match_axis The axis index starting mutually equal shapes
/// of both nodes.
///
/// \return The Output object connected to node producing broadcasted right node.
///
Output<Node> legacy_style_broadcast_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
///
/// \brief Reconstructs axes mapping vector for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The vector with axes indexes mapping .
///
std::vector<std::size_t> get_axes_mapping(const Shape& output_shape,
const AxisSet& broadcast_axes);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] input_shape The input shape.
/// \param[in] start_match_axis The axis index at which input shape starts to be
/// identical as the output shape.
///
/// \return Returns the Output object pointing to node with the axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
///
/// \brief Creates Node returning the axes mapping for Broadcast:v1 operation.
///
/// \param[in] output_shape The output shape of Broadcast operation.
/// \param[in] broadcast_axes The broadcast axes used for Broadcast:v0 operator.
///
/// \return The Output object with Node returning axes mapping.
///
Output<Node> get_axes_mapping_output(const Shape& output_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
const AxisSet& broadcast_axes);
Output<Node> make_broadcast(const Output<Node>& node,
const Shape& target_shape,
std::size_t start_match_axis);
} // namespace opset1
} // namespace op
} // namespace ngraph
......@@ -148,13 +148,48 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{
auto arg = node->input_value(0);
const auto& arg_shape = arg.get_shape();
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto target_shape =
static_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr())
->get_shape_val();
auto target_shape = node->output(0).get_shape();
NGRAPH_CHECK(node->get_broadcast_axes().first);
// (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{
ngraph::op::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
// the empty ones (dim:=1), which would be broadcasted by Broadcast v1.
std::vector<size_t> empty_axes;
for (size_t a{0}; a < axes_mapping.size(); ++a)
{
if (arg_shape.at(a) == 1 && target_shape.at(axes_mapping.at(a)) != 1)
{
empty_axes.push_back(a);
}
}
// Check if arg_shape contains some more empty dimensions marked to broadcast.
// If axes_mapping size is less than arg_shape size, then some of arg dimensions may
// be equal to one and marked to broadcast.
if (axes_mapping.size() < arg_shape.size())
{
for (size_t a{axes_mapping.size()}; a < arg_shape.size(); ++a)
{
if (arg_shape.at(a) == 1)
{
empty_axes.push_back(a);
}
}
}
if (!empty_axes.empty())
{
squeezed_arg = builder::squeeze(arg, empty_axes);
}
auto replacement_node =
make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second);
make_shared<op::v0::Broadcast>(squeezed_arg, target_shape, broadcast_axes);
replace_node(node, replacement_node);
return replacement_node;
......
......@@ -20,6 +20,7 @@
#include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp"
......@@ -106,25 +107,10 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{
auto result_shape = node->get_broadcast_shape();
auto result_shape_node =
op::Constant::create(element::i64, Shape{result_shape.size()}, result_shape);
auto broadcast_axes = node->get_broadcast_axes();
// Flip broadcast_axes to get axes_mapping
std::vector<size_t> axes_mapping(result_shape.size());
std::iota(axes_mapping.begin(), axes_mapping.end(), 0);
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); i++)
{
axes_mapping.erase(axes_mapping.begin() + *i);
}
auto axes_mapping_node =
op::Constant::create(element::i64, Shape{axes_mapping.size()}, axes_mapping);
auto replacement_node = make_shared<op::v1::Broadcast>(
node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0));
replace_node(node, replacement_node);
return replacement_node;
auto replacement_node = ngraph::op::opset1::make_broadcast(
node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
replace_node(node, replacement_node.get_node_shared_ptr());
return replacement_node.get_node_shared_ptr();
}
shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment