Unverified Commit fe054d67 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub

Unify (static) auto-broadcasting helpers. (#4242)

* Helper function get_axes_mapping.

* Enhance Broadcast:v1 NUMPY broadcasting.

- Enable NUMPY broadcasting mechanism to work in bothdirections:
    target_shape <-> arg_shape

* Add opset1:squeeze and fix bug in reading squeezed axis idx.

* Fix and enhance downgrade pass for Broadcast:v1

* Use Broadcast:v1 in ONNX Expand operator.

* Replace Broadcast:v0 with v1 in some helper functions.

* Remove call to deprecated legacy_broadcasting helper function.

* Add helper get_axes_mapping_output function.

* Use directly Broadcast:v1 instead of helper function.

* Get back operators from v0 in helper function.

* Use helper function and some refactoring.

* Add legacy style broadcast helper function for opset1.

* User helper broadcasting function for arithmetic operators.

* Add empty axis only if its size is equal to one.

* Aplly review remarks:

- Rename broadcasting function deleting _values_ infix
- Remove variables used only once.
- Use STL library where possible.
- Remove unnecessary conditions.

* Add helper for Broadcast:v1.

* Fix merge artifact and force unsigned type for argument.

* Review. Add additional check for static output.

* Apply clang-format.

* Fix: call v0 ops in ngraph::builder namespace.

* Move opset1 boradcasting helpers from util/broadcasting.hpp

* Use autobroadcast built-in mechanism for arithmetic operators in RNN.

* Move helper functions to autobroadcast.hpp file.

- Update calls with new namespace ngraph::builder
- Remove calls using shared_ptr<ngraph::Node> and replace them with
  one accepting Output<ngraph::Node>
- Some small formatting (remove unnecesary namespace prefix)

* Remove unused function.

* Rename error class to reflect it's NumPy related.

* Fix thrown error name in autobroadcast UT.

* Code refactoring.

- Use one one set of helpers to broadcast node according to NumPy scheme

* Documentation formatting.

* Remove include to deleted header.

* Apply style format.

* Remove std:: prefix.

* Do reshape and/or broadcast only when necessary.

* Remove std:: and ngraph:: prefixes.

* UT for numpy_broadcast_for_matmul and legacy boradcast.

* Rename helper function.

* UT for opset1 legacy broadcast helper function.

* Add more UT for get_axes_mapping and style-format.

* Review comments.

* Restrict if with NGRAPH_WARN to NGRAPH_CHECK.
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 639ff3f1
...@@ -445,8 +445,6 @@ set (SRC ...@@ -445,8 +445,6 @@ set (SRC
op/util/binary_elementwise_comparison.hpp op/util/binary_elementwise_comparison.hpp
op/util/binary_elementwise_logical.cpp op/util/binary_elementwise_logical.cpp
op/util/binary_elementwise_logical.hpp op/util/binary_elementwise_logical.hpp
op/util/broadcasting.cpp
op/util/broadcasting.hpp
op/util/fused_op.cpp op/util/fused_op.cpp
op/util/fused_op.hpp op/util/fused_op.hpp
op/util/index_reduction.cpp op/util/index_reduction.cpp
......
This diff is collapsed.
This diff is collapsed.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/matmul_factory.hpp" #include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
...@@ -26,7 +27,6 @@ ...@@ -26,7 +27,6 @@
#include "ngraph/op/quantized_dot.hpp" #include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -92,7 +92,7 @@ NodeVector builder::MatmulFactory::make_matmul_op() ...@@ -92,7 +92,7 @@ NodeVector builder::MatmulFactory::make_matmul_op()
if (left_rank > 1 && right_rank > 1) if (left_rank > 1 && right_rank > 1)
{ {
const OutputVector& broadcasted_nodes = const OutputVector& broadcasted_nodes =
op::numpy_style_broadcast_for_matmul_operation(left, right); builder::numpy_broadcast_for_matmul_operation(left, right);
left = broadcasted_nodes.at(0); left = broadcasted_nodes.at(0);
right = broadcasted_nodes.at(1); right = broadcasted_nodes.at(1);
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "add.hpp" #include "add.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
auto rhs_rank = rhs_node.get_shape().size(); auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank); auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape. // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation( rhs_node = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(
lhs_node, rhs_node, axis); lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Add>( return {std::make_shared<default_opset::Add>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/opsets/opset0.hpp"
#include "utils/convpool.hpp" #include "utils/convpool.hpp"
namespace ngraph namespace ngraph
......
...@@ -20,14 +20,13 @@ ...@@ -20,14 +20,13 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "builder/reshape.hpp"
#include "conv_transpose.hpp" #include "conv_transpose.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "utils/convpool.hpp" #include "utils/convpool.hpp"
...@@ -93,7 +92,7 @@ namespace ngraph ...@@ -93,7 +92,7 @@ namespace ngraph
Shape new_filters_shape{weights_shape}; Shape new_filters_shape{weights_shape};
new_filters_shape.at(0) /= groups; new_filters_shape.at(0) /= groups;
new_filters_shape.insert(std::begin(new_filters_shape), groups); new_filters_shape.insert(std::begin(new_filters_shape), groups);
filters = builder::reshape(filters, new_filters_shape); filters = builder::opset1::reshape(filters, new_filters_shape);
std::shared_ptr<ngraph::Node> conv_node; std::shared_ptr<ngraph::Node> conv_node;
if (!output_shape.empty()) if (!output_shape.empty())
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
auto rhs_rank = rhs_node.get_shape().size(); auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank); auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape. // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation( rhs_node = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(
lhs_node, rhs_node, axis); lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Divide>( return {std::make_shared<default_opset::Divide>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "instance_norm.hpp" #include "instance_norm.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp" #include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -68,15 +68,16 @@ namespace ngraph ...@@ -68,15 +68,16 @@ namespace ngraph
std::make_shared<default_opset::Constant>( std::make_shared<default_opset::Constant>(
data->get_element_type(), data_shape, std::vector<float>{epsilon}); data->get_element_type(), data_shape, std::vector<float>{epsilon});
scale = ngraph::op::opset1::make_broadcast(scale, data_shape, 1); scale = ngraph::builder::opset1::make_broadcast(scale, data_shape, 1);
bias = ngraph::op::opset1::make_broadcast(bias, data_shape, 1); bias = ngraph::builder::opset1::make_broadcast(bias, data_shape, 1);
Output<ngraph::Node> mean = builder::mean(data, reduction_axes); Output<ngraph::Node> mean = builder::mean(data, reduction_axes);
mean = ngraph::op::opset1::make_broadcast(mean, data_shape, reduction_axes); mean =
ngraph::builder::opset1::make_broadcast(mean, data_shape, reduction_axes);
Output<ngraph::Node> variance = builder::variance(data, reduction_axes); Output<ngraph::Node> variance = builder::variance(data, reduction_axes);
variance = variance = ngraph::builder::opset1::make_broadcast(
ngraph::op::opset1::make_broadcast(variance, data_shape, reduction_axes); variance, data_shape, reduction_axes);
const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node); const auto sqrt = std::make_shared<default_opset::Sqrt>(variance + eps_node);
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "utils/variadic.hpp" #include "utils/variadic.hpp"
namespace ngraph namespace ngraph
......
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
auto rhs_rank = rhs_node.get_shape().size(); auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank); auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape. // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation( rhs_node = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(
lhs_node, rhs_node, axis); lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Multiply>( return {std::make_shared<default_opset::Multiply>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <vector> #include <vector>
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "softsign.hpp" #include "softsign.hpp"
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> one_node = std::shared_ptr<ngraph::Node> one_node =
std::make_shared<default_opset::Constant>( std::make_shared<default_opset::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape()); one_node = ngraph::builder::make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<default_opset::Abs>(data) + one_node)}; return {data / (std::make_shared<default_opset::Abs>(data) + one_node)};
} }
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
auto rhs_rank = rhs_node.get_shape().size(); auto rhs_rank = rhs_node.get_shape().size();
auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank); auto axis = node.get_attribute_value<std::int64_t>("axis", lhs_rank - rhs_rank);
// Unidirectional broadcast right node to left shape. // Unidirectional broadcast right node to left shape.
rhs_node = ngraph::op::opset1::legacy_style_broadcast_for_binary_operation( rhs_node = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(
lhs_node, rhs_node, axis); lhs_node, rhs_node, axis);
return {std::make_shared<default_opset::Subtract>( return {std::make_shared<default_opset::Subtract>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "utils/variadic.hpp" #include "utils/variadic.hpp"
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#include "common.hpp" #include "common.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "validation_util.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ngraph/coordinate_diff.hpp" #include "ngraph/coordinate_diff.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
namespace ngraph namespace ngraph
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/attribute_visitor.hpp" #include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/partial_shape.hpp" #include "ngraph/partial_shape.hpp"
#include <numeric> #include <numeric>
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -23,7 +24,6 @@ ...@@ -23,7 +24,6 @@
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -43,7 +43,7 @@ NodeVector op::Elu::decompose_op() const ...@@ -43,7 +43,7 @@ NodeVector op::Elu::decompose_op() const
shared_ptr<Node> alpha_node = shared_ptr<Node> alpha_node =
make_shared<op::Constant>(data.get_element_type(), Shape{}, vector<double>{m_alpha}); make_shared<op::Constant>(data.get_element_type(), Shape{}, vector<double>{m_alpha});
alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data.get_shape()); alpha_node = builder::numpy_broadcast(alpha_node, data.get_shape());
shared_ptr<ngraph::Node> zero_node = shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data.get_element_type(), data.get_shape(), 0); builder::make_constant(data.get_element_type(), data.get_shape(), 0);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "fake_quantize.hpp" #include "fake_quantize.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
...@@ -30,7 +31,6 @@ ...@@ -30,7 +31,6 @@
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
using namespace std; using namespace std;
...@@ -90,7 +90,7 @@ NodeVector op::FakeQuantize::decompose_op() const ...@@ -90,7 +90,7 @@ NodeVector op::FakeQuantize::decompose_op() const
if (m_auto_broadcast.m_type == AutoBroadcastType::NUMPY) if (m_auto_broadcast.m_type == AutoBroadcastType::NUMPY)
{ {
OutputVector broadcasted_nodes = numpy_style_broadcast_values( OutputVector broadcasted_nodes = builder::numpy_broadcast_outputs(
OutputVector{data, input_low, input_high, output_low, output_high}); OutputVector{data, input_low, input_high, output_low, output_high});
data = broadcasted_nodes.at(0); data = broadcasted_nodes.at(0);
...@@ -101,9 +101,9 @@ NodeVector op::FakeQuantize::decompose_op() const ...@@ -101,9 +101,9 @@ NodeVector op::FakeQuantize::decompose_op() const
} }
else if (m_auto_broadcast.m_type == AutoBroadcastType::PDPD) else if (m_auto_broadcast.m_type == AutoBroadcastType::PDPD)
{ {
OutputVector broadcasted_nodes = OutputVector broadcasted_nodes = builder::pdpd_broadcast(
pdpd_style_broadcast(OutputVector{data, input_low, input_high, output_low, output_high}, OutputVector{data, input_low, input_high, output_low, output_high},
m_auto_broadcast.m_axis); m_auto_broadcast.m_axis);
data = broadcasted_nodes.at(0); data = broadcasted_nodes.at(0);
input_low = broadcasted_nodes.at(1); input_low = broadcasted_nodes.at(1);
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/gemm.hpp" #include "ngraph/op/fused/gemm.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/fused/matmul.hpp" #include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -52,36 +52,34 @@ NodeVector op::Gemm::decompose_op() const ...@@ -52,36 +52,34 @@ NodeVector op::Gemm::decompose_op() const
if (m_transA) if (m_transA)
{ {
A = ngraph::builder::transpose(A); A = builder::transpose(A);
} }
if (m_transB) if (m_transB)
{ {
B = ngraph::builder::transpose(B); B = builder::transpose(B);
} }
A = ngraph::builder::flatten(A, 1); A = builder::flatten(A, 1);
B = ngraph::builder::flatten(B, 1); B = builder::flatten(B, 1);
// A' * B' // A' * B'
std::shared_ptr<ngraph::Node> a_dot_b = std::make_shared<ngraph::op::Dot>(A, B); std::shared_ptr<Node> a_dot_b = std::make_shared<op::Dot>(A, B);
// alpha // alpha
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<Node> alpha_node = std::make_shared<op::Constant>(
a_dot_b->get_element_type(), a_dot_b->get_shape(), std::vector<double>{m_alpha}); a_dot_b->get_element_type(), a_dot_b->get_shape(), std::vector<double>{m_alpha});
// alpha * A' * B' // alpha * A' * B'
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b); a_dot_b = std::make_shared<op::Multiply>(alpha_node, a_dot_b);
// beta * C // beta * C
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<Node> beta_node = std::make_shared<op::Constant>(
C.get_element_type(), C.get_shape(), std::vector<double>{m_beta}); C.get_element_type(), C.get_shape(), std::vector<double>{m_beta});
C = std::make_shared<ngraph::op::Multiply>(beta_node, C); C = std::make_shared<op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
OutputVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast_values(OutputVector{a_dot_b, C});
// The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor. // The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor.
// Numpy style broadcast is bidirectional, so we only use the second output from broadcasting. auto broadcasted_c = builder::numpy_broadcast(C, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))}; return {std::make_shared<op::Add>(a_dot_b, broadcasted_c)};
} }
shared_ptr<Node> op::Gemm::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Gemm::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/lstm_sequence.hpp" #include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp" #include "ngraph/frontend/onnx_import/utils/reshape.hpp"
...@@ -25,7 +27,6 @@ ...@@ -25,7 +27,6 @@
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -121,7 +122,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data, ...@@ -121,7 +122,7 @@ shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step)); element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
Output<Node> batch_seq_length = Output<Node> batch_seq_length =
op::legacy_style_broadcast_for_binary_operation( builder::legacy_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis) curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1); .at(1);
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
#include "mvn.hpp" #include "mvn.hpp"
#include "ngraph/builder/reduce_ops.hpp" #include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/fused/normalize_l2.hpp" #include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -22,7 +23,6 @@ ...@@ -22,7 +23,6 @@
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -47,11 +47,11 @@ NodeVector op::PRelu::decompose_op() const ...@@ -47,11 +47,11 @@ NodeVector op::PRelu::decompose_op() const
{ {
auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0)); auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it); auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data.get_shape(), index); slope = builder::make_broadcast_node(slope, data.get_shape(), index);
} }
else if (data_shape != slope_shape) else if (data_shape != slope_shape)
{ {
slope = numpy_style_broadcast_values({slope, data})[0]; slope = builder::numpy_broadcast(slope, data.get_shape());
} }
// x < 0 => f(x) = x * slope // x < 0 => f(x) = x * slope
...@@ -59,7 +59,7 @@ NodeVector op::PRelu::decompose_op() const ...@@ -59,7 +59,7 @@ NodeVector op::PRelu::decompose_op() const
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data.get_element_type(), ngraph::Shape{}, std::vector<double>{0}); data.get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data.get_shape()); zero_node = builder::make_broadcast_node(zero_node, data.get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type()); std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "scale_shift.hpp" #include "scale_shift.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -38,7 +39,7 @@ NodeVector op::ScaleShift::decompose_op() const ...@@ -38,7 +39,7 @@ NodeVector op::ScaleShift::decompose_op() const
auto shift = input_value(2); auto shift = input_value(2);
// broadcast all data // broadcast all data
auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift}); auto broadcasted_nodes = builder::numpy_broadcast_outputs({data, scale, shift});
data = broadcasted_nodes[0]; data = broadcasted_nodes[0];
scale = broadcasted_nodes[1]; scale = broadcasted_nodes[1];
shift = broadcasted_nodes[2]; shift = broadcasted_nodes[2];
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/fused_op.hpp" #include "ngraph/op/util/fused_op.hpp"
using namespace std; using namespace std;
......
This diff is collapsed.
This diff is collapsed.
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "ngraph/op/fused/clamp.hpp" #include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp" #include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -68,20 +67,19 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size ...@@ -68,20 +67,19 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs) shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast_values({lhs, rhs}); return {make_shared<op::Add>(lhs, rhs, op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY))};
return {make_shared<op::Add>(args.at(0), args.at(1))};
} }
shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs) shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast_values({lhs, rhs}); return {
return {make_shared<op::Subtract>(args.at(0), args.at(1))}; make_shared<op::Subtract>(lhs, rhs, op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY))};
} }
shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs) shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
{ {
auto args = op::numpy_style_broadcast_values({lhs, rhs}); return {
return {make_shared<op::Multiply>(args.at(0), args.at(1))}; make_shared<op::Multiply>(lhs, rhs, op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY))};
} }
shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/pass/implicit_broadcast_elimination.hpp" #include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
...@@ -24,7 +25,7 @@ ...@@ -24,7 +25,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<ngraph::Node> node) bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Node> node)
{ {
if (node->supports_auto_broadcast()) if (node->supports_auto_broadcast())
{ {
...@@ -53,11 +54,11 @@ NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node) ...@@ -53,11 +54,11 @@ NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node)
} }
else if (autob.m_type == op::AutoBroadcastType::NUMPY) else if (autob.m_type == op::AutoBroadcastType::NUMPY)
{ {
rc = op::numpy_style_broadcast(node->get_arguments()); rc = as_node_vector(builder::numpy_broadcast_outputs(node->input_values()));
} }
else if (autob.m_type == op::AutoBroadcastType::PDPD) else if (autob.m_type == op::AutoBroadcastType::PDPD)
{ {
rc = op::pdpd_style_broadcast(node->get_arguments(), autob.m_axis); rc = as_node_vector(builder::pdpd_broadcast(node->input_values(), autob.m_axis));
} }
else else
{ {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
namespace ngraph namespace ngraph
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include <functional> #include <functional>
#include <numeric> #include <numeric>
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp" #include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp" #include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset0_downgrade.hpp"
...@@ -157,7 +157,7 @@ namespace ...@@ -157,7 +157,7 @@ namespace
// (Re)construct axes_mapping. // (Re)construct axes_mapping.
AxisSet broadcast_axes = node->get_broadcast_axes().second; AxisSet broadcast_axes = node->get_broadcast_axes().second;
std::vector<size_t> axes_mapping{ std::vector<size_t> axes_mapping{
ngraph::op::opset1::get_axes_mapping(target_shape, broadcast_axes)}; ngraph::builder::opset1::get_axes_mapping(target_shape, broadcast_axes)};
Output<Node> squeezed_arg = arg; Output<Node> squeezed_arg = arg;
// Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze // Collect axes to squeeze. Broadcast v0 "adds" new axes, thus we have to squeeze
...@@ -536,10 +536,10 @@ namespace ...@@ -536,10 +536,10 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
{ {
const auto indices = node->input_value(0).get_node_shared_ptr(); const auto indices = node->input_value(0);
const auto depth = node->input_value(1).get_node_shared_ptr(); const auto depth = node->input_value(1).get_node_shared_ptr();
auto on_value = node->input_value(2).get_node_shared_ptr(); auto on_value = node->input_value(2);
auto off_value = node->input_value(3).get_node_shared_ptr(); auto off_value = node->input_value(3);
const auto axis = node->get_axis(); const auto axis = node->get_axis();
NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node); NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
...@@ -549,9 +549,9 @@ namespace ...@@ -549,9 +549,9 @@ namespace
auto one_hot = std::make_shared<ngraph::op::Convert>( auto one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis), std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
on_value->get_element_type()); on_value.get_element_type());
auto broadcasted_values = op::numpy_style_broadcast({one_hot, on_value, off_value}); auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
on_value = broadcasted_values[1]; on_value = broadcasted_values[1];
off_value = broadcasted_values[2]; off_value = broadcasted_values[2];
......
...@@ -13,17 +13,17 @@ ...@@ -13,17 +13,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <numeric> #include <numeric>
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
using namespace std; using namespace std;
...@@ -108,7 +108,7 @@ namespace ...@@ -108,7 +108,7 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node) shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{ {
auto replacement_node = ngraph::op::opset1::make_broadcast( auto replacement_node = ngraph::builder::opset1::make_broadcast(
node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes()); node->input_value(0), node->get_broadcast_shape(), node->get_broadcast_axes());
replace_node(node, replacement_node.get_node_shared_ptr()); replace_node(node, replacement_node.get_node_shared_ptr());
return replacement_node.get_node_shared_ptr(); return replacement_node.get_node_shared_ptr();
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment