Commit 59e119bf authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Implement ReverseSequence

parent cb336bce
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/reverse_sequence.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ReverseSequence)
{
auto rev_seq = static_cast<const ngraph::op::ReverseSequence*>(node);
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& seq_len_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto arg_shape = args[0].get_shape();
auto sequence_axis = rev_seq->get_sequence_axis();
auto batch_axis = rev_seq->get_batch_axis();
std::function<decltype(runtime::cpu::kernel::reverse_sequence<int, int, 4>)> kernel;
if (args[1].get_element_type() == element::i32)
{
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg_shape.size(),
runtime::cpu::kernel::reverse_sequence_sli32);
}
else
{
throw ngraph_error("Unsupported sequence length type " +
args[1].get_element_type().c_type_string() +
" requires a kernel instantiation to handle this type");
}
auto functor =
[&, kernel, arg_shape, batch_axis, sequence_axis](CPURuntimeContext* ctx) {
kernel(arg_tensor,
out_tensor,
arg_shape,
batch_axis,
sequence_axis,
seq_len_tensor);
};
functors.emplace_back(functor);
}
REGISTER_OP_BUILDER(ReverseSequence);
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstdint>
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename InputElementType, typename SeqLenType, unsigned int Rank>
void reverse_sequence(void* input,
void* output,
const Shape& input_shape,
size_t batch_axis,
size_t sequence_axis,
void* sequence_lengths)
{
Eigen::array<Eigen::Index, Rank> in_dims;
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
Eigen::TensorMap<Eigen::Tensor<InputElementType, Rank, Eigen::RowMajor>> out(
static_cast<InputElementType*>(output), in_dims);
Eigen::TensorMap<Eigen::Tensor<InputElementType, Rank, Eigen::RowMajor>> in(
static_cast<InputElementType*>(input), in_dims);
auto slv = static_cast<SeqLenType*>(sequence_lengths);
auto generator = [&](const Eigen::array<Eigen::DenseIndex, Rank>& i) {
Eigen::array<Eigen::DenseIndex, Rank> k = i;
if (i[sequence_axis] < slv[i[batch_axis]])
{
k[sequence_axis] = slv[i[batch_axis]] - i[sequence_axis] - 1;
}
return in(k);
};
out.device(eigen::global_thread_pool_device) = in.generate(generator);
}
template <typename InputElementType, unsigned int Rank>
void reverse_sequence_sli32(void* input,
void* output,
const Shape& input_shape,
size_t batch_axis,
size_t sequence_axis,
void* sequence_lengths)
{
reverse_sequence<InputElementType, int32_t, Rank>(
input, output, input_shape, batch_axis, sequence_axis, sequence_lengths);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment