Commit 17e59a8a authored by Adam Procter's avatar Adam Procter

De-Eigenize constant and convert

parent 7f5075af
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename TI,typename TO>
void convert(TI* arg, TO* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = TO(arg[i]);
}
}
}
}
}
......@@ -82,8 +82,8 @@
#include "ngraph/runtime/ngvm/instruction/ceiling.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_matrix.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_vector.hpp"
#include "ngraph/runtime/ngvm/eigen/constant.hpp"
#include "ngraph/runtime/ngvm/eigen/convert.hpp"
#include "ngraph/runtime/ngvm/instruction/constant.hpp"
#include "ngraph/runtime/ngvm/instruction/convert.hpp"
#include "ngraph/runtime/ngvm/instruction/copy.hpp"
#include "ngraph/runtime/ngvm/instruction/copy_by_index.hpp"
#include "ngraph/runtime/ngvm/instruction/cos.hpp"
......@@ -325,7 +325,7 @@ std::vector<typename ET::type>
{ \
REGISTER_INSTRUCTION( \
op::ParameterizedConstant<T>, \
eigen::ConstantInstruction<T>, \
instruction::ConstantInstruction<T>, \
std::vector<T::type>{ \
get_vector<T>(dynamic_cast<const op::ParameterizedConstant<T>*>(n)->get_value())}, \
out[0]); \
......@@ -385,7 +385,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto c_value_strings = c->get_value_strings();
#define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \
ef->get_instructions()->push_back(make_shared<eigen::ConstantInstruction<ET>>( \
ef->get_instructions()->push_back(make_shared<instruction::ConstantInstruction<ET>>( \
parse_string<typename ET::type>(c_value_strings), out[0]));
DO_ON_ELEMENT_TYPE(c_element_type,
......@@ -532,7 +532,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
result_element_type == (TO::element_type())) \
{ \
ef->get_instructions()->push_back( \
make_shared<eigen::ConvertInstruction<TI, TO>>(in[0], out[0])); \
make_shared<instruction::ConvertInstruction<TI, TO>>(in[0], out[0])); \
}
// End hacky macro
......
......@@ -15,7 +15,6 @@
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -26,7 +25,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class ConstantInstruction : public Instruction
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/convert.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ETI,typename ETO>
class ConvertInstruction : public Instruction
{
public:
ConvertInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ETI::type* arg = get_tensor_data_ptr<ETI>(call_frame, m_arg);
typename ETO::type* out = get_tensor_data_ptr<ETO>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::convert<typename ETI::type,typename ETO::type>(arg, out, count);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -17,7 +17,6 @@
#include <cassert>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment