Commit 73bff556 authored by Amy Zhuang's avatar Amy Zhuang Committed by Robert Kimball

Modify DEX OneHot op: use generator. (#1446)

* Modify DEX OneHot op: use generator.

* Cast index to int.
parent 58f9af01
...@@ -58,14 +58,9 @@ namespace ngraph ...@@ -58,14 +58,9 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::one_hot_rank_1<float>)> kernel; std::function<decltype(runtime::cpu::kernel::one_hot_rank_1<float>)> kernel;
SELECT_KERNEL( SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_1); kernel, out[0].get_element_type(), runtime::cpu::kernel::one_hot_rank_1);
auto functor = [&, kernel, arg_shape, out_shape, out_strides, one_hot_axis]( auto functor =
CPURuntimeContext* ctx) { [&, kernel, arg_shape, out_shape, one_hot_axis](CPURuntimeContext* ctx) {
kernel(arg_tensor, kernel(arg_tensor, out_tensor, arg_shape, out_shape, one_hot_axis);
out_tensor,
arg_shape,
out_shape,
out_strides,
one_hot_axis);
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -38,16 +38,10 @@ namespace ngraph ...@@ -38,16 +38,10 @@ namespace ngraph
size_t one_hot_axis) size_t one_hot_axis)
{ {
Eigen::array<Eigen::Index, 1> out_dims; memset(out, 0, sizeof(ElementType) * shape_size(out_shape));
out_dims[0] = out_shape[0];
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out_tensor(
static_cast<ElementType*>(out), out_dims);
out_tensor.setZero();
auto pos_raw = (static_cast<ElementType*>(arg))[0]; auto pos_raw = (static_cast<ElementType*>(arg))[0];
size_t pos = pos_raw; size_t pos = pos_raw;
out_tensor(pos) = 1; (static_cast<ElementType*>(out))[pos] = 1;
} }
template <typename ElementType> template <typename ElementType>
...@@ -55,7 +49,6 @@ namespace ngraph ...@@ -55,7 +49,6 @@ namespace ngraph
void* out, void* out,
const Shape& arg_shape, const Shape& arg_shape,
const Shape& out_shape, const Shape& out_shape,
const Strides& out_strides,
size_t one_hot_axis) size_t one_hot_axis)
{ {
...@@ -67,16 +60,21 @@ namespace ngraph ...@@ -67,16 +60,21 @@ namespace ngraph
Eigen::TensorMap<Eigen::Tensor<ElementType, 2, Eigen::RowMajor>> out_tensor( Eigen::TensorMap<Eigen::Tensor<ElementType, 2, Eigen::RowMajor>> out_tensor(
static_cast<ElementType*>(out), out_dims); static_cast<ElementType*>(out), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in_tensor( Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in_tensor(
static_cast<ElementType*>(arg), in_dims); static_cast<ElementType*>(arg), in_dims);
out_tensor.setZero(); auto generator = [&](const Eigen::array<Eigen::DenseIndex, 2>& idx) {
for (size_t i = 0; i < arg_shape[0]; i++) if ((one_hot_axis == 0 && idx[0] == static_cast<int>(in_tensor(idx[1]))) ||
(one_hot_axis == 1 && idx[1] == static_cast<int>(in_tensor(idx[0]))))
{ {
auto pos_raw = in_tensor(i); return 1;
size_t pos = pos_raw;
one_hot_axis == 0 ? out_tensor(pos, i) = 1 : out_tensor(i, pos) = 1;
} }
return 0;
};
out_tensor.device(eigen::global_thread_pool_device) =
out_tensor.generate(generator);
} }
template <typename ElementType> template <typename ElementType>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment