Commit 134b0ae2 authored by shssf's avatar shssf Committed by Scott Cyphers

IntelGPU backend: BatchNorm, Dot, Pad operations optimization (#1393)

parent 9c1c5b59
......@@ -216,6 +216,7 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const string entry_point_name = "batch_norm_" + output_name;
codegen::CodeWriter writer;
vector<size_t> gws;
writer << "__kernel void " << entry_point_name << "( const __global float input"
<< array_dims(input_shape) << ", const __global float gamma" << array_dims(gamma_shape)
......@@ -227,45 +228,17 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
writer.block_begin();
{ // Main function body
// Loop for Channel axis 1
writer << "for (uint i" << channel_axis << " = 0; i" << channel_axis << " < "
<< output_shape.at(channel_axis) << "; ++i" << channel_axis << ")\n";
writer.block_begin();
{
size_t var_idx = 0;
// Main loops
for (auto const& i : output_shape)
{
if (var_idx != channel_axis)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i
<< "; ++i" << var_idx << ")\n";
writer.block_begin();
}
++var_idx;
}
gws = generate_loops(writer, output_shape, true);
writer << "float normalized = (input" << access_dims(input_shape) << " - mean[i"
<< channel_axis << "]) / ("
<< "sqrt(variance[i" << channel_axis << "] + " << eps << ")"
<< ");\n";
writer << "output" << access_dims(output_shape) << " = normalized * gamma[i"
<< channel_axis << "] + beta[i" << channel_axis << "];\n";
var_idx = 0;
// Closing brackets for main loops
for (auto const& i : output_shape)
{
if (var_idx != channel_axis)
{
writer.block_end();
}
++var_idx;
}
writer << "output" << access_dims(output_shape) << " = normalized * gamma[i" << channel_axis
<< "] + beta[i" << channel_axis << "];\n";
} // Closing brackets for Channel axis loop
writer.block_end();
generate_loops(writer, output_shape, false);
} // Main function body
writer.block_end();
......@@ -279,6 +252,6 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
get_kernel_args(5, 1),
"",
layout,
{1});
gws);
topology.add(op_batch_norm);
}
......@@ -18,6 +18,8 @@
#include <CPP/topology.hpp>
#include "ngraph/runtime/intelgpu/code_writer.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/shape.hpp"
......@@ -96,6 +98,8 @@ namespace ngraph
std::string access_dims(const Shape& dimentions,
const AxisSet& axis = {},
bool is_reversed = false);
std::vector<size_t>
generate_loops(codegen::CodeWriter& writer, const Shape& shape, bool is_begin);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment