Commit 5e307d6b authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add OneHot op to direct execution. (#1340)

* Add OneHot op to direct execution.

* Remove exceptions from kernel.

* Use Eigen::Tensor.
parent bcdfb7e1
...@@ -44,6 +44,7 @@ set(SRC ...@@ -44,6 +44,7 @@ set(SRC
builder/max.cpp builder/max.cpp
builder/max_pool.cpp builder/max_pool.cpp
builder/min.cpp builder/min.cpp
builder/one_hot.cpp
builder/relu.cpp builder/relu.cpp
builder/pad.cpp builder/pad.cpp
builder/product.cpp builder/product.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/kernel/one_hot.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::OneHot)
{
auto oh = static_cast<const ngraph::op::OneHot*>(node);
auto one_hot_axis = oh->get_one_hot_axis();
auto arg_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto out_strides = out[0].get_strides();
auto arg_rank = arg_shape.size();
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (arg_rank == 0)
{
std::function<decltype(runtime::cpu::kernel::one_hot_rank_0<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_0);
auto functor = [&, kernel, out_shape, one_hot_axis](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, out_shape, one_hot_axis);
};
functors.emplace_back(functor);
}
else if (arg_rank == 1)
{
std::function<decltype(runtime::cpu::kernel::one_hot_rank_1<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_1);
auto functor = [&, kernel, arg_shape, out_shape, out_strides, one_hot_axis](
CPURuntimeContext* ctx) {
kernel(arg_tensor,
out_tensor,
arg_shape,
out_shape,
out_strides,
one_hot_axis);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::one_hot_rank_2_or_more<float>)>
kernel;
SELECT_KERNEL(kernel,
out[0].get_element_type(),
runtime::cpu::kernel::one_hot_rank_2_or_more);
auto functor =
[&, kernel, arg_shape, out_shape, one_hot_axis](CPURuntimeContext* ctx) {
kernel(arg_tensor, out_tensor, arg_shape, out_shape, one_hot_axis);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(OneHot);
}
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#define EIGEN_USE_THREADS
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace kernel
{
template <typename ElementType>
void one_hot_rank_0(void* arg,
void* out,
const Shape& out_shape,
size_t one_hot_axis)
{
Eigen::array<Eigen::Index, 1> out_dims;
out_dims[0] = out_shape[0];
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out_tensor(
static_cast<ElementType*>(out), out_dims);
out_tensor.setZero();
auto pos_raw = (static_cast<ElementType*>(arg))[0];
size_t pos = pos_raw;
out_tensor(pos) = 1;
}
template <typename ElementType>
void one_hot_rank_1(void* arg,
void* out,
const Shape& arg_shape,
const Shape& out_shape,
const Strides& out_strides,
size_t one_hot_axis)
{
Eigen::array<Eigen::Index, 2> out_dims;
Eigen::array<Eigen::Index, 1> in_dims;
out_dims[0] = out_shape[0];
out_dims[1] = out_shape[1];
in_dims[0] = arg_shape[0];
Eigen::TensorMap<Eigen::Tensor<ElementType, 2, Eigen::RowMajor>> out_tensor(
static_cast<ElementType*>(out), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in_tensor(
static_cast<ElementType*>(arg), in_dims);
out_tensor.setZero();
for (size_t i = 0; i < arg_shape[0]; i++)
{
auto pos_raw = in_tensor(i);
size_t pos = pos_raw;
one_hot_axis == 0 ? out_tensor(pos, i) = 1 : out_tensor(i, pos) = 1;
}
}
template <typename ElementType>
void one_hot_rank_2_or_more(void* arg,
void* out,
const Shape& arg_shape,
const Shape& out_shape,
size_t one_hot_axis)
{
reference::one_hot<ElementType>(static_cast<const ElementType*>(arg),
static_cast<ElementType*>(out),
arg_shape,
out_shape,
one_hot_axis);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment