Commit 83433ef2 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Implement 'not' in interpreter and CPU; add unit tests for same (#321)

* Implement 'not' in interpreter and CPU; add unit tests for same

* Fix compile failure on CI
parent 4c52d420
......@@ -1357,6 +1357,15 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n,
<< "});\n";
}
void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
m_out << "kernel::logical_not(" << args[0].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << out[0].get_size() << ");\n";
}
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
......
......@@ -95,6 +95,7 @@ namespace ngraph
void EMITTER_DECL(EmitCeiling);
void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitNot);
private:
void generate_call(const std::vector<TensorViewWrapper>& args,
......
......@@ -58,6 +58,7 @@
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/one_hot.hpp"
#include "ngraph/ops/power.hpp"
......@@ -152,6 +153,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Ceiling), &runtime::cpu::CPU_Emitter::EmitCeiling},
{TI(ngraph::op::Sqrt), &runtime::cpu::CPU_Emitter::EmitSqrt},
{TI(ngraph::op::Convolution), &runtime::cpu::CPU_Emitter::EmitConvolution},
{TI(ngraph::op::Not), &runtime::cpu::CPU_Emitter::EmitNot},
};
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
......@@ -197,6 +199,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/kernel/concat.hpp"
#include "ngraph/runtime/kernel/convolution.hpp"
#include "ngraph/runtime/kernel/dot.hpp"
#include "ngraph/runtime/kernel/not.hpp"
#include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/kernel/reduce.hpp"
#include "ngraph/runtime/kernel/replace_slice.hpp"
......
......@@ -63,6 +63,7 @@
#include "ngraph/runtime/kernel/minimum.hpp"
#include "ngraph/runtime/kernel/multiply.hpp"
#include "ngraph/runtime/kernel/negate.hpp"
#include "ngraph/runtime/kernel/not.hpp"
#include "ngraph/runtime/kernel/not_equal.hpp"
#include "ngraph/runtime/kernel/one_hot.hpp"
#include "ngraph/runtime/kernel/power.hpp"
......@@ -406,6 +407,12 @@ private:
reinterpret_cast<T*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "Not")
{
kernel::logical_not(reinterpret_cast<char*>(args[0]->get_data_ptr()),
reinterpret_cast<char*>(out[0]->get_data_ptr()),
out[0]->get_element_count());
}
else if (node_op == "NotEqual")
{
kernel::not_equal<T>(reinterpret_cast<T*>(args[0]->get_data_ptr()),
......
......@@ -20,7 +20,8 @@ namespace ngraph
{
namespace kernel
{
void logical_not(char* arg,
static inline void
logical_not(char* arg,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
......
......@@ -4463,3 +4463,24 @@ TEST(${BACKEND_NAME}, DISABLED_parameter_to_output)
cf->call({a}, {result});
EXPECT_EQ((vector<float>{1, -2, 0, -4.8f}), result->get_vector<float>());
}
TEST(${BACKEND_NAME}, not)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto result_type = make_shared<TensorViewType>(element::boolean, shape);
auto f = make_shared<Function>(make_shared<op::Not>(A), result_type, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::boolean, shape);
copy_data(a, vector<char>{1, 0, 2, 0});
auto result = backend->make_primary_tensor_view(element::boolean, shape);
cf->call({a}, {result});
EXPECT_EQ((vector<char>{0, 1, 0, 1}), result->get_vector<char>());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment