Unverified Commit e7e4e860 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3363 from NervanaSystems/aprocter/logical-xor

Add logical xor op
parents 9476e0f4 0a0dff09
......@@ -77,6 +77,7 @@ Not currently a comprehensive list.
* :doc:`tan`
* :doc:`tanh`
* :doc:`transpose`
* :doc:`xor`
......@@ -149,6 +150,7 @@ Not currently a comprehensive list.
tan.rst
tanh.rst
transpose.rst
xor.rst
.. _more_about:
......
.. xor.rst:
###
Xor
###
.. code-block:: cpp
Xor // Elementwise logical-xor operation
Description
===========
Produces tensor with boolean element type and shape as the two inputs,
which must themselves have boolean element type, where the value at each
coordinate of ``output`` is ``0`` (true) if ``arg0`` or ``arg1`` both
zero or both nonzero, or ``1`` otherwise.
Inputs
------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``arg0`` | ``ngraph::element::boolean`` | any |
+-----------------+------------------------------+--------------------------------+
| ``arg1`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Outputs
-------
+-----------------+------------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+==============================+================================+
| ``output`` | ``ngraph::element::boolean`` | same as ``arg0`` |
+-----------------+------------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\mathtt{output}_{i_0, \ldots, i_{n-1}} = \mathtt{arg0}_{i_0, \ldots, i_{n-1}}\, \mathtt{XOR}\, \mathtt{arg1}_{i_0, \ldots, i_{n-1}}
C++ Interface
=============
.. doxygenclass:: ngraph::op::Xor
:project: ngraph
:members:
......@@ -304,6 +304,8 @@ set (SRC
op/tanh.hpp
op/topk.cpp
op/topk.hpp
op/xor.cpp
op/xor.hpp
op/fused/clamp.cpp
op/fused/clamp.hpp
op/fused/conv_fused.cpp
......
......@@ -175,6 +175,7 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
......
......@@ -160,3 +160,4 @@ NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(Xor, ngraph::op)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/xor.hpp"
using namespace std;
using namespace ngraph;
const string op::Xor::type_name{"Xor"};
op::Xor::Xor(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Xor::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Xor>(new_args.at(0), new_args.at(1), this->get_autob());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/util/binary_elementwise_logical.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-xor operation.
///
class Xor : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-xor operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param autob Auto broadcast specification
///
/// Output `[d0, ...]`
///
Xor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
}
}
......@@ -65,6 +65,7 @@
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp"
......@@ -107,6 +108,7 @@
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/slice_plan.hpp"
#include "ngraph/util.hpp"
......@@ -994,6 +996,17 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Xor>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
NGRAPH_CHECK(false,
......@@ -1034,14 +1047,15 @@ shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out
}
bool is_supported_binary_op(std::shared_ptr<Node> n)
{
return (
std::dynamic_pointer_cast<op::Add>(n) || std::dynamic_pointer_cast<op::And>(n) ||
std::dynamic_pointer_cast<op::Divide>(n) || std::dynamic_pointer_cast<op::Equal>(n) ||
std::dynamic_pointer_cast<op::Greater>(n) || std::dynamic_pointer_cast<op::GreaterEq>(n) ||
std::dynamic_pointer_cast<op::Less>(n) || std::dynamic_pointer_cast<op::LessEq>(n) ||
std::dynamic_pointer_cast<op::Maximum>(n) || std::dynamic_pointer_cast<op::Minimum>(n) ||
std::dynamic_pointer_cast<op::Multiply>(n) || std::dynamic_pointer_cast<op::NotEqual>(n) ||
std::dynamic_pointer_cast<op::Or>(n) || std::dynamic_pointer_cast<op::Subtract>(n));
return (std::dynamic_pointer_cast<op::Add>(n) || std::dynamic_pointer_cast<op::And>(n) ||
std::dynamic_pointer_cast<op::Divide>(n) || std::dynamic_pointer_cast<op::Equal>(n) ||
std::dynamic_pointer_cast<op::Greater>(n) ||
std::dynamic_pointer_cast<op::GreaterEq>(n) || std::dynamic_pointer_cast<op::Less>(n) ||
std::dynamic_pointer_cast<op::LessEq>(n) || std::dynamic_pointer_cast<op::Maximum>(n) ||
std::dynamic_pointer_cast<op::Minimum>(n) ||
std::dynamic_pointer_cast<op::Multiply>(n) ||
std::dynamic_pointer_cast<op::NotEqual>(n) || std::dynamic_pointer_cast<op::Or>(n) ||
std::dynamic_pointer_cast<op::Subtract>(n) || std::dynamic_pointer_cast<op::Xor>(n));
}
void pass::ConstantFolding::construct_constant_binary()
......
......@@ -65,6 +65,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/cpu/cpu_builder_registry.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -104,6 +105,7 @@
#include "ngraph/runtime/cpu/kernel/subtract.hpp"
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/kernel/xor.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -243,6 +245,28 @@ namespace ngraph
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Xor)
{
auto& functors = external_function->get_functors();
auto element_count = out[0].get_size();
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto functor =
[&, element_count, arg0_buffer_index, arg1_buffer_index, out0_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::cpu::kernel::logical_xor(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out0_buffer_index],
element_count,
ectx->arena);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Maximum)
{
......@@ -545,6 +569,19 @@ namespace ngraph
return functor;
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Xor)
{
auto element_count = shape_size(node->get_shape());
auto functor = [&, element_count](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
runtime::cpu::kernel::logical_xor(
inputs[0], inputs[1], outputs[0], element_count, 0);
};
return functor;
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sign)
{
......@@ -612,6 +649,7 @@ namespace ngraph
REGISTER_OP_BUILDER(Minimum);
REGISTER_OP_BUILDER(And);
REGISTER_OP_BUILDER(Or);
REGISTER_OP_BUILDER(Xor);
REGISTER_CF_BUILDER(Add);
REGISTER_CF_BUILDER(Subtract);
......@@ -633,6 +671,7 @@ namespace ngraph
REGISTER_CF_BUILDER(LessEq);
REGISTER_CF_BUILDER(And);
REGISTER_CF_BUILDER(Or);
REGISTER_CF_BUILDER(Xor);
REGISTER_CF_BUILDER(Sign);
REGISTER_CF_BUILDER(Not);
}
......
......@@ -113,6 +113,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -3822,6 +3823,15 @@ namespace ngraph
<< " " << out[0].get_size() << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Xor)
{
writer << "reference::logical_xor(" << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << out[0].get_size() << ");\n";
}
#define TI(x) std::type_index(typeid(x))
static std::string emit_infix_operator(const std::string& opname,
......
......@@ -134,6 +134,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/batch_fusion.hpp"
#include "ngraph/pass/common_function_collection.hpp"
......@@ -431,6 +432,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
{TI(ngraph::op::And), &runtime::cpu::CPU_Emitter::emit<op::And>},
{TI(ngraph::op::Or), &runtime::cpu::CPU_Emitter::emit<op::Or>},
{TI(ngraph::op::Xor), &runtime::cpu::CPU_Emitter::emit<op::Xor>},
{TI(ngraph::op::CPULeakyRelu), &runtime::cpu::CPU_Emitter::emit<op::CPULeakyRelu>},
{TI(ngraph::op::CompiledKernel), &runtime::cpu::CPU_Emitter::emit<op::CompiledKernel>},
{TI(ngraph::op::LRN), &runtime::cpu::CPU_Emitter::emit<ngraph::op::LRN>},
......@@ -567,6 +569,7 @@ void runtime::cpu::CPU_ExternalFunction::compile(ngraph::pass::PassConfig& pass_
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/state/rng_state.hpp"
#include "ngraph/strides.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/cpu_executor.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
void logical_xor(void* input0, void* input1, void* output, size_t count, int arena)
{
Eigen::array<Eigen::Index, 1> out_dims, in_dims;
out_dims[0] = in_dims[0] = count;
Eigen::TensorMap<Eigen::Tensor<char, 1, Eigen::RowMajor>> out(
static_cast<char*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<char, 1, Eigen::RowMajor>> in0(
static_cast<char*>(input0), in_dims);
Eigen::TensorMap<Eigen::Tensor<char, 1, Eigen::RowMajor>> in1(
static_cast<char*>(input1), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
(in0 != in1).template cast<char>();
}
}
}
}
}
......@@ -51,6 +51,7 @@
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
......@@ -67,6 +68,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/generic_cpu/kernel/broadcast.hpp"
......@@ -154,6 +156,7 @@
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
......@@ -1607,6 +1610,15 @@ private:
}
break;
}
case OP_TYPEID::Xor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_xor(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
......
......@@ -129,6 +129,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
......@@ -1463,6 +1464,11 @@ std::string runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS)
return compiled_function->add_to_runtime(index, function_name, args, out);
}
std::string runtime::gpu::GPU_Emitter::emit_Xor(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_DynBroadcast(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
......
......@@ -215,3 +215,4 @@ send_recv
send_recv_ring
gelu_f32
gelu_f64
logical_xor
......@@ -2106,6 +2106,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Tile:
case OP_TYPEID::Transpose:
case OP_TYPEID::Unsqueeze:
case OP_TYPEID::Xor:
default:
{
throw unsupported_op("Unsupported op '" + op->description() +
......
......@@ -112,6 +112,7 @@ send_recv
send_recv_ring
gelu_f32
gelu_f64
logical_xor
# Not supported quant ops
model_dequantize_linear_1d_zero_scale_int8
......
......@@ -154,6 +154,7 @@
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/xor.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/state/rng_state.hpp"
......@@ -1616,6 +1617,15 @@ private:
}
break;
}
case OP_TYPEID::Xor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_xor(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
......
......@@ -17,6 +17,7 @@
#include "ngraph/op/and.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
......@@ -28,6 +29,7 @@ namespace ngraph
NGRAPH_PLAIDML_OP_CLASS(ImplAnd, OpImpl<op::And>);
NGRAPH_PLAIDML_OP_CLASS(ImplNot, OpImpl<op::Not>);
NGRAPH_PLAIDML_OP_CLASS(ImplOr, OpImpl<op::Or>);
NGRAPH_PLAIDML_OP_CLASS(ImplXor, OpImpl<op::Xor>);
}
}
}
......@@ -69,3 +71,16 @@ void ngraph::runtime::plaidml::ImplOr::Apply()
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize());
}
// Xor performs a simple elementwise logical xor.
void ngraph::runtime::plaidml::ImplXor::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? (B ? 0 : A) : B"})
.finalize());
}
......@@ -27,6 +27,7 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
......@@ -46,7 +47,8 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
std::type_index{typeid(ngraph::op::LessEq)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
std::type_index{typeid(ngraph::op::Or)},
std::type_index{typeid(ngraph::op::Xor)}};
const ngraph::Node* node_ptr = node.get();
......@@ -62,7 +64,8 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
std::type_index{typeid(ngraph::op::Equal)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
std::type_index{typeid(ngraph::op::Or)},
std::type_index{typeid(ngraph::op::Xor)}};
const ngraph::Node* node_ptr = node.get();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void logical_xor(const T* arg0, const T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = static_cast<T>((arg0[i] || arg1[i]) && !(arg0[i] && arg1[i]));
}
}
}
}
}
......@@ -146,6 +146,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
......@@ -1920,6 +1921,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Unsqueeze>(args[0], args[1]);
break;
}
case OP_TYPEID::Xor:
{
node = make_shared<op::Xor>(args[0], args[1], read_auto_broadcast(node_js, "autob"));
break;
}
case OP_TYPEID::UnknownOp:
{
stringstream ss;
......@@ -2912,6 +2918,15 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Unsqueeze: { break;
}
case OP_TYPEID::Xor:
{
auto tmp = dynamic_cast<const op::Xor*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["autob"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::UnknownOp: { break;
}
}
......
......@@ -245,6 +245,7 @@ set(MULTI_TEST_SRC
backend/generate_mask.in.cpp
backend/logical_and.in.cpp
backend/logical_or.in.cpp
backend/logical_xor.in.cpp
backend/lrn.in.cpp
backend/max.in.cpp
backend/min.in.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, logical_xor)
{
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
auto f = make_shared<Function>(make_shared<op::Xor>(A, B), ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::boolean, shape);
copy_data(a, vector<char>{1, 0, 1, 1, 1, 0, 1, 0});
auto b = backend->create_tensor(element::boolean, shape);
copy_data(b, vector<char>{0, 0, 1, 0, 0, 1, 1, 0});
auto result = backend->create_tensor(element::boolean, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_EQ((vector<char>{1, 0, 0, 1, 1, 1, 0, 0}), read_vector<char>(result));
}
......@@ -822,6 +822,32 @@ TEST(constant_folding, const_or)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_xor)
{
auto constant0 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
auto constant1 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
auto eq = make_shared<op::Xor>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Xor>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 1, 0, 1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_ceiling)
{
auto constant = op::Constant::create(
......
......@@ -1216,14 +1216,16 @@ TEST(cpu_test, constant_unary_binary)
auto less_eq = make_shared<op::LessEq>(g, h);
auto logical_and = make_shared<op::And>(i, j);
auto logical_or = make_shared<op::Or>(i, j);
auto logical_xor = make_shared<op::Xor>(i, j);
auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k);
auto logical_not = make_shared<op::Not>(j);
auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, absn, neg,
sqrt, relu, sign, equal, not_equal, greater, greater_eq, less,
less_eq, logical_and, logical_or, ceil, floor, logical_not},
NodeVector{add, sub, mul, divn, min, max,
absn, neg, sqrt, relu, sign, equal,
not_equal, greater, greater_eq, less, less_eq, logical_and,
logical_or, logical_xor, ceil, floor, logical_not},
ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
......@@ -1252,6 +1254,7 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(count_ops_of_type<op::LessEq>(func), 0);
ASSERT_EQ(count_ops_of_type<op::And>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Or>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Xor>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Not>(func), 0);
......@@ -1275,6 +1278,7 @@ TEST(cpu_test, constant_unary_binary)
vector<char> less_eq_expected{1, 1, 1, 0};
vector<char> and_expected{0, 0, 0, 1};
vector<char> or_expected{0, 1, 1, 1};
vector<char> xor_expected{0, 1, 1, 0};
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
vector<char> not_expected{1, 0, 1, 0};
......@@ -1298,11 +1302,12 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(get_result_constant<char>(func, 16), less_eq_expected);
ASSERT_EQ(get_result_constant<char>(func, 17), and_expected);
ASSERT_EQ(get_result_constant<char>(func, 18), or_expected);
ASSERT_EQ(get_result_constant<char>(func, 19), xor_expected);
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
get_result_constant<float>(func, 20), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 21), not_expected);
get_result_constant<float>(func, 21), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 22), not_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
}
......
......@@ -209,6 +209,14 @@ TEST(type_prop, or_bad_arguments)
});
}
TEST(type_prop, xor_bad_arguments)
{
test_binary_logical(
"Xor", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Xor>(x, y);
});
}
template <typename T>
void test_binary_eltwise_numpy(const element::Type& et, const op::AutoBroadcastSpec& autob)
{
......@@ -242,6 +250,7 @@ TEST(type_prop, eltwise_auto_bcast)
test_binary_eltwise_numpy<op::Or>(element::boolean, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::Power>(element::f32, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::Subtract>(element::f32, op::AutoBroadcastType::NUMPY);
test_binary_eltwise_numpy<op::Xor>(element::boolean, op::AutoBroadcastType::NUMPY);
}
TEST(type_prop, comparison_good)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment