Commit 045ab6bb authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

CPU Direct Execution: Implement ReplaceSlice (#1357)

* CPU Direct Execution: Implement ReplaceSlice

* Remove scalar variant
parent a2ba381d
......@@ -46,6 +46,7 @@ set(SRC
builder/product.cpp
builder/reduce_function.cpp
builder/reduce_function_window.cpp
builder/replace_slice.cpp
builder/reshape.cpp
builder/reverse.cpp
builder/reverse_sequence.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <cstring>
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/replace_slice.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ReplaceSlice)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto strides = replace_slice->get_strides();
auto lower_bounds = replace_slice->get_lower_bounds();
auto upper_bounds = replace_slice->get_upper_bounds();
bool strided = false;
for (auto stride : strides)
{
if (stride != 1)
{
strided = true;
break;
}
}
if (!arg0_shape.size())
{
size_t size = args[0].get_element_type().size();
auto functor = [&, size](CPURuntimeContext* ctx) {
memcpy(out_tensor, arg1_tensor, size);
};
functors.emplace_back(functor);
return;
}
if (strided)
{
std::function<decltype(runtime::cpu::kernel::strided_replace_slice<float, 2>)>
kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::strided_replace_slice);
auto functor =
[&, kernel, arg0_shape, arg1_shape, lower_bounds, upper_bounds, strides](
CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
lower_bounds,
upper_bounds,
strides);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::replace_slice<float, 2>)> kernel;
SELECT_KERNEL_BY_RANK(kernel,
args[0].get_element_type(),
arg0_shape.size(),
runtime::cpu::kernel::replace_slice);
auto functor =
[&, kernel, arg0_shape, arg1_shape, lower_bounds](CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
arg0_shape,
arg1_shape,
lower_bounds);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(ReplaceSlice);
}
}
}
......@@ -110,7 +110,7 @@
else if (R == 16) \
KV = K<ET, 16>; \
else \
throw ngraph_error("Unsupported rank " #R " for kernel " #K);
throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K);
// Per-type and rank kernel macro
#define SELECT_KERNEL_BY_RANK(KV, ET, R, K) \
......@@ -181,7 +181,7 @@
else if (R == 6) \
KV = K<ET, 6>; \
else \
throw ngraph_error("Unsupported rank " #R " for kernel " #K);
throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K);
// Partial per-type and rank kernel macro
#define PARTIAL_SELECT_KERNEL_BY_RANK(KV, ET, R, K) \
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType, unsigned int Rank>
void replace_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds)
{
Eigen::array<Eigen::Index, Rank> in0_dims, in1_dims;
Eigen::array<Eigen::Index, Rank> indices;
for (int i = 0; i < Rank; i++)
{
in0_dims[i] = input0_shape[i];
in1_dims[i] = input1_shape[i];
indices[i] = lower_bounds[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
out.device(eigen::global_thread_pool_device) = in0;
out.slice(indices, in1_dims).device(eigen::global_thread_pool_device) = in1;
}
template <typename ElementType, unsigned int Rank>
void strided_replace_slice(void* input0,
void* input1,
void* output,
const Shape& input0_shape,
const Shape& input1_shape,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Strides& slice_strides)
{
Eigen::array<Eigen::Index, Rank> in0_dims, in1_dims;
Eigen::array<Eigen::Index, Rank> start_indices, stop_indices, strides;
for (int i = 0; i < Rank; i++)
{
in0_dims[i] = input0_shape[i];
in1_dims[i] = input1_shape[i];
start_indices[i] = lower_bounds[i];
stop_indices[i] = upper_bounds[i];
strides[i] = slice_strides[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in0_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in1(
static_cast<ElementType*>(input1), in1_dims);
out.device(eigen::global_thread_pool_device) = in0;
out.stridedSlice(start_indices, stop_indices, strides)
.device(eigen::global_thread_pool_device) = in1;
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment