Commit 7f5075af authored by Adam Procter's avatar Adam Procter

De-Eigenize copy instruction (technically it was never Eigenized, but whatever)

parent f0ce2244
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void copy(T* arg, T* out, size_t count)
{
memcpy(out,arg,sizeof(T) * count);
}
}
}
}
......@@ -84,7 +84,8 @@
#include "ngraph/runtime/ngvm/eigen/concat_vector.hpp"
#include "ngraph/runtime/ngvm/eigen/constant.hpp"
#include "ngraph/runtime/ngvm/eigen/convert.hpp"
#include "ngraph/runtime/ngvm/eigen/copy.hpp"
#include "ngraph/runtime/ngvm/instruction/copy.hpp"
#include "ngraph/runtime/ngvm/instruction/copy_by_index.hpp"
#include "ngraph/runtime/ngvm/instruction/cos.hpp"
#include "ngraph/runtime/ngvm/instruction/cosh.hpp"
#include "ngraph/runtime/ngvm/instruction/divide.hpp"
......@@ -432,9 +433,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::CopyInstruction,
in[0].get_index(),
out[0].get_index());
instruction::CopyInstruction,
in[0],
out[0]);
}
else if (arg_shape.size() == 0)
{
......@@ -657,9 +658,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"GetTupleElement has unhandled element type",
eigen::CopyInstruction,
in.at(get_tuple_element->get_n()).get_index(),
out.at(0).get_index());
instruction::CopyInstruction,
in[get_tuple_element->get_n()],
out[0]);
};
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
......@@ -670,9 +671,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto& et = in.at(i).get_tensor_view_layout()->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(et,
"Tuple has unhandled element type",
eigen::CopyInstruction,
in.at(i).get_index(),
out.at(i).get_index());
instruction::CopyInstruction,
in[i],
out[i]);
}
};
......@@ -738,9 +739,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily,
......@@ -775,9 +776,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(1).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[1],
out[0]);
}
else
{
......@@ -862,9 +863,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(s_element_type,
"Sum has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
......@@ -929,9 +930,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2)
......@@ -978,9 +979,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
else if (arg_rank == 1)
{
......@@ -1045,9 +1046,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Replace-slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(1).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[1],
out[0]);
}
else if (arg0_rank == 1)
{
......@@ -1142,9 +1143,10 @@ void ExternalFunction::compile(FunctionMap& function_map)
assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
auto ef = this;
// TODO: This is the one case where we can't use the new CopyInstruction that takes in a TensorViewInfo. (At least, I can't figure out how to do it.)
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Copy has unhandled element type",
eigen::CopyInstruction,
instruction::CopyByIndexInstruction,
prev_index_it->second,
index);
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/copy.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class CopyInstruction : public Instruction
{
public:
CopyInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::copy<typename ET::type>(arg, out, count);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -27,16 +27,16 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
/// @brief Copies a tensor from in to out.
template <typename ET>
class CopyInstruction : public Instruction
class CopyByIndexInstruction : public Instruction
{
public:
/// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
CopyByIndexInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment