Commit e49756c8 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Unify broadcast code in ngraph core and in ONNX bridge (#2771)

* [ONNX] Unify broadcast code in ngraph core and in ONNX bridge

* Use using a bit less
parent 1572d31f
...@@ -179,8 +179,6 @@ add_library(onnx_import STATIC ...@@ -179,8 +179,6 @@ add_library(onnx_import STATIC
op/xor.hpp op/xor.hpp
ops_bridge.cpp ops_bridge.cpp
ops_bridge.hpp ops_bridge.hpp
utils/broadcasting.cpp
utils/broadcasting.hpp
utils/common.hpp utils/common.hpp
utils/convpool.cpp utils/convpool.cpp
utils/convpool.hpp utils/convpool.hpp
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis = auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation( NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
{ {
inline NodeVector add(const Node& node) inline NodeVector add(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector logical_and(const Node& node) inline NodeVector logical_and(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -46,13 +46,15 @@ namespace ngraph ...@@ -46,13 +46,15 @@ namespace ngraph
std::make_shared<ngraph::op::Constant>(data->get_element_type(), std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{}, ngraph::Shape{},
std::vector<double>{max_value}); std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape()); max_value_node =
ngraph::op::make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node = std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(), std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{}, ngraph::Shape{},
std::vector<double>{min_value}); std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape()); min_value_node =
ngraph::op::make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>( return {std::make_shared<ngraph::op::Minimum>(
max_value_node, max_value_node,
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp" #include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
namespace ngraph namespace ngraph
...@@ -142,7 +142,9 @@ namespace ngraph ...@@ -142,7 +142,9 @@ namespace ngraph
const Shape& new_shape = conv_node->get_shape(); const Shape& new_shape = conv_node->get_shape();
auto broadcasted_bias = std::make_shared<ngraph::op::Broadcast>( auto broadcasted_bias = std::make_shared<ngraph::op::Broadcast>(
bias, new_shape, calculate_broadcast_axes(new_shape, bias->get_shape(), 1)); bias,
new_shape,
ngraph::op::calculate_broadcast_axes(new_shape, bias->get_shape(), 1));
return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)}; return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)};
} }
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv_transpose.hpp" #include "ngraph/frontend/onnx_import/op/conv_transpose.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
...@@ -33,6 +32,7 @@ ...@@ -33,6 +32,7 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
...@@ -234,7 +234,8 @@ namespace ngraph ...@@ -234,7 +234,8 @@ namespace ngraph
auto bias = inputs.at(2); auto bias = inputs.at(2);
return {std::make_shared<ngraph::op::Add>( return {std::make_shared<ngraph::op::Add>(
conv_node, make_broadcast_node(bias, conv_node->get_shape(), 1))}; conv_node,
ngraph::op::make_broadcast_node(bias, conv_node->get_shape(), 1))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis = auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation( NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "elu.hpp" #include "elu.hpp"
...@@ -46,12 +46,12 @@ namespace ngraph ...@@ -46,12 +46,12 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0}); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) + return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * alpha_node *
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector equal(const Node& node) inline NodeVector equal(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp" #include "ngraph/frontend/onnx_import/utils/reshape.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -76,7 +76,8 @@ namespace ngraph ...@@ -76,7 +76,8 @@ namespace ngraph
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c); input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = numpy_style_broadcast({a_dot_b, input_c}); NodeVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast({a_dot_b, input_c});
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable" // The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we // to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to // only use the second output from above broadcasting. In other words we want to
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector greater(const Node& node) inline NodeVector greater(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -24,7 +24,9 @@ ...@@ -24,7 +24,9 @@
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph::op;
namespace ngraph namespace ngraph
{ {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "leaky_relu.hpp" #include "leaky_relu.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector less(const Node& node) inline NodeVector less(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -45,10 +45,10 @@ ...@@ -45,10 +45,10 @@
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/broadcasting.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp" #include "utils/rnn/activation_functions.hpp"
...@@ -64,21 +64,21 @@ namespace ngraph ...@@ -64,21 +64,21 @@ namespace ngraph
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast({lhs, rhs}); auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
} }
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast({lhs, rhs}); auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
} }
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs, std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs) const std::shared_ptr<ngraph::Node>& rhs)
{ {
auto args = numpy_style_broadcast({lhs, rhs}); auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))}; return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
} }
...@@ -504,7 +504,7 @@ namespace ngraph ...@@ -504,7 +504,7 @@ namespace ngraph
time_step)); time_step));
std::shared_ptr<ngraph::Node> batch_seq_length = std::shared_ptr<ngraph::Node> batch_seq_length =
legacy_style_broadcast_for_binary_operation( ngraph::op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, m_seq_lengths, batch_axis) curr_time_step_node, m_seq_lengths, batch_axis)
.at(1); .at(1);
......
...@@ -28,8 +28,8 @@ ...@@ -28,8 +28,8 @@
#include "ngraph/op/experimental/quantized_dot.hpp" #include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
/// \brief Slice the sub matrix from the input tensor. /// \brief Slice the sub matrix from the input tensor.
...@@ -127,7 +127,7 @@ namespace ngraph ...@@ -127,7 +127,7 @@ namespace ngraph
if (left_rank > 1 && right_rank > 1) if (left_rank > 1 && right_rank > 1)
{ {
const NodeVector& broadcasted_nodes = const NodeVector& broadcasted_nodes =
numpy_style_broadcast_for_matmul_operation(left, right); ngraph::op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0); left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1); right = broadcasted_nodes.at(1);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis = auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation( NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return { return {
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include "ngraph/op/one_hot.hpp" #include "ngraph/op/one_hot.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp" #include "onehot.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -79,7 +79,8 @@ namespace ngraph ...@@ -79,7 +79,8 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis), std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
values->get_element_type()); values->get_element_type());
auto broadcasted_values = numpy_style_broadcast({one_hot, on_value, off_value}); auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
on_value = broadcasted_values[1]; on_value = broadcasted_values[1];
off_value = broadcasted_values[2]; off_value = broadcasted_values[2];
one_hot = one_hot * (on_value - off_value) + off_value; one_hot = one_hot * (on_value - off_value) + off_value;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector logical_or(const Node& node) inline NodeVector logical_or(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/power.hpp" #include "ngraph/op/power.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
{ {
inline NodeVector pow(const Node& node) inline NodeVector pow(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -29,8 +29,8 @@ ...@@ -29,8 +29,8 @@
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "prelu.hpp" #include "prelu.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -53,11 +53,11 @@ namespace ngraph ...@@ -53,11 +53,11 @@ namespace ngraph
auto it = std::find( auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0)); std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it); auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index); slope = ngraph::op::make_broadcast_node(slope, data->get_shape(), index);
} }
else if (data_shape != slope_shape) else if (data_shape != slope_shape)
{ {
slope = numpy_style_broadcast({slope, data})[0]; slope = ngraph::op::numpy_style_broadcast({slope, data})[0];
} }
// x < 0 => f(x) = x * slope // x < 0 => f(x) = x * slope
...@@ -66,7 +66,7 @@ namespace ngraph ...@@ -66,7 +66,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> zero_node = std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::shared_ptr<ngraph::Node> negative_map =
std::make_shared<ngraph::op::Convert>( std::make_shared<ngraph::op::Convert>(
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp" #include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp" #include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp" #include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
...@@ -31,6 +30,7 @@ ...@@ -31,6 +30,7 @@
#include "ngraph/op/experimental/quantized_conv.hpp" #include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/strides.hpp" #include "ngraph/strides.hpp"
#include "quant_conv.hpp" #include "quant_conv.hpp"
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "reciprocal.hpp" #include "reciprocal.hpp"
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape());
return {one_node / data}; return {one_node / data};
} }
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "utils/norm.hpp" #include "utils/norm.hpp"
#include "utils/reduction.hpp" #include "utils/reduction.hpp"
......
...@@ -28,9 +28,11 @@ ...@@ -28,9 +28,11 @@
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "selu.hpp" #include "selu.hpp"
#include "utils/broadcasting.hpp"
using namespace ngraph::op;
namespace ngraph namespace ngraph
{ {
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "softsign.hpp" #include "softsign.hpp"
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)}; return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size(); auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis = auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank); node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation( NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)}; node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return { return {
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return { return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
} }
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "thresholded_relu.hpp" #include "thresholded_relu.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>( auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node), std::make_shared<ngraph::op::Greater>(data, alpha_node),
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
{ {
inline NodeVector where(const Node& node) inline NodeVector where(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Select>( return {std::make_shared<ngraph::op::Select>(
ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))}; ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))};
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "utils/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
{ {
inline NodeVector logical_xor(const Node& node) inline NodeVector logical_xor(const Node& node)
{ {
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())}; NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
auto left = ng_inputs.at(0); auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left); auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1); auto right = ng_inputs.at(1);
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPys' `matmul` operation
/// \link https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
/// \param start_match_axis The axis along which we want to replicate elements.
/// The starting axis position (0-based) int the output
/// shape from which the current shape of the tensor
/// matches the desired new shape.
///
/// \return The indices of added axes.
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
/// \brief Generate a list of broadcast along axes.
///
/// \details Broadcast "adds" elements along axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// This function will attempt to match shapes, assuming the current shape
/// matches the rightmost positions of the desired new shape. This behaviour
/// is similar to NumPy's broadcasting.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
///
/// \return The indices of added axes.
inline AxisSet calculate_broadcast_axes(const Shape& output_shape, const Shape& input_shape)
{
return calculate_broadcast_axes(
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
ngraph::Shape new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace onnx_import
} // namespace ngraph
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include <vector> #include <vector>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -125,7 +125,7 @@ namespace ngraph ...@@ -125,7 +125,7 @@ namespace ngraph
if (data.size() == 1) if (data.size() == 1)
{ {
node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data); node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data);
node = make_broadcast_node(node, shape); node = ngraph::op::make_broadcast_node(node, shape);
} }
else else
{ {
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
// Templated binary operation - Creates Add, Minimum, Maximum, etc. // Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0, auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) { const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{numpy_style_broadcast({arg0, arg1})}; NodeVector args{ngraph::op::numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1)); return std::make_shared<T>(args.at(0), args.at(1));
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment