Commit 4034a0c2 authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Scott Cyphers

IntelGPU backend: Allow more cases for clDNN gemm (#2187)

parent 8fc481a3
...@@ -685,28 +685,30 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -685,28 +685,30 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
if (get_input_type(op) == element::f32 && get_input_type(op, 1) == element::f32 && if (get_input_type(op) == element::f32 && get_input_type(op, 1) == element::f32 &&
get_output_type(op) == element::f32 && input0_elem_count && input1_elem_count && get_output_type(op) == element::f32 && input0_elem_count && input1_elem_count &&
(axes_count == 1) && (input0_shape.size() < 3) && (input1_shape.size() < 3) && (axes_count < 2) && (input0_shape.size() < 3) && (input1_shape.size() < 3))
!input0_shape.empty() && !input1_shape.empty())
{ {
string input1_name = get_input_name(op, 1); bool transpose0 = false;
bool transpose1 = false;
// If we have A[5] and B[] here, in cldnn we have A[1, 1, 1, 5] and B[1, 1, 1, 1]
// it needs to be reshaped into A[1, 1, 5, 1] and B[1, 1, 1, 1]
if ((input0_shape.size() == 1) && input1_shape.empty())
{
transpose0 = true;
}
// If we have A[5] and B[5] here, in cldnn we have A[1, 1, 1, 5] and B[1, 1, 1, 5] // If we have A[5] and B[5] here, in cldnn we have A[1, 1, 1, 5] and B[1, 1, 1, 5]
// it needs to be reshaped into A[1, 1, 1, 5] and B[1, 1, 5, 1] // it needs to be reshaped into A[1, 1, 1, 5] and B[1, 1, 5, 1]
if (!input0_shape.empty() && (input1_shape.size() == 1)) if (!input0_shape.empty() && (input1_shape.size() == 1))
{ {
const string new_name = input1_name + "_reshaped"; transpose1 = true;
Shape new_shape = input1_shape;
new_shape.push_back(1);
const cldnn::tensor reshaped_tensor =
intelgpu_space::create_cldnn_tensor(new_shape);
const cldnn::reshape reshape_op(new_name, input1_name, reshaped_tensor);
topology.add(reshape_op);
input1_name = new_name;
} }
const cldnn::gemm dot_op(get_output_name(op), get_input_name(op, 0), input1_name); const cldnn::gemm dot_op(get_output_name(op),
get_input_name(op, 0),
get_input_name(op, 1),
transpose0,
transpose1);
topology.add(dot_op); topology.add(dot_op);
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment