Commit 0b95efa6 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Optimized eigen kernel for spatial mean (#1094)

* Optimized eigen kernel for 2D reduction on a 4D tensor used for spatial mean

* revert change to serializer
parent abb68627
......@@ -2095,6 +2095,16 @@ namespace ngraph
<< "{" << join(sum->get_reduction_axes()) << "}"
<< ");\n";
}
else if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 4 && sum->get_reduction_axes().size() == 2)
{
writer << "cpu::kernel::reduce_sum_4d_2rd_float32(" << args[0].get_name()
<< ", " << out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(out[0].get_shape()) << "}, "
<< "{" << join(sum->get_reduction_axes()) << "}"
<< ");\n";
}
else if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 4 && sum->get_reduction_axes().size() == 4)
{
......
......@@ -158,6 +158,12 @@ namespace ngraph
const Shape& output_shape,
const AxisSet& reduction_axes);
void reduce_sum_4d_2rd_float32(float* input,
float* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes);
void reduce_sum_all_4d_float32(float* input,
float* output,
const Shape& input_shape,
......
......@@ -57,6 +57,15 @@ namespace ngraph
{
reduce_sum_all<float, 4>(input, output, input_shape, output_shape);
}
void reduce_sum_4d_2rd_float32(float* input,
float* output,
const Shape& input_shape,
const Shape& output_shape,
const AxisSet& reduction_axes)
{
reduce_sum<float, 4, 2>(
input, output, input_shape, output_shape, reduction_axes);
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment