Commit cf15ef32 authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: Compilation fix after pr1828 (#1892)

parent e7b4106e
...@@ -1105,12 +1105,12 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -1105,12 +1105,12 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
pad_interior); pad_interior);
break; break;
} }
case OP_TYPEID::BatchNormBackprop: case OP_TYPEID::BatchNormTrainingBackprop:
{ {
arguments_check(op, 6, 3); arguments_check(op, 6, 3);
const shared_ptr<op::BatchNormBackprop> batch_norm = const shared_ptr<op::BatchNormTrainingBackprop> batch_norm =
static_pointer_cast<op::BatchNormBackprop>(op); static_pointer_cast<op::BatchNormTrainingBackprop>(op);
const double eps = batch_norm->get_eps_value(); const double eps = batch_norm->get_eps_value();
do_create_mean(topology, do_create_mean(topology,
...@@ -1145,9 +1145,32 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -1145,9 +1145,32 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
get_output_name(op, 2)); get_output_name(op, 2));
break; break;
} }
case OP_TYPEID::BatchNorm: case OP_TYPEID::BatchNormInference:
{ {
const shared_ptr<op::BatchNorm> batch_norm = static_pointer_cast<op::BatchNorm>(op); const shared_ptr<op::BatchNormInference> batch_norm =
static_pointer_cast<op::BatchNormInference>(op);
const double eps = batch_norm->get_eps_value();
string mean_name;
string variance_name;
arguments_check(op, 5, 1);
do_batch_norm_operation(topology,
get_output_name(op),
get_output_type(op),
eps,
get_input_name(op, 2),
get_input_shape(op, 2),
get_input_name(op, 0),
get_input_name(op, 1),
get_input_name(op, 3),
get_input_name(op, 4));
break;
}
case OP_TYPEID::BatchNormTraining:
{
const shared_ptr<op::BatchNormTraining> batch_norm =
static_pointer_cast<op::BatchNormTraining>(op);
const double eps = batch_norm->get_eps_value(); const double eps = batch_norm->get_eps_value();
string mean_name; string mean_name;
string variance_name; string variance_name;
......
...@@ -9,6 +9,7 @@ backwards_dot_tensor_vector ...@@ -9,6 +9,7 @@ backwards_dot_tensor_vector
backwards_exp backwards_exp
backwards_maxpool_n2_c1_hw5_3x3_str2_max backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n4_c1_hw4_2x2_max backwards_maxpool_n4_c1_hw4_2x2_max
backwards_relu
backwards_replace_slice backwards_replace_slice
backwards_reverse_sequence_n3_c2_h3 backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2 backwards_reverse_sequence_n4d2c3h2w2
...@@ -28,6 +29,14 @@ max_pool_3d ...@@ -28,6 +29,14 @@ max_pool_3d
numeric_double_inf numeric_double_inf
numeric_double_nan numeric_double_nan
quantize quantize
quantize_ROUND_DOWN
quantize_ROUND_NEAREST_DOWNWARD
quantize_ROUND_NEAREST_TOWARD_EVEN
quantize_ROUND_NEAREST_TOWARD_ZERO
quantize_ROUND_NEAREST_UPWARD
quantize_ROUND_TOWARD_INFINITY
quantize_ROUND_TOWARD_ZERO
quantize_ROUND_UP
quantize_axes quantize_axes
quantize_clamp quantize_clamp
quantize_int8 quantize_int8
...@@ -48,8 +57,8 @@ reverse_sequence_n2c3h4w2 ...@@ -48,8 +57,8 @@ reverse_sequence_n2c3h4w2
reverse_sequence_n4c3h2w2 reverse_sequence_n4c3h2w2
reverse_sequence_n4d2c3h2w2 reverse_sequence_n4d2c3h2w2
select_and_scatter_3d_without_overlap select_and_scatter_3d_without_overlap
select_and_scatter_without_overlap
select_and_scatter_with_overlap select_and_scatter_with_overlap
select_and_scatter_without_overlap
topk_1d_max_all topk_1d_max_all
topk_1d_max_one topk_1d_max_one
topk_1d_max_partial topk_1d_max_partial
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment