Commit 7d9ce6d7 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Remove broadcasting from ONNX importer. (#3660)

* [ONNX] Remove broadcasting from ONNX importer.

* [ONNX] Used Xor op instead of And,OR and NOT ops.
parent 1ddda541
...@@ -47,8 +47,10 @@ namespace ngraph ...@@ -47,8 +47,10 @@ namespace ngraph
{ {
inline NodeVector add(const Node& node) inline NodeVector add(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Add>(
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_7 } // namespace set_7
......
...@@ -47,8 +47,10 @@ namespace ngraph ...@@ -47,8 +47,10 @@ namespace ngraph
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Divide>(
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
{ {
inline NodeVector equal(const Node& node) inline NodeVector equal(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Equal>(
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,9 +31,10 @@ namespace ngraph ...@@ -31,9 +31,10 @@ namespace ngraph
{ {
inline NodeVector greater(const Node& node) inline NodeVector greater(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Greater>(
return { node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
{ {
inline NodeVector less(const Node& node) inline NodeVector less(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Less>(
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -49,9 +49,10 @@ namespace ngraph ...@@ -49,9 +49,10 @@ namespace ngraph
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Multiply>(
return { node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_7 } // namespace set_7
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
{ {
inline NodeVector logical_or(const Node& node) inline NodeVector logical_or(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Or>(
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,8 +31,10 @@ namespace ngraph ...@@ -31,8 +31,10 @@ namespace ngraph
{ {
inline NodeVector pow(const Node& node) inline NodeVector pow(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Power>(
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -48,9 +48,10 @@ namespace ngraph ...@@ -48,9 +48,10 @@ namespace ngraph
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Subtract>(
return { node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -18,10 +18,8 @@ ...@@ -18,10 +18,8 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/xor.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,14 +31,10 @@ namespace ngraph ...@@ -33,14 +31,10 @@ namespace ngraph
{ {
inline NodeVector logical_xor(const Node& node) inline NodeVector logical_xor(const Node& node)
{ {
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())}; return {std::make_shared<ngraph::op::Xor>(
auto left = ng_inputs.at(0); node.get_ng_inputs().at(0),
auto not_left = std::make_shared<ngraph::op::Not>(left); node.get_ng_inputs().at(1),
auto right = ng_inputs.at(1); ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
auto not_right = std::make_shared<ngraph::op::Not>(right);
return {std::make_shared<ngraph::op::Or>(
std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
} }
} // namespace set_1 } // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment