Commit b21ff63d authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Logical ops (#1567)

parent 3548772b
...@@ -32,6 +32,7 @@ add_library(onnx_import STATIC ...@@ -32,6 +32,7 @@ add_library(onnx_import STATIC
core/value_info.hpp core/value_info.hpp
exceptions.hpp exceptions.hpp
op/add.hpp op/add.hpp
op/and.hpp
op/average_pool.cpp op/average_pool.cpp
op/average_pool.hpp op/average_pool.hpp
op/batch_norm.cpp op/batch_norm.cpp
...@@ -43,10 +44,13 @@ add_library(onnx_import STATIC ...@@ -43,10 +44,13 @@ add_library(onnx_import STATIC
op/conv.cpp op/conv.cpp
op/conv.hpp op/conv.hpp
op/div.hpp op/div.hpp
op/equal.hpp
op/flatten.cpp op/flatten.cpp
op/flatten.hpp op/flatten.hpp
op/gemm.cpp op/gemm.cpp
op/gemm.hpp op/gemm.hpp
op/greater.hpp
op/less.hpp
op/matmul.hpp op/matmul.hpp
op/max_pool.cpp op/max_pool.cpp
op/max_pool.hpp op/max_pool.hpp
...@@ -55,6 +59,8 @@ add_library(onnx_import STATIC ...@@ -55,6 +59,8 @@ add_library(onnx_import STATIC
op/mean.hpp op/mean.hpp
op/min.hpp op/min.hpp
op/mul.hpp op/mul.hpp
op/not.hpp
op/or.hpp
op/pow.hpp op/pow.hpp
op/relu.hpp op/relu.hpp
op/reshape.cpp op/reshape.cpp
...@@ -67,6 +73,7 @@ add_library(onnx_import STATIC ...@@ -67,6 +73,7 @@ add_library(onnx_import STATIC
op/sum.hpp op/sum.hpp
op/unsqueeze.cpp op/unsqueeze.cpp
op/unsqueeze.hpp op/unsqueeze.hpp
op/xor.hpp
ops_bridge.cpp ops_bridge.cpp
ops_bridge.hpp ops_bridge.hpp
utils/broadcasting.cpp utils/broadcasting.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/and.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector logical_and(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/equal.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector equal(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/greater.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/less.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector less(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/not.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector logical_not(const Node& node)
{
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/or.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector logical_or(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "core/node.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
inline NodeVector logical_xor(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
auto not_right = std::make_shared<ngraph::op::Not>(right);
return {std::make_shared<ngraph::op::Or>(
std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
}
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -19,20 +19,26 @@ ...@@ -19,20 +19,26 @@
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "op/add.hpp" #include "op/add.hpp"
#include "op/and.hpp"
#include "op/average_pool.hpp" #include "op/average_pool.hpp"
#include "op/batch_norm.hpp" #include "op/batch_norm.hpp"
#include "op/concat.hpp" #include "op/concat.hpp"
#include "op/constant.hpp" #include "op/constant.hpp"
#include "op/conv.hpp" #include "op/conv.hpp"
#include "op/div.hpp" #include "op/div.hpp"
#include "op/equal.hpp"
#include "op/flatten.hpp" #include "op/flatten.hpp"
#include "op/gemm.hpp" #include "op/gemm.hpp"
#include "op/greater.hpp"
#include "op/less.hpp"
#include "op/matmul.hpp" #include "op/matmul.hpp"
#include "op/max.hpp" #include "op/max.hpp"
#include "op/max_pool.hpp" #include "op/max_pool.hpp"
#include "op/mean.hpp" #include "op/mean.hpp"
#include "op/min.hpp" #include "op/min.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/not.hpp"
#include "op/or.hpp"
#include "op/pow.hpp" #include "op/pow.hpp"
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/reshape.hpp" #include "op/reshape.hpp"
...@@ -41,6 +47,7 @@ ...@@ -41,6 +47,7 @@
#include "op/sub.hpp" #include "op/sub.hpp"
#include "op/sum.hpp" #include "op/sum.hpp"
#include "op/unsqueeze.hpp" #include "op/unsqueeze.hpp"
#include "op/xor.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
namespace ngraph namespace ngraph
...@@ -86,6 +93,7 @@ namespace ngraph ...@@ -86,6 +93,7 @@ namespace ngraph
ops_bridge() ops_bridge()
{ {
m_map.emplace("Add", std::bind(op::add, std::placeholders::_1)); m_map.emplace("Add", std::bind(op::add, std::placeholders::_1));
m_map.emplace("And", std::bind(op::logical_and, std::placeholders::_1));
m_map.emplace("AveragePool", m_map.emplace("AveragePool",
std::bind(op::average_pool, std::placeholders::_1)); std::bind(op::average_pool, std::placeholders::_1));
m_map.emplace("BatchNormalization", m_map.emplace("BatchNormalization",
...@@ -94,14 +102,19 @@ namespace ngraph ...@@ -94,14 +102,19 @@ namespace ngraph
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1));
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1));
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1)); m_map.emplace("Div", std::bind(op::div, std::placeholders::_1));
m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1));
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1)); m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1));
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1));
m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1));
m_map.emplace("Less", std::bind(op::less, std::placeholders::_1));
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1));
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1)); m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1));
m_map.emplace("Max", std::bind(op::max, std::placeholders::_1)); m_map.emplace("Max", std::bind(op::max, std::placeholders::_1));
m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1)); m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1));
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1)); m_map.emplace("Min", std::bind(op::min, std::placeholders::_1));
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1));
m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1));
m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1));
m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1)); m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1));
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1));
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1)); m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1));
...@@ -110,6 +123,7 @@ namespace ngraph ...@@ -110,6 +123,7 @@ namespace ngraph
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1)); m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1));
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1)); m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1)); m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1));
} }
NodeVector operator()(const Node& node) const NodeVector operator()(const Node& node) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment