Commit 137f002b authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[Py] Expose logical And, Or operations. (#1198)

parent 7cd38322
......@@ -45,6 +45,8 @@ from ngraph.ops import greater_eq
from ngraph.ops import less
from ngraph.ops import less_eq
from ngraph.ops import log
from ngraph.ops import logical_and
from ngraph.ops import logical_or
from ngraph.ops import logical_not
from ngraph.ops import max
from ngraph.ops import max_pool
......
......@@ -36,6 +36,7 @@ from _pyngraph.op import Abs
from _pyngraph.op import Acos
from _pyngraph.op import Add
from _pyngraph.op import AllReduce
from _pyngraph.op import And
from _pyngraph.op import Asin
from _pyngraph.op import Atan
from _pyngraph.op import AvgPool
......@@ -76,6 +77,7 @@ from _pyngraph.op import Not
from _pyngraph.op import NotEqual
from _pyngraph.op import OneHot
from _pyngraph.op import Op
from _pyngraph.op import Or
from _pyngraph.op import Pad
from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector
......
......@@ -37,5 +37,6 @@ from _pyngraph.op.util import UnaryElementwiseArithmetic
from _pyngraph.op.util import BinaryElementwise
from _pyngraph.op.util import BinaryElementwiseComparison
from _pyngraph.op.util import BinaryElementwiseArithmetic
from _pyngraph.op.util import BinaryElementwiseLogical
from _pyngraph.op.util import OpAnnotations
from _pyngraph.op.util import ArithmeticReduction
\ No newline at end of file
from _pyngraph.op.util import ArithmeticReduction
......@@ -20,10 +20,10 @@ import numpy as np
from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Function, Node, \
NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, BatchNorm, Broadcast, Ceiling, \
Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \
Ceiling, Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Pad, Parameter, Product, Power, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, \
Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \
Subtract, Sum, Tan, Tanh
......@@ -393,6 +393,30 @@ def less_eq(left_node, right_node, name=None): # type: (NodeInput, NodeInput, s
return LessEq(left_node, right_node)
@binary_op
def logical_and(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which perform logical and operation on input nodes element-wise.
:param left_node: The first input node providing data.
:param right_node: The second input node providing data.
:param name: The optional new name for output node.
:return: The node performing logical and operation on input nodes corresponding elements.
"""
return And(left_node, right_node)
@binary_op
def logical_or(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Return node which performs logical or operation on input nodes element-wise.
:param left_node: The first input node providing data.
:param right_node: The second input node providing data.
:param name: The optional new name for output node.
:return: The node performing logical or operation on input nodes corresponding elements.
"""
return Or(left_node, right_node)
@unary_op
def logical_not(node, name=None): # type: (Node, str) -> Node
"""Return node which applies logical negation to the input node elementwise."""
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/and.hpp" // ngraph::op::And
#include "pyngraph/ops/and.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_And(py::module m)
{
py::class_<ngraph::op::And,
std::shared_ptr<ngraph::op::And>,
ngraph::op::util::BinaryElementwiseLogical>
logical_and(m, "And");
logical_and.doc() = "ngraph.impl.op.And wraps ngraph::op::And";
logical_and.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_And(py::module m);
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/or.hpp" // ngraph::op::Or
#include "pyngraph/ops/or.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Or(py::module m)
{
py::class_<ngraph::op::Or,
std::shared_ptr<ngraph::op::Or>,
ngraph::op::util::BinaryElementwiseLogical>
logical_or(m, "Or");
logical_or.doc() = "ngraph.impl.op.Or wraps ngraph::op::Or";
logical_or.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Or(py::module m);
......@@ -30,6 +30,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Cos(m_op);
regclass_pyngraph_op_Cosh(m_op);
regclass_pyngraph_op_Add(m_op);
regclass_pyngraph_op_And(m_op);
regclass_pyngraph_op_Broadcast(m_op);
regclass_pyngraph_op_Ceiling(m_op);
regclass_pyngraph_op_Concat(m_op);
......@@ -62,6 +63,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Power(m_op);
regclass_pyngraph_op_OneHot(m_op);
// regclass_pyngraph_op_Op(m_op);
regclass_pyngraph_op_Or(m_op);
regclass_pyngraph_op_Reduce(m_op);
regclass_pyngraph_op_ReplaceSlice(m_op);
regclass_pyngraph_op_Reshape(m_op);
......
......@@ -20,6 +20,7 @@
#include "pyngraph/ops/abs.hpp"
#include "pyngraph/ops/acos.hpp"
#include "pyngraph/ops/add.hpp"
#include "pyngraph/ops/and.hpp"
#include "pyngraph/ops/asin.hpp"
#include "pyngraph/ops/atan.hpp"
#include "pyngraph/ops/avg_pool.hpp"
......@@ -56,6 +57,7 @@
#include "pyngraph/ops/max.hpp"
#include "pyngraph/ops/min.hpp"
#include "pyngraph/ops/one_hot.hpp"
#include "pyngraph/ops/or.hpp"
#include "pyngraph/ops/pad.hpp"
#include "pyngraph/ops/parameter.hpp"
#include "pyngraph/ops/parameter_vector.hpp"
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "pyngraph/ops/util/binary_elementwise_logical.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_util_BinaryElementwiseLogical(py::module m)
{
py::class_<ngraph::op::util::BinaryElementwiseLogical,
std::shared_ptr<ngraph::op::util::BinaryElementwiseLogical>,
ngraph::op::util::BinaryElementwise>
binaryElementwiseLogical(m, "BinaryElementwiseLogical");
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_util_BinaryElementwiseLogical(py::module m);
......@@ -28,6 +28,7 @@ void regmodule_pyngraph_op_util(py::module m)
regclass_pyngraph_op_util_BinaryElementwise(m_util);
regclass_pyngraph_op_util_BinaryElementwiseArithmetic(m_util);
regclass_pyngraph_op_util_BinaryElementwiseComparison(m_util);
regclass_pyngraph_op_util_BinaryElementwiseLogical(m_util);
regclass_pyngraph_op_util_UnaryElementwise(m_util);
regclass_pyngraph_op_util_UnaryElementwiseArithmetic(m_util);
}
......@@ -21,6 +21,7 @@
#include "pyngraph/ops/util/binary_elementwise.hpp"
#include "pyngraph/ops/util/binary_elementwise_arithmetic.hpp"
#include "pyngraph/ops/util/binary_elementwise_comparison.hpp"
#include "pyngraph/ops/util/binary_elementwise_logical.hpp"
#include "pyngraph/ops/util/op_annotations.hpp"
#include "pyngraph/ops/util/requires_tensor_view_args.hpp"
#include "pyngraph/ops/util/unary_elementwise.hpp"
......
......@@ -141,11 +141,13 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/util/op_annotations.cpp',
'pyngraph/ops/util/unary_elementwise.cpp',
'pyngraph/ops/util/binary_elementwise_arithmetic.cpp',
'pyngraph/ops/util/binary_elementwise_logical.cpp',
'pyngraph/ops/util/regmodule_pyngraph_op_util.cpp',
'pyngraph/ops/util/unary_elementwise_arithmetic.cpp',
'pyngraph/ops/abs.cpp',
'pyngraph/ops/acos.cpp',
'pyngraph/ops/add.cpp',
'pyngraph/ops/and.cpp',
'pyngraph/ops/asin.cpp',
'pyngraph/ops/atan.cpp',
'pyngraph/ops/avg_pool.cpp',
......@@ -178,6 +180,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/not_equal.cpp',
'pyngraph/ops/op.cpp',
'pyngraph/ops/one_hot.cpp',
'pyngraph/ops/or.cpp',
'pyngraph/ops/pad.cpp',
'pyngraph/ops/parameter.cpp',
'pyngraph/ops/parameter_vector.cpp',
......
......@@ -85,6 +85,49 @@ def test_binary_op_with_scalar(ng_api_helper, numpy_function):
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_helper,numpy_function', [
(ng.logical_and, np.logical_and),
(ng.logical_or, np.logical_or),
])
def test_binary_logical_op(ng_api_helper, numpy_function):
runtime = get_runtime()
shape = [2, 2]
parameter_a = ng.parameter(shape, name='A', dtype=np.bool)
parameter_b = ng.parameter(shape, name='B', dtype=np.bool)
model = ng_api_helper(parameter_a, parameter_b)
computation = runtime.computation(model, parameter_a, parameter_b)
value_a = np.array([[True, False], [False, False]], dtype=np.bool)
value_b = np.array([[False, True], [False, True]], dtype=np.bool)
result = computation(value_a, value_b)
expected = numpy_function(value_a, value_b)
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_helper,numpy_function', [
(ng.logical_and, np.logical_and),
(ng.logical_or, np.logical_or),
])
def test_binary_logical_op_with_scalar(ng_api_helper, numpy_function):
runtime = get_runtime()
value_a = np.array([[True, False], [False, False]], dtype=np.bool)
value_b = np.array([[False, True], [False, True]], dtype=np.bool)
shape = [2, 2]
parameter_a = ng.parameter(shape, name='A', dtype=np.bool)
model = ng_api_helper(parameter_a, value_b)
computation = runtime.computation(model, parameter_a)
result = computation(value_a)
expected = numpy_function(value_a, value_b)
assert np.allclose(result, expected)
@pytest.mark.parametrize('operator,numpy_function', [
(operator.add, np.add),
(operator.sub, np.subtract),
......
......@@ -33,6 +33,7 @@ from test.ngraph.util import run_op_numeric_data, run_op_node
(ng.exp, np.exp, -100., 100.),
(ng.floor, np.floor, -100., 100.),
(ng.log, np.log, 0, 100.),
(ng.logical_not, np.logical_not, -10, 10),
(ng.relu, lambda x: np.maximum(0, x), -100., 100.),
(ng.sign, np.sign, -100., 100.),
(ng.sin, np.sin, -100., 100.),
......@@ -67,6 +68,7 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
(ng.exp, np.exp, np.float32(1.5)),
(ng.floor, np.floor, np.float32(1.5)),
(ng.log, np.log, np.float32(1.5)),
(ng.logical_not, np.logical_not, np.int32(0)),
(ng.relu, lambda x: np.maximum(0, x), np.float32(-0.125)),
(ng.sign, np.sign, np.float32(0.)),
(ng.sin, np.sin, np.float32(np.pi / 4.0)),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment