Commit c21bbba0 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Commit missed kernel changes for Sum reduction

parent f60dd831
......@@ -20,6 +20,7 @@
#include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
......@@ -31,8 +32,8 @@ namespace ngraph
namespace kernel
{
template <typename ElementType, unsigned int Rank>
void reduce_sum_all(ElementType* input,
ElementType* output,
void reduce_sum_all(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape)
{
......@@ -44,16 +45,16 @@ namespace ngraph
in_dims[i] = input_shape[i];
}
Eigen::TensorMap<Eigen::Tensor<ElementType, 0, Eigen::RowMajor>> out(output,
out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(input,
in_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 0, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(
static_cast<ElementType*>(input), in_dims);
out.device(eigen::global_thread_pool_device) = in.sum();
}
template <typename ElementType, unsigned int Rank, unsigned int ReductionDims>
void reduce_sum(ElementType* input,
ElementType* output,
void reduce_sum(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
......@@ -80,11 +81,69 @@ namespace ngraph
Eigen::TensorMap<
Eigen::Tensor<ElementType, Rank - ReductionDims, Eigen::RowMajor>>
out(output, out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(input,
in_dims);
out(static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, Rank, Eigen::RowMajor>> in(
static_cast<ElementType*>(input), in_dims);
out.device(eigen::global_thread_pool_device) = in.sum(reduction_dims);
}
template <typename ElementType, unsigned int Rank>
void reduce_sum_1rd(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
{
reduce_sum<ElementType, Rank, 1>(
input, output, input_shape, output_shape, reduction_axes);
}
template <typename ElementType>
void reduce_sum_3d_2rd(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
{
reduce_sum<ElementType, 3, 2>(
input, output, input_shape, output_shape, reduction_axes);
}
template <typename ElementType>
void reduce_sum_4d_2rd(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
{
reduce_sum<ElementType, 4, 2>(
input, output, input_shape, output_shape, reduction_axes);
}
template <typename ElementType>
void reduce_sum_5d_2rd(void* input,
void* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
{
reduce_sum<ElementType, 5, 2>(
input, output, input_shape, output_shape, reduction_axes);
}
template <typename ElementType>
void sum(void* arg,
void* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& reduction_axes)
{
reference::sum(static_cast<ElementType*>(arg),
static_cast<ElementType*>(out),
in_shape,
out_shape,
reduction_axes);
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment