Commit 7f5075af authored by Adam Procter's avatar Adam Procter

De-Eigenize copy instruction (technically it was never Eigenized, but whatever)

parent f0ce2244
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void copy(T* arg, T* out, size_t count)
{
memcpy(out,arg,sizeof(T) * count);
}
}
}
}
...@@ -84,7 +84,8 @@ ...@@ -84,7 +84,8 @@
#include "ngraph/runtime/ngvm/eigen/concat_vector.hpp" #include "ngraph/runtime/ngvm/eigen/concat_vector.hpp"
#include "ngraph/runtime/ngvm/eigen/constant.hpp" #include "ngraph/runtime/ngvm/eigen/constant.hpp"
#include "ngraph/runtime/ngvm/eigen/convert.hpp" #include "ngraph/runtime/ngvm/eigen/convert.hpp"
#include "ngraph/runtime/ngvm/eigen/copy.hpp" #include "ngraph/runtime/ngvm/instruction/copy.hpp"
#include "ngraph/runtime/ngvm/instruction/copy_by_index.hpp"
#include "ngraph/runtime/ngvm/instruction/cos.hpp" #include "ngraph/runtime/ngvm/instruction/cos.hpp"
#include "ngraph/runtime/ngvm/instruction/cosh.hpp" #include "ngraph/runtime/ngvm/instruction/cosh.hpp"
#include "ngraph/runtime/ngvm/instruction/divide.hpp" #include "ngraph/runtime/ngvm/instruction/divide.hpp"
...@@ -432,9 +433,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -432,9 +433,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type", "Broadcast has unhandled element type",
eigen::CopyInstruction, instruction::CopyInstruction,
in[0].get_index(), in[0],
out[0].get_index()); out[0]);
} }
else if (arg_shape.size() == 0) else if (arg_shape.size() == 0)
{ {
...@@ -657,9 +658,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -657,9 +658,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"GetTupleElement has unhandled element type", "GetTupleElement has unhandled element type",
eigen::CopyInstruction, instruction::CopyInstruction,
in.at(get_tuple_element->get_n()).get_index(), in[get_tuple_element->get_n()],
out.at(0).get_index()); out[0]);
}; };
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy. // Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
...@@ -670,9 +671,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -670,9 +671,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto& et = in.at(i).get_tensor_view_layout()->get_element_type(); auto& et = in.at(i).get_tensor_view_layout()->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(et, PUSH_POLYMORPHIC_INSTRUCTION(et,
"Tuple has unhandled element type", "Tuple has unhandled element type",
eigen::CopyInstruction, instruction::CopyInstruction,
in.at(i).get_index(), in[i],
out.at(i).get_index()); out[i]);
} }
}; };
...@@ -738,9 +739,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -738,9 +739,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type", "Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(0).get_index(), in[0],
out.at(0).get_index()); out[0]);
} }
// Behavior for zero-size axes bears some explanation here. XLA's reduce // Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily, // operator provides an "base" element (usually, but not necessarily,
...@@ -775,9 +776,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -775,9 +776,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type", "Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(1).get_index(), in[1],
out.at(0).get_index()); out[0]);
} }
else else
{ {
...@@ -862,9 +863,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -862,9 +863,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(s_element_type, PUSH_POLYMORPHIC_INSTRUCTION(s_element_type,
"Sum has unhandled element type", "Sum has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(0).get_index(), in[0],
out.at(0).get_index()); out[0]);
} }
// Full reduction? Then sum to scalar. // Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) || else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
...@@ -929,9 +930,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -929,9 +930,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type", "Reshape has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(0).get_index(), in[0],
out.at(0).get_index()); out[0]);
} }
// If there *is* a layout change in the 2D case, we transpose the input. // If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2) else if (arg_rank == 2)
...@@ -978,9 +979,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -978,9 +979,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type, PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type", "Slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(0).get_index(), in[0],
out.at(0).get_index()); out[0]);
} }
else if (arg_rank == 1) else if (arg_rank == 1)
{ {
...@@ -1045,9 +1046,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -1045,9 +1046,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type, PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Replace-slice has unhandled element type", "Replace-slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction, runtime::ngvm::instruction::CopyInstruction,
in.at(1).get_index(), in[1],
out.at(0).get_index()); out[0]);
} }
else if (arg0_rank == 1) else if (arg0_rank == 1)
{ {
...@@ -1142,9 +1143,10 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -1142,9 +1143,10 @@ void ExternalFunction::compile(FunctionMap& function_map)
assert(nullptr != result_tensor_type); assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type(); auto& result_element_type = result_tensor_type->get_element_type();
auto ef = this; auto ef = this;
// TODO: This is the one case where we can't use the new CopyInstruction that takes in a TensorViewInfo. (At least, I can't figure out how to do it.)
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Copy has unhandled element type", "Copy has unhandled element type",
eigen::CopyInstruction, instruction::CopyByIndexInstruction,
prev_index_it->second, prev_index_it->second,
index); index);
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/copy.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class CopyInstruction : public Instruction
{
public:
CopyInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::copy<typename ET::type>(arg, out, count);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -27,16 +27,16 @@ namespace ngraph ...@@ -27,16 +27,16 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
/// @brief Copies a tensor from in to out. /// @brief Copies a tensor from in to out.
template <typename ET> template <typename ET>
class CopyInstruction : public Instruction class CopyByIndexInstruction : public Instruction
{ {
public: public:
/// @param in Index of input tensor in call frame. /// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame. /// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out) CopyByIndexInstruction(size_t in, size_t out)
: m_in(in) : m_in(in)
, m_out(out) , m_out(out)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment