Commit cf7a4384 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Implement greater-than, less-or-equal, greater-or-equal (#164)

parent ae085010
......@@ -63,7 +63,9 @@
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/greater_eq.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class GreaterEq : public BinaryElementwiseComparison
{
public:
GreaterEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string description() const override { return "GreaterEq"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class LessEq : public BinaryElementwiseComparison
{
public:
LessEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string description() const override { return "LessEq"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename TI, typename TO>
void greater_eq(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map_array(&*arg0) <= get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
template <typename ET>
class GreaterEqInstruction : public Instruction
{
public:
GreaterEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class GreaterThanInstruction : public Instruction
{
public:
GreaterThanInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class LessEqInstruction : public Instruction
{
public:
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
......@@ -25,14 +25,6 @@ namespace ngraph
{
namespace eigen
{
template <typename TI, typename TO>
void less_than(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map_array(&*arg0) < get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
template <typename ET>
class LessThanInstruction : public Instruction
{
......
......@@ -33,7 +33,10 @@
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/greater_eq.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/less_eq.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/multiply.hpp"
......@@ -60,6 +63,9 @@
#include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp"
#include "ngraph/runtime/eigen/equal.hpp"
#include "ngraph/runtime/eigen/greater_eq.hpp"
#include "ngraph/runtime/eigen/greater_than.hpp"
#include "ngraph/runtime/eigen/less_eq.hpp"
#include "ngraph/runtime/eigen/less_than.hpp"
#include "ngraph/runtime/eigen/log.hpp"
#include "ngraph/runtime/eigen/matrix_mult.hpp"
......@@ -121,7 +127,10 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_BINOP(op::Add, runtime::eigen::AddInstruction<element::Float32>);
REGISTER_BINOP(op::Divide, runtime::eigen::DivideInstruction<element::Float32>);
REGISTER_BINOP(op::Equal, runtime::eigen::EqualInstruction<element::Float32>);
REGISTER_BINOP(op::Greater, runtime::eigen::GreaterThanInstruction<element::Float32>);
REGISTER_BINOP(op::GreaterEq, runtime::eigen::GreaterEqInstruction<element::Float32>);
REGISTER_BINOP(op::Less, runtime::eigen::LessThanInstruction<element::Float32>);
REGISTER_BINOP(op::LessEq, runtime::eigen::LessEqInstruction<element::Float32>);
REGISTER_UNOP(op::Log, runtime::eigen::LogInstruction<element::Float32>);
REGISTER_BINOP(op::Maximum, runtime::eigen::MaximumInstruction<element::Float32>);
REGISTER_BINOP(op::Multiply, runtime::eigen::MultiplyInstruction<element::Float32>);
......
......@@ -417,7 +417,51 @@ TEST(execute, test_dot_matrix_vector)
ASSERT_EQ((vector<float>{190, 486, 782, 1078}), result->get_vector());
}
TEST(execute, test_lessthan)
TEST(execute, test_greater)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Greater>(A, B), rt, op::Parameters{A, B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0.5, 2, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, 4, 8, 0, 0, 1, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<char>{0, 1, 0, 1, 0, 1, 1, 0}), result->get_vector());
}
TEST(execute, test_greatereq)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::GreaterEq>(A, B), rt, op::Parameters{A, B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0, 2, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, -8, 8, 0, 0, 0.5, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<char>{1, 1, 1, 1, 0, 1, 1, 0}), result->get_vector());
}
TEST(execute, test_less)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
......@@ -439,6 +483,28 @@ TEST(execute, test_lessthan)
ASSERT_EQ((vector<char>{0, 0, 1, 0, 1, 0, 0, 1}), result->get_vector());
}
TEST(execute, test_lesseq)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto rt = make_shared<TensorViewType>(element::Bool::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::LessEq>(A, B), rt, op::Parameters{A, B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0, 2, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, -8, 8, 0, 0, 0.5, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a, b}, {result});
ASSERT_EQ((vector<char>{1, 0, 1, 0, 1, 1, 0, 1}), result->get_vector());
}
TEST(execute, test_log)
{
auto shape = Shape{2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment