Commit 7d9ce6d7 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Remove broadcasting from ONNX importer. (#3660)

* [ONNX] Remove broadcasting from ONNX importer.

* [ONNX] Used Xor op instead of And,OR and NOT ops.
parent 1ddda541
......@@ -47,8 +47,10 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Add>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_7
......
......@@ -47,8 +47,10 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Divide>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -31,8 +31,10 @@ namespace ngraph
{
inline NodeVector equal(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Equal>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -31,9 +31,10 @@ namespace ngraph
{
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Greater>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -31,8 +31,10 @@ namespace ngraph
{
inline NodeVector less(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Less>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -49,9 +49,10 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Multiply>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_7
......
......@@ -31,8 +31,10 @@ namespace ngraph
{
inline NodeVector logical_or(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Or>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -31,8 +31,10 @@ namespace ngraph
{
inline NodeVector pow(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Power>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -48,9 +48,10 @@ namespace ngraph
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::Subtract>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
......@@ -18,10 +18,8 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/xor.hpp"
namespace ngraph
{
......@@ -33,14 +31,10 @@ namespace ngraph
{
inline NodeVector logical_xor(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
auto not_right = std::make_shared<ngraph::op::Not>(right);
return {std::make_shared<ngraph::op::Or>(
std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
return {std::make_shared<ngraph::op::Xor>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment