Commit 45b50d06 authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: BatchNorm operation completly redeveloped (#1318)

parent 39278e7d
......@@ -533,35 +533,58 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
}
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const string& gamma_name = op->get_inputs().at(0).get_tensor().get_name();
const Shape& gamma_shape = op->get_inputs().at(0).get_shape();
const string& beta_name = op->get_inputs().at(1).get_tensor().get_name();
const string& input_name = op->get_inputs().at(2).get_tensor().get_name();
const Shape& input_shape = op->get_inputs().at(2).get_shape();
string mean_name;
string variance_name;
if (op->get_outputs().size() == 3)
{
arguments_check(op, 3, 3);
mean_name = op->get_outputs().at(1).get_tensor().get_name();
variance_name = op->get_outputs().at(2).get_tensor().get_name();
do_create_mean(
topology, mean_name, gamma_shape, output_type, input_name, input_shape);
do_create_variance(topology,
variance_name,
gamma_shape,
output_type,
input_name,
input_shape,
mean_name);
}
if (op->get_outputs().size() == 1)
if (op->get_outputs().size() == 1 || op->get_outputs().size() == 3)
{
if (mean_name.empty() || variance_name.empty())
{
arguments_check(op, 5, 1);
const string& mean_name = op->get_inputs().at(3).get_tensor().get_name();
const string& variance_name = op->get_inputs().at(4).get_tensor().get_name();
mean_name = op->get_inputs().at(3).get_tensor().get_name();
variance_name = op->get_inputs().at(4).get_tensor().get_name();
}
do_batch_norm_operation(topology,
output_name,
output_shape,
output_type,
eps,
input_name,
input_shape,
gamma_name,
gamma_shape,
beta_name,
mean_name,
variance_name);
}
else if (op->get_outputs().size() == 3)
{
arguments_check(op, 3, 3);
do_batch_norm_operation(
topology, output_name, eps, input_name, input_shape, gamma_name, beta_name);
}
else
{
arguments_check(op, 5, 1); // throw exception in this case
......
......@@ -19,6 +19,7 @@
#include <CPP/topology.hpp>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
......@@ -27,22 +28,36 @@ namespace ngraph
namespace intelgpu
{
// This implements BatchNorm nGraph operation
// Since nGraph uses channels in this operation but clDNN uses full input data
// at one time we have to use following algorithm:
// 1. Split all input data arrays into several matrices by channel axis
// 2. Independently do cldnn::batch_norm on particular matrix
// 3. Every result of the cldnn::batch_norm must be scaled and
// shifted because cldnn::batch_norm dosn't use gamma and beta
// 4. Concatenate all results into output matrix by channel axis
// nGraph uses channels in this operation but clDNN uses full input data
void do_batch_norm_operation(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
double eps,
const std::string& input_name,
const Shape& input_shape,
const std::string& gamma_name,
const Shape& gamma_shape,
const std::string& beta_name,
const std::string& mean_name = std::string(),
const std::string& variance_name = std::string());
const std::string& mean_name,
const std::string& variance_name);
// This creates mean of the input matrix by Channel axis
void do_create_mean(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape);
// This creates mean of the input matrix by Channel axis
void do_create_variance(cldnn::topology& topology,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape,
const std::string& mean_name);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment