Commit e49756c8 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Unify broadcast code in ngraph core and in ONNX bridge (#2771)

* [ONNX] Unify broadcast code in ngraph core and in ONNX bridge

* Use using a bit less
parent 1572d31f
......@@ -179,8 +179,6 @@ add_library(onnx_import STATIC
op/xor.hpp
ops_bridge.cpp
ops_bridge.hpp
utils/broadcasting.cpp
utils/broadcasting.hpp
utils/common.hpp
utils/convpool.cpp
utils/convpool.hpp
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
......@@ -47,7 +47,7 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/and.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector logical_and(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -23,7 +23,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -46,13 +46,15 @@ namespace ngraph
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
max_value_node =
ngraph::op::make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
min_value_node =
ngraph::op::make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>(
max_value_node,
......
......@@ -20,13 +20,13 @@
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "op/conv.hpp"
namespace ngraph
......@@ -142,7 +142,9 @@ namespace ngraph
const Shape& new_shape = conv_node->get_shape();
auto broadcasted_bias = std::make_shared<ngraph::op::Broadcast>(
bias, new_shape, calculate_broadcast_axes(new_shape, bias->get_shape(), 1));
bias,
new_shape,
ngraph::op::calculate_broadcast_axes(new_shape, bias->get_shape(), 1));
return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)};
}
......
......@@ -23,7 +23,6 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv_transpose.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
......@@ -33,6 +32,7 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
......@@ -234,7 +234,8 @@ namespace ngraph
auto bias = inputs.at(2);
return {std::make_shared<ngraph::op::Add>(
conv_node, make_broadcast_node(bias, conv_node->get_shape(), 1))};
conv_node,
ngraph::op::make_broadcast_node(bias, conv_node->get_shape(), 1))};
}
} // namespace set_1
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/divide.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
......@@ -47,7 +47,7 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -26,7 +26,7 @@
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "elu.hpp"
......@@ -46,12 +46,12 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node *
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/equal.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector equal(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -16,12 +16,12 @@
#include "op/gemm.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -76,7 +76,8 @@ namespace ngraph
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = numpy_style_broadcast({a_dot_b, input_c});
NodeVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast({a_dot_b, input_c});
// The ONNX documentation says that `input_c` should be "unidirectional broadcastable"
// to the `a_dot_b` tensor. Since numpy style broadcasting is bidirectional, below we
// only use the second output from above broadcasting. In other words we want to
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/greater.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -24,7 +24,9 @@
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph::op;
namespace ngraph
{
......
......@@ -27,7 +27,7 @@
#include "core/node.hpp"
#include "leaky_relu.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -48,7 +48,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/less.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector less(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -45,10 +45,10 @@
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "utils/broadcasting.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp"
......@@ -64,21 +64,21 @@ namespace ngraph
std::shared_ptr<ngraph::Node> add(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast({lhs, rhs});
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Add>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> sub(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast({lhs, rhs});
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Subtract>(args.at(0), args.at(1))};
}
std::shared_ptr<ngraph::Node> mul(const std::shared_ptr<ngraph::Node>& lhs,
const std::shared_ptr<ngraph::Node>& rhs)
{
auto args = numpy_style_broadcast({lhs, rhs});
auto args = ngraph::op::numpy_style_broadcast({lhs, rhs});
return {std::make_shared<ngraph::op::Multiply>(args.at(0), args.at(1))};
}
......@@ -504,7 +504,7 @@ namespace ngraph
time_step));
std::shared_ptr<ngraph::Node> batch_seq_length =
legacy_style_broadcast_for_binary_operation(
ngraph::op::legacy_style_broadcast_for_binary_operation(
curr_time_step_node, m_seq_lengths, batch_axis)
.at(1);
......
......@@ -28,8 +28,8 @@
#include "ngraph/op/experimental/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "utils/reshape.hpp"
/// \brief Slice the sub matrix from the input tensor.
......@@ -127,7 +127,7 @@ namespace ngraph
if (left_rank > 1 && right_rank > 1)
{
const NodeVector& broadcasted_nodes =
numpy_style_broadcast_for_matmul_operation(left, right);
ngraph::op::numpy_style_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1);
......
......@@ -20,7 +20,7 @@
#include "ngraph/node_vector.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/multiply.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -36,7 +36,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {
......@@ -49,7 +49,7 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -19,7 +19,6 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/not.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......
......@@ -26,8 +26,8 @@
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -79,7 +79,8 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
values->get_element_type());
auto broadcasted_values = numpy_style_broadcast({one_hot, on_value, off_value});
auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
on_value = broadcasted_values[1];
off_value = broadcasted_values[2];
one_hot = one_hot * (on_value - off_value) + off_value;
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/or.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector logical_or(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/power.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector pow(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -29,8 +29,8 @@
#include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "prelu.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -53,11 +53,11 @@ namespace ngraph
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
slope = ngraph::op::make_broadcast_node(slope, data->get_shape(), index);
}
else if (data_shape != slope_shape)
{
slope = numpy_style_broadcast({slope, data})[0];
slope = ngraph::op::numpy_style_broadcast({slope, data})[0];
}
// x < 0 => f(x) = x * slope
......@@ -66,7 +66,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> negative_map =
std::make_shared<ngraph::op::Convert>(
......
......@@ -22,7 +22,6 @@
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/op/conv.hpp"
#include "ngraph/frontend/onnx_import/utils/broadcasting.hpp"
#include "ngraph/frontend/onnx_import/utils/convpool.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
......@@ -31,6 +30,7 @@
#include "ngraph/op/experimental/quantized_conv.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/strides.hpp"
#include "quant_conv.hpp"
......
......@@ -19,8 +19,8 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "reciprocal.hpp"
......@@ -38,7 +38,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape());
return {one_node / data};
}
......
......@@ -28,7 +28,7 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/norm.hpp"
#include "utils/reduction.hpp"
......
......@@ -28,9 +28,11 @@
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "selu.hpp"
#include "utils/broadcasting.hpp"
using namespace ngraph::op;
namespace ngraph
{
......
......@@ -21,8 +21,8 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
#include "softsign.hpp"
......@@ -40,7 +40,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
}
......
......@@ -19,7 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/subtract.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -35,7 +35,7 @@ namespace ngraph
auto right_rank = node.get_ng_inputs().at(1)->get_shape().size();
auto axis =
node.get_attribute_value<std::int64_t>("axis", left_rank - right_rank);
NodeVector ng_inputs{legacy_style_broadcast_for_binary_operation(
NodeVector ng_inputs{ngraph::op::legacy_style_broadcast_for_binary_operation(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1), axis)};
return {
......@@ -48,7 +48,7 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
......
......@@ -24,8 +24,8 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "thresholded_relu.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -43,7 +43,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
......
......@@ -21,7 +21,7 @@
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/select.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -33,7 +33,7 @@ namespace ngraph
{
inline NodeVector where(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Select>(
ng_inputs.at(0), ng_inputs.at(1), ng_inputs.at(2))};
......
......@@ -21,7 +21,7 @@
#include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "utils/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -33,7 +33,7 @@ namespace ngraph
{
inline NodeVector logical_xor(const Node& node)
{
NodeVector ng_inputs{numpy_style_broadcast(node.get_ng_inputs())};
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace onnx_import
{
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPys' `matmul` operation
/// \link https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
/// \param start_match_axis The axis along which we want to replicate elements.
/// The starting axis position (0-based) int the output
/// shape from which the current shape of the tensor
/// matches the desired new shape.
///
/// \return The indices of added axes.
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
/// \brief Generate a list of broadcast along axes.
///
/// \details Broadcast "adds" elements along axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// This function will attempt to match shapes, assuming the current shape
/// matches the rightmost positions of the desired new shape. This behaviour
/// is similar to NumPy's broadcasting.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
///
/// \return The indices of added axes.
inline AxisSet calculate_broadcast_axes(const Shape& output_shape, const Shape& input_shape)
{
return calculate_broadcast_axes(
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
ngraph::Shape new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace onnx_import
} // namespace ngraph
......@@ -24,8 +24,8 @@
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -125,7 +125,7 @@ namespace ngraph
if (data.size() == 1)
{
node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data);
node = make_broadcast_node(node, shape);
node = ngraph::op::make_broadcast_node(node, shape);
}
else
{
......
......@@ -23,8 +23,8 @@
#include "ngraph/node_vector.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -77,7 +77,7 @@ namespace ngraph
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{numpy_style_broadcast({arg0, arg1})};
NodeVector args{ngraph::op::numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1));
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment