Commit ff98d02a authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Faster argmax/argmin kernels (#2032)

* Faster argmax/argmin kernels

* Use switch statement for macro
parent be9f031e
......@@ -18,7 +18,7 @@
#include "ngraph/op/argmax.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/cpu/kernel/argmax.hpp"
using namespace std;
using namespace ngraph;
......@@ -55,26 +55,27 @@ namespace ngraph
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmax<float, int64_t>(
static_cast<float*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmax<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmax<float, int32_t>(
static_cast<float*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmax<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
......@@ -82,26 +83,28 @@ namespace ngraph
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmax<double, int64_t>(
static_cast<double*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmax<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmax<double, int32_t>(
static_cast<double*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmax<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmax);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
......
......@@ -18,7 +18,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/cpu/kernel/argmin.hpp"
using namespace std;
using namespace ngraph;
......@@ -55,26 +55,27 @@ namespace ngraph
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmin<float, int64_t>(
static_cast<float*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmin<float, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, float, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmin<float, int32_t>(
static_cast<float*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmin<float, int, 1>)> kernel;
SELECT_RANK2(
kernel, float, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
......@@ -82,26 +83,28 @@ namespace ngraph
{
if (is_int64)
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmin<double, int64_t>(
static_cast<double*>(arg_tensor),
static_cast<int64_t*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmin<double, int64_t, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int64_t, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
else
{
functor = [&, in_shape, out_shape, axis](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::argmin<double, int32_t>(
static_cast<double*>(arg_tensor),
static_cast<int*>(out_tensor),
in_shape,
out_shape,
axis);
std::function<decltype(runtime::cpu::kernel::argmin<double, int, 1>)>
kernel;
SELECT_RANK2(
kernel, double, int, in_shape.size(), runtime::cpu::kernel::argmin);
functor = [&, kernel, in_shape, out_shape, axis](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
kernel(arg_tensor, out_tensor, in_shape, out_shape, axis, ectx->arena);
};
}
}
......
......@@ -94,6 +94,19 @@
else \
throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K);
#define SELECT_RANK2(KV, IT, OT, R, K) \
switch (R) \
{ \
case 1: KV = K<IT, OT, 1>; break; \
case 2: KV = K<IT, OT, 2>; break; \
case 3: KV = K<IT, OT, 3>; break; \
case 4: KV = K<IT, OT, 4>; break; \
case 5: KV = K<IT, OT, 5>; break; \
case 6: KV = K<IT, OT, 6>; break; \
case 7: KV = K<IT, OT, 7>; break; \
default: throw ngraph_error("Unsupported rank " + std::to_string(R) + " for kernel " #K); \
}
// Per-type and rank kernel macro
#define SELECT_KERNEL_BY_RANK(KV, ET, R, K) \
if (ET == element::boolean) \
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename InType, typename OutType, unsigned int Rank>
void argmax(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
size_t axis,
int arena)
{
Eigen::array<Eigen::Index, Rank - 1> out_dims;
Eigen::array<Eigen::Index, Rank> in_dims;
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
for (int i = 0; i < Rank - 1; i++)
{
out_dims[i] = output_shape[i];
}
Eigen::TensorMap<Eigen::Tensor<OutType, Rank - 1, Eigen::RowMajor>> out(
static_cast<OutType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<InType, Rank, Eigen::RowMajor>> in(
static_cast<InType*>(input), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in.argmax(axis).template cast<OutType>();
}
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename InType, typename OutType, unsigned int Rank>
void argmin(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
size_t axis,
int arena)
{
Eigen::array<Eigen::Index, Rank - 1> out_dims;
Eigen::array<Eigen::Index, Rank> in_dims;
for (int i = 0; i < Rank; i++)
{
in_dims[i] = input_shape[i];
}
for (int i = 0; i < Rank - 1; i++)
{
out_dims[i] = output_shape[i];
}
Eigen::TensorMap<Eigen::Tensor<OutType, Rank - 1, Eigen::RowMajor>> out(
static_cast<OutType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<InType, Rank, Eigen::RowMajor>> in(
static_cast<InType*>(input), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in.argmin(axis).template cast<OutType>();
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment