Commit 3e80ef25 authored by Adam Procter's avatar Adam Procter

De-Eigenize abs (missed that one)

parent 3d5f9f5e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void abs(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]);
}
}
}
}
}
......@@ -70,7 +70,7 @@
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/abs.hpp"
#include "ngraph/runtime/ngvm/instruction/abs.hpp"
#include "ngraph/runtime/ngvm/instruction/acos.hpp"
#include "ngraph/runtime/ngvm/instruction/add.hpp"
#include "ngraph/runtime/ngvm/instruction/asin.hpp"
......@@ -230,51 +230,12 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
} \
}
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
}
#define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
......@@ -388,6 +349,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static OpMap op_map;
if (!initialized)
{
REGISTER_NUMERIC_UNOP(op::Abs, instruction::AbsInstruction);
REGISTER_NUMERIC_UNOP(op::Acos, instruction::AcosInstruction);
REGISTER_NUMERIC_UNOP(op::Asin, instruction::AsinInstruction);
REGISTER_NUMERIC_UNOP(op::Atan, instruction::AtanInstruction);
......@@ -405,8 +367,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_NUMERIC_UNOP(op::Tan, instruction::TanInstruction);
REGISTER_NUMERIC_UNOP(op::Tanh, instruction::TanhInstruction);
REGISTER_SIGNED_NUMERIC_UNOP(op::Abs, eigen::AbsInstruction);
REGISTER_NUMERIC_BINOP(op::Add, instruction::AddInstruction);
REGISTER_NUMERIC_BINOP(op::Divide, instruction::DivideInstruction);
REGISTER_NUMERIC_BINOP(op::Maximum, instruction::MaximumInstruction);
......
......@@ -14,11 +14,11 @@
#pragma once
#include "ngraph/runtime/kernel/abs.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
......@@ -26,13 +26,14 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AbsInstruction : public Instruction
{
public:
AbsInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
AbsInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +41,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::abs<typename ET::type>(arg, out, count);
}
protected:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment