Commit c5889b2b authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: Dot_2x2 operation bug fix (#1329)

parent 0405a870
......@@ -204,22 +204,21 @@ static void do_1d_scalar_mul(codegen::CodeWriter& writer,
static void do_2d_2d_mul(codegen::CodeWriter& writer,
string& kernel_name,
const Shape& shapeA,
const Shape& shapeB)
const Shape& shapeB,
const Shape& shapeZ)
{
const size_t rows = shapeA.at(0);
const size_t colrow = shapeA.at(1);
const size_t cols = shapeB.back();
kernel_name += "_do_2d_2d_mul";
writer << "__kernel void " << kernel_name << "(const __global float inputA"
<< runtime::intelgpu::array_dims(shapeA) << ", const __global float inputB"
<< runtime::intelgpu::array_dims(shapeB) << ", __global float output"
<< runtime::intelgpu::array_dims({rows, cols}) << ")\n";
<< runtime::intelgpu::array_dims(shapeZ) << ")\n";
writer.block_begin();
{
size_t var_idx = 0;
// Main loops
for (auto const& i : shapeA)
for (auto const& i : shapeZ)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i << "; ++i"
<< var_idx << ")\n";
......@@ -238,7 +237,7 @@ static void do_2d_2d_mul(codegen::CodeWriter& writer,
writer << "output[i0][i1] = sum;\n";
// Closing brackets for main loops
for (auto const& i : shapeA)
for (auto const& i : shapeZ)
{
writer.block_end();
}
......@@ -444,7 +443,7 @@ void runtime::intelgpu::do_dot_operation(cldnn::topology& topology,
}
else if (inputA_shape.size() == 2 && inputB_shape.size() == 2)
{
do_2d_2d_mul(writer, entry_point_name, inputA_shape, inputB_shape);
do_2d_2d_mul(writer, entry_point_name, inputA_shape, inputB_shape, output_shape);
}
else if (inputA_shape.size() == 3 && inputB_shape.size() == 3)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment