Commit 3e80ef25 authored by Adam Procter's avatar Adam Procter

De-Eigenize abs (missed that one)

parent 3d5f9f5e
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void abs(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]);
}
}
}
}
}
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/abs.hpp" #include "ngraph/runtime/ngvm/instruction/abs.hpp"
#include "ngraph/runtime/ngvm/instruction/acos.hpp" #include "ngraph/runtime/ngvm/instruction/acos.hpp"
#include "ngraph/runtime/ngvm/instruction/add.hpp" #include "ngraph/runtime/ngvm/instruction/add.hpp"
#include "ngraph/runtime/ngvm/instruction/asin.hpp" #include "ngraph/runtime/ngvm/instruction/asin.hpp"
...@@ -230,51 +230,12 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -230,51 +230,12 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
} \ } \
} }
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \ #define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \ ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
} }
#define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \ #define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0])); ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \ #define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
...@@ -388,6 +349,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -388,6 +349,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static OpMap op_map; static OpMap op_map;
if (!initialized) if (!initialized)
{ {
REGISTER_NUMERIC_UNOP(op::Abs, instruction::AbsInstruction);
REGISTER_NUMERIC_UNOP(op::Acos, instruction::AcosInstruction); REGISTER_NUMERIC_UNOP(op::Acos, instruction::AcosInstruction);
REGISTER_NUMERIC_UNOP(op::Asin, instruction::AsinInstruction); REGISTER_NUMERIC_UNOP(op::Asin, instruction::AsinInstruction);
REGISTER_NUMERIC_UNOP(op::Atan, instruction::AtanInstruction); REGISTER_NUMERIC_UNOP(op::Atan, instruction::AtanInstruction);
...@@ -405,8 +367,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -405,8 +367,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_NUMERIC_UNOP(op::Tan, instruction::TanInstruction); REGISTER_NUMERIC_UNOP(op::Tan, instruction::TanInstruction);
REGISTER_NUMERIC_UNOP(op::Tanh, instruction::TanhInstruction); REGISTER_NUMERIC_UNOP(op::Tanh, instruction::TanhInstruction);
REGISTER_SIGNED_NUMERIC_UNOP(op::Abs, eigen::AbsInstruction);
REGISTER_NUMERIC_BINOP(op::Add, instruction::AddInstruction); REGISTER_NUMERIC_BINOP(op::Add, instruction::AddInstruction);
REGISTER_NUMERIC_BINOP(op::Divide, instruction::DivideInstruction); REGISTER_NUMERIC_BINOP(op::Divide, instruction::DivideInstruction);
REGISTER_NUMERIC_BINOP(op::Maximum, instruction::MaximumInstruction); REGISTER_NUMERIC_BINOP(op::Maximum, instruction::MaximumInstruction);
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/abs.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp" #include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,13 +26,14 @@ namespace ngraph ...@@ -26,13 +26,14 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AbsInstruction : public Instruction class AbsInstruction : public Instruction
{ {
public: public:
AbsInstruction(const TensorViewInfo& arg, const TensorViewInfo& out) AbsInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +41,12 @@ namespace ngraph ...@@ -40,8 +41,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::abs<typename ET::type>(arg, out, count);
} }
protected: protected:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment