Commit da352aa1 authored by shssf's avatar shssf Committed by Scott Cyphers

IntelGPU backend: BatchNormBackprop operation (#1443)

* IntelGPU backend: BatchNormBackprop operation

* PR1443. Requested refactoring done
parent 40ddf45a
......@@ -767,6 +767,45 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
pad_below,
pad_interior);
}
else if ("BatchNormBackprop" == op->description())
{
arguments_check(op, 6, 3);
const shared_ptr<op::BatchNormBackprop> batch_norm =
static_pointer_cast<op::BatchNormBackprop>(op);
const double eps = batch_norm->get_eps_value();
do_create_mean(topology,
get_output_name(op, 2), // d_beta
get_output_type(op, 2),
get_input_name(op, 5), // delta
get_input_shape(op, 5),
true);
do_create_variance_back(topology,
get_output_name(op, 1), // d_gamma
get_output_type(op, 1),
eps,
get_input_name(op, 2), // input
get_input_shape(op, 2),
get_input_name(op, 3), // gamma
get_input_name(op, 4), // beta
get_input_name(op, 5)); // delta
do_batch_norm_backprop_operation(topology,
get_input_shape(op, 2),
get_input_type(op, 2),
get_input_name(op, 0),
get_input_name(op, 1),
get_input_name(op, 2),
get_input_name(op, 3),
get_input_name(op, 4),
get_input_name(op, 5),
eps,
get_output_name(op, 0),
get_output_name(op, 1),
get_output_name(op, 2));
}
else if ("BatchNorm" == op->description())
{
const shared_ptr<op::BatchNorm> batch_norm = static_pointer_cast<op::BatchNorm>(op);
......@@ -788,14 +827,13 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
do_create_mean(topology,
mean_name,
get_input_shape(op),
get_output_type(op),
get_input_name(op, 2),
get_input_shape(op, 2));
get_input_shape(op, 2),
false);
do_create_variance(topology,
variance_name,
get_input_shape(op),
get_output_type(op),
get_input_name(op, 2),
get_input_shape(op, 2),
......@@ -814,13 +852,11 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
do_batch_norm_operation(topology,
get_output_name(op),
get_output_shape(op),
get_output_type(op),
eps,
get_input_name(op, 2),
get_input_shape(op, 2),
get_input_name(op, 0),
get_input_shape(op, 0),
get_input_name(op, 1),
mean_name,
variance_name);
......
......@@ -31,13 +31,11 @@ namespace ngraph
// nGraph uses channels in this operation but clDNN uses full input data
void do_batch_norm_operation(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
double eps,
const std::string& input_name,
const Shape& input_shape,
const std::string& gamma_name,
const Shape& gamma_shape,
const std::string& beta_name,
const std::string& mean_name,
const std::string& variance_name);
......@@ -45,19 +43,46 @@ namespace ngraph
// This creates mean of the input matrix by Channel axis
void do_create_mean(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape);
const Shape& input_shape,
bool backward);
// This creates mean of the input matrix by Channel axis
// This creates variance of the input matrix by Channel axis
void do_create_variance(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape,
const std::string& mean_name);
// This creates variance backprop of the input matrix by Channel axis
void do_create_variance_back(cldnn::topology& topology,
const std::string& output_name,
const element::Type& output_type,
double eps,
const std::string& input_name,
const Shape& input_shape,
const std::string& mean_name,
const std::string& variance_name,
const std::string& delta_name);
// This function uses "shape" parameter as input or output Shape
// Shape of all other calculated as first axis from the left
// Example: output[ 4, 3, 2, 8 ] means out_gamma[ 3 ]
void do_batch_norm_backprop_operation(cldnn::topology& topology,
const Shape& shape,
const element::Type& type,
const std::string& gamma_name,
const std::string& beta_name,
const std::string& input_name,
const std::string& mean_name,
const std::string& variance_name,
const std::string& delta_name,
double eps,
const std::string& output_name,
const std::string& output_gamma_name,
const std::string& output_beta_name);
}
}
}
......@@ -128,7 +128,8 @@ vector<size_t> runtime::intelgpu::generate_loops(codegen::CodeWriter& writer,
{
if (is_begin)
{
writer << "const unsigned i" << var_idx << " = get_global_id(" << var_idx << ");\n";
writer << "const unsigned i" << var_idx << " = get_global_id(" << var_idx
<< "); /*trip count " << i << "*/\n";
gws.push_back(i);
}
}
......
......@@ -28,7 +28,7 @@ backwards_sigmoid
backwards_sign
backwards_slice
backwards_tan
batchnorm_bprop_n4c3h2w2
backwards_tanh
batch_norm_one_output
batch_norm_three_outputs
broadcast_vector_rowwise_int64
......@@ -108,7 +108,6 @@ zero_sized_maximum
zero_sized_minimum
zero_sized_multiply
zero_sized_negative
zero_sized_not
zero_sized_not_equal
zero_sized_power
zero_sized_sign
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment