Commit 36e1de51 authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: BatchNorm operation optimization (#1579)

* IntelGPU backend: BatchNorm operation optimization

* PR1579. Function moved by request
parent 4341c6ac
...@@ -48,6 +48,11 @@ static Shape get_channel_shape(const Shape& shape, const string& function_name) ...@@ -48,6 +48,11 @@ static Shape get_channel_shape(const Shape& shape, const string& function_name)
return {shape.at(channel_axis)}; return {shape.at(channel_axis)};
} }
static size_t get_idx_size(const Shape& shape, size_t pos)
{
return accumulate(shape.cbegin() + pos, shape.cend(), 1, multiplies<size_t>());
}
void runtime::intelgpu::do_create_mean(cldnn::topology& topology, void runtime::intelgpu::do_create_mean(cldnn::topology& topology,
const string& output_name, const string& output_name,
const element::Type& output_type, const element::Type& output_type,
...@@ -210,32 +215,46 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology, ...@@ -210,32 +215,46 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
{ {
const Shape channel_shape = get_channel_shape(input_shape, "batch_norm"); const Shape channel_shape = get_channel_shape(input_shape, "batch_norm");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, input_shape); const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, input_shape);
const vector<size_t> gws(input_shape.begin(), input_shape.begin() + 2);
const string entry_point_name = "batch_norm_" + output_name; const string entry_point_name = "batch_norm_" + output_name;
codegen::CodeWriter writer; codegen::CodeWriter writer;
vector<size_t> gws;
writer << "__kernel void " << entry_point_name << "(const __global float input"
<< array_dims(input_shape) << ", const __global float gamma" << array_dims(channel_shape)
<< ", const __global float beta" << array_dims(channel_shape)
<< ", const __global float mean" << array_dims(channel_shape)
<< ", const __global float variance" << array_dims(channel_shape)
<< ", __global float output" << array_dims(input_shape) << ")\n";
// The kernel name and parameters
writer << "__attribute__((reqd_work_group_size(1,1,1)))\n"
<< "__kernel void " << entry_point_name
<< "(const __global float *input0, const __global float *input1,"
<< " const __global float *input2, const __global float *input3,"
<< " const __global float *input4, __global float *output)\n";
writer.block_begin(); writer.block_begin();
{ // Main function body { // Main function body
gws = generate_loops(writer, input_shape, true); writer << "// input array dims: input0" << array_dims(input_shape);
// Channel axis loop
writer << "\nconst uint i" << channel_axis << " = get_global_id(" << channel_axis
<< "); /* channel_axis trip count " << input_shape.at(channel_axis) << "*/\n";
writer << "float normalized = (input" << access_dims(input_shape) << " - mean[i" // Invariants for the rest of the loops
<< channel_axis << "]) / (" writer << "const float gamma = input1[i" << channel_axis << "];\n"
<< "sqrt(variance[i" << channel_axis << "] + " << eps << ")" << "const float beta = input2[i" << channel_axis << "];\n"
<< ");\n"; << "const float mean = input3[i" << channel_axis << "];\n"
<< "const float variance = input4[i" << channel_axis << "];\n"
<< "const float var_sqrt = (gamma / sqrt(variance + " << eps << "));\n";
writer << "output" << access_dims(input_shape) << " = normalized * gamma[i" << channel_axis writer << "const uint i0 = get_global_id(0);"
<< "] + beta[i" << channel_axis << "];\n"; << " /* batch axis trip count " << input_shape.at(0) << "*/\n";
generate_loops(writer, input_shape, false); // loop index invariants
writer << "const uint idx0 = (i0 * " << get_idx_size(input_shape, 1) << ") + (i1 * "
<< get_idx_size(input_shape, 2) << ");\n";
// SIMD loop
writer << "for (uint i3 = 0; i3 < " << get_idx_size(input_shape, 2) << "; ++i3)\n";
writer.block_begin();
{
writer << "const uint idx = idx0 + i3;\n";
writer << "output[idx] = (input0[idx] - mean) * var_sqrt + beta;\n";
} // Closing brackets for SIMD loop
writer.block_end();
} // Main function body } // Main function body
writer.block_end(); writer.block_end();
...@@ -248,7 +267,8 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology, ...@@ -248,7 +267,8 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
get_kernel_args(5, 1), get_kernel_args(5, 1),
"", "",
layout, layout,
gws); gws,
{1, 1, 1});
topology.add(op_batch_norm); topology.add(op_batch_norm);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment