Commit 12e8b9b7 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add logical-and, logical-or ops (#892)

* Add logical-and, logical-or ops

* Restore accidentally-deleted test

* add new ops to IE backend
parent c74da83e
.. and.rst:
###
And
###
.. code-block:: cpp
And // Elementwise logical-and operation
Description
===========
Produces tensor with boolean element type and shape as the two inputs,
which must themselves have boolean element type, where the value at each
coordinate of ``output`` is ``1`` (true) if ``arg0`` and ``arg1`` are
both nonzero, ``0`` otherwise.
Inputs
------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``arg0`` | ``ngraph::element::boolean`` | any |
+-----------------+------------------------------+--------------------------------+
| ``arg1`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Outputs
-------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``output`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \texttt{arg0}_{i_0, \ldots, i_{n-1}}\, \texttt{&&}\, \texttt{arg1}_{i_0, \ldots, i_{n-1}}
C++ Interface
=============
.. doxygenclass:: ngraph::op::And
:project: ngraph
:members:
......@@ -53,6 +53,7 @@ Not currently a comprehensive list.
acos.rst
add.rst
allreduce.rst
and.rst
asin.rst
atan.rst
avg_pool.rst
......@@ -83,6 +84,7 @@ Not currently a comprehensive list.
negative.rst
not_equal.rst
not.rst
or.rst
softmax.rst
.. or.rst:
##
Or
##
.. code-block:: cpp
Or // Elementwise logical-or operation
Description
===========
Produces tensor with boolean element type and shape as the two inputs,
which must themselves have boolean element type, where the value at each
coordinate of ``output`` is ``1`` (true) if ``arg0`` or ``arg1`` is
nonzero, ``0`` otherwise.
Inputs
------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``arg0`` | ``ngraph::element::boolean`` | any |
+-----------------+------------------------------+--------------------------------+
| ``arg1`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Outputs
-------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``output`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \texttt{arg0}_{i_0, \ldots, i_{n-1}}\, \texttt{||}\, \texttt{arg1}_{i_0, \ldots, i_{n-1}}
C++ Interface
=============
.. doxygenclass:: ngraph::op::Or
:project: ngraph
:members:
......@@ -35,6 +35,7 @@ set (SRC
op/acos.cpp
op/add.cpp
op/allreduce.cpp
op/and.cpp
op/asin.cpp
op/atan.cpp
op/avg_pool.cpp
......@@ -70,6 +71,7 @@ set (SRC
op/not_equal.cpp
op/one_hot.cpp
op/op.cpp
op/or.cpp
op/pad.cpp
op/parameter.cpp
op/power.cpp
......@@ -97,6 +99,7 @@ set (SRC
op/util/arithmetic_reduction.cpp
op/util/binary_elementwise_arithmetic.cpp
op/util/binary_elementwise_comparison.cpp
op/util/binary_elementwise_logical.cpp
op/util/binary_elementwise.cpp
op/util/requires_tensor_view_args.cpp
op/util/unary_elementwise_arithmetic.cpp
......
......@@ -62,6 +62,7 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -96,6 +97,7 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/and.hpp"
using namespace std;
using namespace ngraph;
op::And::And(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseLogical("And", arg0, arg1)
{
}
shared_ptr<Node> op::And::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<And>(new_args.at(0), new_args.at(1));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/op/util/binary_elementwise_logical.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-and operation.
///
class And : public util::BinaryElementwiseLogical
{
public:
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
///
/// Output `[d0, ...]`
///
And(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual bool is_commutative() override { return true; }
};
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/or.hpp"
using namespace std;
using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseLogical("Or", arg0, arg1)
{
}
shared_ptr<Node> op::Or::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Or>(new_args.at(0), new_args.at(1));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/op/util/binary_elementwise_logical.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-or operation.
///
class Or : public util::BinaryElementwiseLogical
{
public:
/// \brief Constructs a logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
///
/// Output `[d0, ...]`
///
Or(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual bool is_commutative() override { return true; }
};
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/util/binary_elementwise_logical.hpp"
using namespace std;
using namespace ngraph;
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const string& node_type,
const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, element::boolean, arg0, arg1)
{
if (arg0->get_element_type() != element::boolean ||
arg1->get_element_type() != element::boolean)
{
throw ngraph_error("Arguments must have boolean element type");
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/binary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary logical operations, i.e., operations where the same
/// scalar binary logical operation is applied to each corresponding pair of elements in two same-shaped
/// boolean input tensors.
///
/// For example, if the underlying operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$\texttt{bool}[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape, with element type `bool`. |
/// | `arg1` | \f$\texttt{bool}[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseLogical : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseLogical(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
......@@ -27,6 +27,7 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -62,6 +63,7 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......@@ -3648,6 +3650,24 @@ namespace ngraph
writer << " " << out[0].get_name() << ",\n";
writer << " " << shape_size(node->get_shape()) << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::And)
{
writer << "reference::logical_and(" << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << out[0].get_size() << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Or)
{
writer << "reference::logical_or(" << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << out[0].get_size() << ");\n";
}
}
}
}
......
......@@ -37,6 +37,7 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -72,6 +73,7 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......@@ -274,6 +276,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>},
{TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>},
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
{TI(ngraph::op::And), &runtime::cpu::CPU_Emitter::emit<op::And>},
{TI(ngraph::op::Or), &runtime::cpu::CPU_Emitter::emit<op::Or>},
};
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
......@@ -332,6 +336,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/concat.hpp"
......@@ -342,6 +347,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/reduce.hpp"
......
......@@ -50,6 +50,7 @@
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
......@@ -81,6 +82,7 @@
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
......@@ -169,6 +171,13 @@ namespace ngraph
static_cast<int>(args[0]->get_element_count()));
}
#endif
else if (node_op == "And")
{
reference::logical_and(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Asin")
{
reference::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......@@ -516,6 +525,13 @@ namespace ngraph
out[0]->get_shape(),
oh->get_one_hot_axis());
}
else if (node_op == "Or")
{
reference::logical_or(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Parameter")
{
}
......
......@@ -50,6 +50,7 @@
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
......@@ -82,6 +83,7 @@
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
......@@ -250,6 +252,13 @@ private:
static_cast<int>(args[0]->get_element_count()));
}
#endif
else if (node_op == "And")
{
reference::logical_and(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Asin")
{
reference::asin<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......@@ -588,6 +597,13 @@ private:
out[0]->get_shape(),
oh->get_one_hot_axis());
}
else if (node_op == "Or")
{
reference::logical_or(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(args[1]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Parameter")
{
}
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
static inline void
logical_and(const char* arg0, const char* arg1, char* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] && arg1[i];
}
}
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
static inline void
logical_or(const char* arg0, const char* arg1, char* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] || arg1[i];
}
}
}
}
}
......@@ -24,6 +24,7 @@
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -58,6 +59,7 @@
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......@@ -374,6 +376,10 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::AllReduce>(args[0]);
}
else if (node_op == "And")
{
node = make_shared<op::And>(args[0], args[1]);
}
else if (node_op == "Asin")
{
node = make_shared<op::Asin>(args[0]);
......@@ -715,6 +721,10 @@ static shared_ptr<ngraph::Function>
auto one_hot_axis = node_js.at("one_hot_axis").get<size_t>();
node = make_shared<op::OneHot>(args[0], shape, one_hot_axis);
}
else if (node_op == "Or")
{
node = make_shared<op::Or>(args[0], args[1]);
}
else if (node_op == "Pad")
{
auto padding_below = node_js.at("padding_below").get<vector<size_t>>();
......
......@@ -7933,3 +7933,43 @@ TEST(${BACKEND_NAME}, validate_call_output_shape)
EXPECT_ANY_THROW(backend->call(f, {a}, {c, b}));
}
TEST(${BACKEND_NAME}, logical_and)
{
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
auto f = make_shared<Function>(make_shared<op::And>(A, B), op::ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, vector<char>{1, 0, 1, 1, 1, 0, 1, 0});
auto b = backend->create_tensor(element::boolean, shape);
copy_data(b, vector<char>{0, 0, 1, 0, 0, 1, 1, 0});
auto result = backend->create_tensor(element::boolean, shape);
backend->call(f, {result}, {a, b});
EXPECT_EQ((vector<char>{0, 0, 1, 0, 0, 0, 1, 0}), read_vector<char>(result));
}
TEST(${BACKEND_NAME}, logical_or)
{
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
auto f = make_shared<Function>(make_shared<op::Or>(A, B), op::ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, vector<char>{1, 0, 1, 1, 1, 0, 1, 0});
auto b = backend->create_tensor(element::boolean, shape);
copy_data(b, vector<char>{0, 0, 1, 0, 0, 1, 1, 0});
auto result = backend->create_tensor(element::boolean, shape);
backend->call(f, {result}, {a, b});
EXPECT_EQ((vector<char>{1, 0, 1, 1, 1, 1, 1, 0}), read_vector<char>(result));
}
......@@ -599,6 +599,83 @@ TEST(type_prop, subtract_bad_arguments)
});
}
//
// Tests for binary elementwise logical ops.
//
void test_binary_logical(std::string node_type,
shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto tv0_2_4_param_3 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto tv0_4_2_param = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
const shared_ptr<Node>& y) {
try
{
auto node = f(x, y);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
};
test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
auto test_binary_bad_arguments_view_element_types = [&](const shared_ptr<Node>& x,
const shared_ptr<Node>& y) {
try
{
auto node = f(x, y);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have boolean element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
};
test_binary_bad_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
test_binary_bad_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
test_binary_bad_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
auto node = f(x, y);
EXPECT_TRUE(node->has_same_type(node->get_arguments()[0]));
};
test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
}
TEST(type_prop, and_bad_arguments)
{
test_binary_logical(
"And", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::And>(x, y);
});
}
TEST(type_prop, or_bad_arguments)
{
test_binary_logical(
"Or", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Or>(x, y);
});
}
TEST(type_prop, comparison_good)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment