Commit 67248fdb authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Robert Kimball

IntelGPU backend: Custom kernels refactoring 3 (#2787)

parent 2b13ae40
......@@ -52,7 +52,6 @@
#include "ngraph/runtime/intelgpu/intelgpu_executable.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_kernels.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_layout.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_batchnorm.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_tensor_view.hpp"
#include "ngraph/runtime/intelgpu/visualize_tree.hpp"
......@@ -61,6 +60,7 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
......@@ -73,13 +73,20 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
......@@ -129,25 +136,13 @@ static OP_TYPEID get_typeid(const string& s)
return it->second;
}
static void arguments_check(const shared_ptr<Node>& op, size_t input, size_t output)
{
if (op->get_input_size() != input || op->get_output_size() != output)
{
ostringstream os;
os << "Operation \"" << op->description() << "\" input and output sizes mismatch."
<< " Expected input size=" << input << ", provided=" << op->get_input_size()
<< ". Expected output size=" << output << ", provided=" << op->get_output_size();
throw invalid_argument(os.str());
}
}
static void do_eltwise_operation(cldnn::topology& topology,
const shared_ptr<Node>& op,
const string& custom_op,
bool function_operation,
cldnn::eltwise_mode mode)
{
arguments_check(op, 2, 1);
runtime::intelgpu::arguments_check(op, 2, 1);
if (op->get_input_element_type(0) != element::f32 ||
op->get_input_element_type(1) != element::f32 ||
......@@ -180,7 +175,7 @@ static void do_cldnn_unary(cldnn::topology& topology,
cldnn_activation_func mode,
const cldnn_activation_additional_params& param = {0.f, 0.f})
{
arguments_check(op, 1, 1);
runtime::intelgpu::arguments_check(op, 1, 1);
const cldnn::activation cldnn_unary(
op->get_output_tensor_name(0), op->get_input_tensor_name(0), mode, param);
......@@ -190,7 +185,7 @@ static void do_cldnn_unary(cldnn::topology& topology,
static void
do_custom_unary(cldnn::topology& topology, const shared_ptr<Node>& op, const string& operation)
{
arguments_check(op, 1, 1);
runtime::intelgpu::arguments_check(op, 1, 1);
runtime::intelgpu::do_custom_unary_operation(topology,
op->get_input_tensor_name(0),
......@@ -209,7 +204,7 @@ static void do_universal_unary(cldnn::topology& topology,
bool force_custom = false,
const cldnn_activation_additional_params& param = {0.f, 0.f})
{
arguments_check(op, 1, 1);
runtime::intelgpu::arguments_check(op, 1, 1);
if (force_custom || (op->get_input_element_type(0) != element::f32))
{
......@@ -228,7 +223,7 @@ static void do_pooling_operation(cldnn::topology& topology,
const Shape& pad_below,
const cldnn::pooling_mode mode)
{
arguments_check(op, 1, 1);
runtime::intelgpu::arguments_check(op, 1, 1);
const cldnn::tensor output_size = intelgpu_space::create_cldnn_tensor(op->get_output_shape(0));
const cldnn::tensor input_offset = intelgpu_space::create_cldnn_offset(pad_below);
......@@ -245,22 +240,12 @@ static void do_pooling_operation(cldnn::topology& topology,
topology.add(cldnn_pooling);
}
static void do_logical_operation(cldnn::topology& topology,
const shared_ptr<Node>& op,
const string& operation)
template <typename OP>
static void do_logical_operation(runtime::intelgpu::CustomKernels& kern, const shared_ptr<Node>& op)
{
arguments_check(op, 2, 1);
runtime::intelgpu::arguments_check(op, 2, 1);
runtime::intelgpu::do_logic_kernel(topology,
op->get_input_tensor_name(0),
op->get_input_shape(0),
op->get_input_element_type(0),
op->get_input_tensor_name(1),
op->get_input_shape(1),
op->get_output_tensor_name(0),
op->get_output_shape(0),
op->get_output_element_type(0),
operation);
kern.emit<OP>(static_pointer_cast<OP>(op));
}
// This function needed to only change the name of the data in topology
......@@ -1246,42 +1231,42 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::Greater:
{
do_logical_operation(topology, op, " > ");
do_logical_operation<op::Greater>(kern, op);
break;
}
case OP_TYPEID::GreaterEq:
{
do_logical_operation(topology, op, " >= ");
do_logical_operation<op::GreaterEq>(kern, op);
break;
}
case OP_TYPEID::Equal:
{
do_logical_operation(topology, op, " == ");
do_logical_operation<op::Equal>(kern, op);
break;
}
case OP_TYPEID::NotEqual:
{
do_logical_operation(topology, op, " != ");
do_logical_operation<op::NotEqual>(kern, op);
break;
}
case OP_TYPEID::Less:
{
do_logical_operation(topology, op, " < ");
do_logical_operation<op::Less>(kern, op);
break;
}
case OP_TYPEID::LessEq:
{
do_logical_operation(topology, op, " <= ");
do_logical_operation<op::LessEq>(kern, op);
break;
}
case OP_TYPEID::And:
{
do_logical_operation(topology, op, " && ");
do_logical_operation<op::And>(kern, op);
break;
}
case OP_TYPEID::Or:
{
do_logical_operation(topology, op, " || ");
do_logical_operation<op::Or>(kern, op);
break;
}
case OP_TYPEID::Pad:
......@@ -1305,40 +1290,8 @@ shared_ptr<runtime::Executable>
{
arguments_check(op, 6, 3);
const shared_ptr<op::BatchNormTrainingBackprop> batch_norm =
static_pointer_cast<op::BatchNormTrainingBackprop>(op);
const double eps = batch_norm->get_eps_value();
do_create_mean(topology,
op->get_output_tensor_name(2), // d_beta
op->get_output_element_type(2),
op->get_input_tensor_name(5), // delta
op->get_input_shape(5),
true);
do_create_variance_back(topology,
op->get_output_tensor_name(1), // d_gamma
op->get_output_element_type(1),
eps,
op->get_input_tensor_name(2), // input
op->get_input_shape(2),
op->get_input_tensor_name(3), // gamma
op->get_input_tensor_name(4), // beta
op->get_input_tensor_name(5)); // delta
do_batch_norm_backprop_operation(topology,
op->get_input_shape(2),
op->get_input_element_type(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
op->get_input_tensor_name(2),
op->get_input_tensor_name(3),
op->get_input_tensor_name(4),
op->get_input_tensor_name(5),
eps,
op->get_output_tensor_name(0),
op->get_output_tensor_name(1),
op->get_output_tensor_name(2));
kern.emit<op::BatchNormTrainingBackprop>(
static_pointer_cast<op::BatchNormTrainingBackprop>(op));
break;
}
case OP_TYPEID::BatchNormInference:
......@@ -1367,16 +1320,7 @@ shared_ptr<runtime::Executable>
if (proceed_with_custom_kernel || (op->get_input_shape(2).size() != 4) ||
(op->get_input_element_type(0) != ngraph::element::f32))
{
do_batch_norm_operation(topology,
op->get_output_tensor_name(0),
op->get_output_element_type(0),
eps,
op->get_input_tensor_name(2),
op->get_input_shape(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
op->get_input_tensor_name(3),
op->get_input_tensor_name(4));
kern.emit<op::BatchNormInference>(bnorm);
}
else
{
......@@ -1400,61 +1344,7 @@ shared_ptr<runtime::Executable>
if ((op->get_input_shape(2).size() != 4) ||
(op->get_input_element_type(0) != ngraph::element::f32))
{
string mean_name;
string variance_name;
if (op->get_inputs().size() < 3 || op->get_outputs().empty())
{
arguments_check(op, 3, 1); // throw exception in this case
}
if (op->get_outputs().size() == 3)
{
arguments_check(op, 3, 3);
mean_name = op->get_output_tensor_name(1);
variance_name = op->get_output_tensor_name(2);
do_create_mean(topology,
mean_name,
op->get_output_element_type(0),
op->get_input_tensor_name(2),
op->get_input_shape(2),
false);
do_create_variance(topology,
variance_name,
op->get_output_element_type(0),
op->get_input_tensor_name(2),
op->get_input_shape(2),
mean_name);
}
if (op->get_outputs().size() == 1 || op->get_outputs().size() == 3)
{
if (mean_name.empty() || variance_name.empty())
{
arguments_check(op, 5, 1);
mean_name = op->get_input_tensor_name(3);
variance_name = op->get_input_tensor_name(4);
}
do_batch_norm_operation(topology,
op->get_output_tensor_name(0),
op->get_output_element_type(0),
eps,
op->get_input_tensor_name(2),
op->get_input_shape(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
mean_name,
variance_name);
}
else
{
arguments_check(op, 5, 1); // throw exception in this case
}
kern.emit<op::BatchNormTraining>(bnorm);
}
else
{
......
......@@ -44,3 +44,15 @@ void runtime::intelgpu::CustomKernels::queue_krnl(const krnl_info& krnl_info,
stream.add(kernel_item);
}
}
void runtime::intelgpu::arguments_check(const shared_ptr<Node>& op, size_t input, size_t output)
{
if (op->get_input_size() != input || op->get_output_size() != output)
{
ostringstream os;
os << "Operation \"" << op->description() << "\" input and output sizes mismatch."
<< " Expected input size=" << input << ", provided=" << op->get_input_size()
<< ". Expected output size=" << output << ", provided=" << op->get_output_size();
throw invalid_argument(os.str());
}
}
......@@ -24,11 +24,20 @@
#include "ngraph/node.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/slice.hpp"
......@@ -43,6 +52,8 @@ namespace ngraph
{
class CustomKernelInfo;
class CustomKernels;
void arguments_check(const std::shared_ptr<Node>& op, size_t input, size_t output);
}
}
}
......@@ -107,13 +118,24 @@ private:
void queue_krnl(const krnl_info& krn_info, const std::shared_ptr<Node>& op);
krnl_info build_krnl(const std::shared_ptr<op::All>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::And>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Any>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::BatchNormInference>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::BatchNormTraining>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::BatchNormTrainingBackprop>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Broadcast>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Convolution>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::ConvolutionBackpropData>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::ConvolutionBackpropFilters>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Equal>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Greater>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::GreaterEq>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Less>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::LessEq>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Max>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Min>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::NotEqual>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Or>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Product>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Select>& op) const;
krnl_info build_krnl(const std::shared_ptr<op::Slice>& op) const;
......
......@@ -14,21 +14,15 @@
// limitations under the License.
//*****************************************************************************
#include <CPP/batch_norm.hpp>
#include <CPP/concatenation.hpp>
#include <CPP/custom_gpu_primitive.hpp>
#include <CPP/scale.hpp>
#include <CPP/split.hpp>
#include "ngraph/code_writer.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_layout.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_batchnorm.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_kernels.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.hpp"
#include "ngraph/op/batch_norm.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::runtime::intelgpu;
// According to the documentation, input data channel is always being axis 1
// Assumed the second dimension from the left. Example {0, 1, 0, 0} or {0, 1}
......@@ -39,9 +33,8 @@ static Shape get_channel_shape(const Shape& shape, const string& function_name)
{
if (shape.size() < channel_axis + 1)
{
const string err = "intelgpu::" + function_name + "() input_shape" +
runtime::intelgpu::array_dims(shape) + " should be at least " +
to_string(channel_axis + 1) + "D.";
const string err = "intelgpu::" + function_name + "() input_shape" + array_dims(shape) +
" should be at least " + to_string(channel_axis + 1) + "D.";
throw invalid_argument(err);
}
......@@ -53,15 +46,14 @@ static size_t get_idx_size(const Shape& shape, size_t pos)
return accumulate(shape.cbegin() + pos, shape.cend(), 1, multiplies<size_t>());
}
void runtime::intelgpu::do_create_mean(cldnn::topology& topology,
const string& output_name,
const element::Type& output_type,
const string& input_name,
const Shape& input_shape,
bool backward)
// This creates mean of the input matrix by Channel axis
static CustomKernels::krnl_info do_create_mean(const string& output_name,
const element::Type& output_type,
const string& input_name,
const Shape& input_shape,
bool backward)
{
const Shape channel_shape = get_channel_shape(input_shape, "create_mean");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, channel_shape);
const string entry_point_name = "create_mean_" + output_name;
const size_t output_counts = shape_size<Shape>(input_shape) / input_shape.at(channel_axis);
const string kernel_data_type = get_opencl_type_name(output_type);
......@@ -118,26 +110,23 @@ void runtime::intelgpu::do_create_mean(cldnn::topology& topology,
} // Main function body
writer.block_end();
const cldnn::custom_gpu_primitive op_mean(output_name,
{input_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(1, 1),
"",
layout,
{1});
topology.add(op_mean);
const CustomKernelInfo op_bcast_sum(output_name,
channel_shape,
output_type,
{input_name},
{writer.get_code()},
entry_point_name);
return {op_bcast_sum};
}
void runtime::intelgpu::do_create_variance(cldnn::topology& topology,
const string& output_name,
const element::Type& output_type,
const string& input_name,
const Shape& input_shape,
const std::string& mean_name)
// This creates variance of the input matrix by Channel axis
static CustomKernels::krnl_info do_create_variance(const string& output_name,
const element::Type& output_type,
const string& input_name,
const Shape& input_shape,
const std::string& mean_name)
{
const Shape channel_shape = get_channel_shape(input_shape, "create_variance");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, channel_shape);
const string entry_point_name = "create_variance_" + output_name;
const size_t output_counts = shape_size<Shape>(input_shape) / input_shape.at(channel_axis);
const string kernel_data_type = get_opencl_type_name(output_type);
......@@ -194,30 +183,26 @@ void runtime::intelgpu::do_create_variance(cldnn::topology& topology,
} // Main function body
writer.block_end();
const cldnn::custom_gpu_primitive op_variance(output_name,
{input_name, mean_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(2, 1),
"",
layout,
{1});
topology.add(op_variance);
const CustomKernelInfo op_variance(output_name,
channel_shape,
output_type,
{input_name, mean_name},
{writer.get_code()},
entry_point_name);
return {op_variance};
}
void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
const string& output_name,
const element::Type& output_type,
double eps,
const string& input_name,
const Shape& input_shape,
const string& gamma_name,
const string& beta_name,
const string& mean_name_inp,
const string& variance_name_inp)
static CustomKernels::krnl_info do_batch_norm_operation(const string& output_name,
const element::Type& output_type,
double eps,
const string& input_name,
const Shape& input_shape,
const string& gamma_name,
const string& beta_name,
const string& mean_name_inp,
const string& variance_name_inp)
{
const Shape channel_shape = get_channel_shape(input_shape, "batch_norm");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, input_shape);
const vector<size_t> gws(input_shape.begin(), input_shape.begin() + 2);
const string entry_point_name = "batch_norm_" + output_name;
const string kernel_data_type = get_opencl_type_name(output_type);
......@@ -265,32 +250,30 @@ void runtime::intelgpu::do_batch_norm_operation(cldnn::topology& topology,
} // Main function body
writer.block_end();
const vector<cldnn::primitive_id>& inputs = {
const vector<string>& inputs = {
input_name, gamma_name, beta_name, mean_name_inp, variance_name_inp};
const cldnn::custom_gpu_primitive op_batch_norm(output_name,
inputs,
{writer.get_code()},
entry_point_name,
get_kernel_args(5, 1),
"",
layout,
gws,
{1, 1, 1});
topology.add(op_batch_norm);
const CustomKernelInfo op_batch_norm(output_name,
input_shape,
output_type,
inputs,
{writer.get_code()},
entry_point_name,
gws,
{1, 1, 1});
return {op_batch_norm};
}
void runtime::intelgpu::do_create_variance_back(cldnn::topology& topology,
const string& output_name,
const element::Type& output_type,
double eps,
const string& input_name,
const Shape& input_shape,
const string& mean_name,
const string& variance_name,
const string& delta_name)
// This creates variance backprop of the input matrix by Channel axis
static CustomKernels::krnl_info do_create_variance_back(const string& output_name,
const element::Type& output_type,
double eps,
const string& input_name,
const Shape& input_shape,
const string& mean_name,
const string& variance_name,
const string& delta_name)
{
const Shape channel_shape = get_channel_shape(input_shape, "create_variance_back");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, channel_shape);
const string entry_point_name = "create_variance_back_" + output_name;
const string kernel_data_type = get_opencl_type_name(output_type);
CodeWriter writer;
......@@ -343,34 +326,34 @@ void runtime::intelgpu::do_create_variance_back(cldnn::topology& topology,
} // Main function body
writer.block_end();
const vector<cldnn::primitive_id>& inputs = {input_name, delta_name, mean_name, variance_name};
const cldnn::custom_gpu_primitive op_create_variance_back(output_name,
inputs,
{writer.get_code()},
entry_point_name,
get_kernel_args(4, 1),
"",
layout,
gws);
topology.add(op_create_variance_back);
const vector<string>& inputs = {input_name, delta_name, mean_name, variance_name};
const CustomKernelInfo op_create_variance_back(output_name,
channel_shape,
output_type,
inputs,
{writer.get_code()},
entry_point_name,
gws);
return {op_create_variance_back};
}
void runtime::intelgpu::do_batch_norm_backprop_operation(cldnn::topology& topology,
const Shape& shape,
const element::Type& type,
const string& gamma_name,
const string& beta_name,
const string& input_name,
const string& mean_name,
const string& variance_name,
const string& delta_name,
double eps,
const string& output_name,
const string& output_gamma_name,
const string& output_beta_name)
// This function uses "shape" parameter as input or output Shape
// Shape of all other calculated as first axis from the left
// Example: output[ 4, 3, 2, 8 ] means out_gamma[ 3 ]
static CustomKernels::krnl_info do_batch_norm_backprop_operation(const Shape& shape,
const element::Type& type,
const string& gamma_name,
const string& beta_name,
const string& input_name,
const string& mean_name,
const string& variance_name,
const string& delta_name,
double eps,
const string& output_name,
const string& output_gamma_name,
const string& output_beta_name)
{
const Shape channel_shape = get_channel_shape(shape, "batch_norm_backprop");
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(type, shape);
const string entry_point_name = "batch_norm_backprop_" + output_name;
const size_t r_axes_size = shape_size(shape) / shape_size(channel_shape);
const string kernel_data_type = get_opencl_type_name(type);
......@@ -391,7 +374,7 @@ void runtime::intelgpu::do_batch_norm_backprop_operation(cldnn::topology& topolo
{ // Main function body
// Main loops
gws = runtime::intelgpu::generate_loops(writer, shape, true);
gws = generate_loops(writer, shape, true);
writer << kernel_data_type << " stddev = sqrt(variance[i" << channel_axis << "] + " << eps
<< ");\n";
......@@ -404,25 +387,139 @@ void runtime::intelgpu::do_batch_norm_backprop_operation(cldnn::topology& topolo
<< channel_axis << "]) / " << r_axes_size << ");\n";
// Closing brackets for main loops
runtime::intelgpu::generate_loops(writer, shape, false);
generate_loops(writer, shape, false);
} // Main function body
writer.block_end();
const vector<cldnn::primitive_id>& inputs = {input_name,
delta_name,
mean_name,
variance_name,
gamma_name,
output_gamma_name,
output_beta_name};
const cldnn::custom_gpu_primitive op_batch_norm_backprop(output_name,
inputs,
{writer.get_code()},
entry_point_name,
get_kernel_args(7, 1),
"",
layout,
gws);
topology.add(op_batch_norm_backprop);
const vector<string>& inputs = {input_name,
delta_name,
mean_name,
variance_name,
gamma_name,
output_gamma_name,
output_beta_name};
const CustomKernelInfo op_batch_norm_backprop(
output_name, shape, type, inputs, {writer.get_code()}, entry_point_name, gws);
return {op_batch_norm_backprop};
}
CustomKernels::krnl_info
CustomKernels::build_krnl(const shared_ptr<op::BatchNormInference>& op) const
{
return do_batch_norm_operation(op->get_output_tensor_name(0),
op->get_output_element_type(0),
op->get_eps_value(),
op->get_input_tensor_name(2),
op->get_input_shape(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
op->get_input_tensor_name(3),
op->get_input_tensor_name(4));
}
CustomKernels::krnl_info
CustomKernels::build_krnl(const shared_ptr<op::BatchNormTraining>& op) const
{
CustomKernels::krnl_info result;
string mean_name;
string variance_name;
if (op->get_inputs().size() < 3 || op->get_outputs().empty())
{
arguments_check(op, 3, 1); // throw exception in this case
}
if (op->get_outputs().size() == 3)
{
arguments_check(op, 3, 3);
mean_name = op->get_output_tensor_name(1);
variance_name = op->get_output_tensor_name(2);
CustomKernels::krnl_info mean = do_create_mean(mean_name,
op->get_output_element_type(0),
op->get_input_tensor_name(2),
op->get_input_shape(2),
false);
result.insert(result.end(), mean.begin(), mean.end());
CustomKernels::krnl_info variance = do_create_variance(variance_name,
op->get_output_element_type(0),
op->get_input_tensor_name(2),
op->get_input_shape(2),
mean_name);
result.insert(result.end(), variance.begin(), variance.end());
}
if (op->get_outputs().size() == 1 || op->get_outputs().size() == 3)
{
if (mean_name.empty() || variance_name.empty())
{
arguments_check(op, 5, 1);
mean_name = op->get_input_tensor_name(3);
variance_name = op->get_input_tensor_name(4);
}
CustomKernels::krnl_info batch_norm =
do_batch_norm_operation(op->get_output_tensor_name(0),
op->get_output_element_type(0),
op->get_eps_value(),
op->get_input_tensor_name(2),
op->get_input_shape(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
mean_name,
variance_name);
result.insert(result.end(), batch_norm.begin(), batch_norm.end());
}
else
{
arguments_check(op, 5, 1); // throw exception in this case
}
return result;
}
CustomKernels::krnl_info
CustomKernels::build_krnl(const shared_ptr<op::BatchNormTrainingBackprop>& op) const
{
CustomKernels::krnl_info result;
CustomKernels::krnl_info mean = do_create_mean(op->get_output_tensor_name(2), // d_beta
op->get_output_element_type(2),
op->get_input_tensor_name(5), // delta
op->get_input_shape(5),
true);
result.insert(result.end(), mean.begin(), mean.end());
CustomKernels::krnl_info variance =
do_create_variance_back(op->get_output_tensor_name(1), // d_gamma
op->get_output_element_type(1),
op->get_eps_value(),
op->get_input_tensor_name(2), // input
op->get_input_shape(2),
op->get_input_tensor_name(3), // gamma
op->get_input_tensor_name(4), // beta
op->get_input_tensor_name(5)); // delta
result.insert(result.end(), variance.begin(), variance.end());
CustomKernels::krnl_info batch_norm =
do_batch_norm_backprop_operation(op->get_input_shape(2),
op->get_input_element_type(2),
op->get_input_tensor_name(0),
op->get_input_tensor_name(1),
op->get_input_tensor_name(2),
op->get_input_tensor_name(3),
op->get_input_tensor_name(4),
op->get_input_tensor_name(5),
op->get_eps_value(),
op->get_output_tensor_name(0),
op->get_output_tensor_name(1),
op->get_output_tensor_name(2));
result.insert(result.end(), batch_norm.begin(), batch_norm.end());
return result;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <CPP/topology.hpp>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace intelgpu
{
// This implements BatchNorm nGraph operation
// nGraph uses channels in this operation but clDNN uses full input data
void do_batch_norm_operation(cldnn::topology& topology,
const std::string& output_name,
const element::Type& output_type,
double eps,
const std::string& input_name,
const Shape& input_shape,
const std::string& gamma_name,
const std::string& beta_name,
const std::string& mean_name,
const std::string& variance_name);
// This creates mean of the input matrix by Channel axis
void do_create_mean(cldnn::topology& topology,
const std::string& output_name,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape,
bool backward);
// This creates variance of the input matrix by Channel axis
void do_create_variance(cldnn::topology& topology,
const std::string& output_name,
const element::Type& output_type,
const std::string& input_name,
const Shape& input_shape,
const std::string& mean_name);
// This creates variance backprop of the input matrix by Channel axis
void do_create_variance_back(cldnn::topology& topology,
const std::string& output_name,
const element::Type& output_type,
double eps,
const std::string& input_name,
const Shape& input_shape,
const std::string& mean_name,
const std::string& variance_name,
const std::string& delta_name);
// This function uses "shape" parameter as input or output Shape
// Shape of all other calculated as first axis from the left
// Example: output[ 4, 3, 2, 8 ] means out_gamma[ 3 ]
void do_batch_norm_backprop_operation(cldnn::topology& topology,
const Shape& shape,
const element::Type& type,
const std::string& gamma_name,
const std::string& beta_name,
const std::string& input_name,
const std::string& mean_name,
const std::string& variance_name,
const std::string& delta_name,
double eps,
const std::string& output_name,
const std::string& output_gamma_name,
const std::string& output_beta_name);
}
}
}
......@@ -1277,18 +1277,16 @@ CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::Select>&
return {krn_ret};
}
void runtime::intelgpu::do_logic_kernel(cldnn::topology& topology,
const string& input0_name,
const Shape& input0_shape,
const element::Type& input0_type,
const string& input1_name,
const Shape& input1_shape,
const string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const string& operation)
static CustomKernels::krnl_info do_logic_kernel(const shared_ptr<Node>& op, const string& operation)
{
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const string& input0_name = op->get_input_tensor_name(0);
const Shape& input0_shape = op->get_input_shape(0);
const element::Type& input0_type = op->get_input_element_type(0);
const string& input1_name = op->get_input_tensor_name(1);
const Shape& input1_shape = op->get_input_shape(1);
const string& output_name = op->get_output_tensor_name(0);
const Shape& output_shape = op->get_output_shape(0);
const element::Type& output_type = op->get_output_element_type(0);
const string entry_point_name = "logic_" + output_name;
CodeWriter writer;
vector<size_t> gws;
......@@ -1313,15 +1311,14 @@ void runtime::intelgpu::do_logic_kernel(cldnn::topology& topology,
}
writer.block_end();
const cldnn::custom_gpu_primitive op_logical(output_name,
{input0_name, input1_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(2, 1),
"",
layout,
gws);
topology.add(op_logical);
const CustomKernelInfo op_logical(output_name,
output_shape,
output_type,
{input0_name, input1_name},
{writer.get_code()},
entry_point_name,
gws);
return {op_logical};
}
void runtime::intelgpu::do_eltwise_kernel(cldnn::topology& topology,
......@@ -2333,3 +2330,43 @@ size_t runtime::intelgpu::get_max_memory_rss()
return result;
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::And>& op) const
{
return do_logic_kernel(op, " && ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::Equal>& op) const
{
return do_logic_kernel(op, " == ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::Greater>& op) const
{
return do_logic_kernel(op, " > ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::GreaterEq>& op) const
{
return do_logic_kernel(op, " >= ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::Less>& op) const
{
return do_logic_kernel(op, " < ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::LessEq>& op) const
{
return do_logic_kernel(op, " <= ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::NotEqual>& op) const
{
return do_logic_kernel(op, " != ");
}
CustomKernels::krnl_info CustomKernels::build_krnl(const shared_ptr<op::Or>& op) const
{
return do_logic_kernel(op, " || ");
}
......@@ -100,17 +100,6 @@ namespace ngraph
const element::Type& output_type,
size_t concat_axis);
void do_logic_kernel(cldnn::topology& topology,
const std::string& input0_name,
const Shape& input0_shape,
const element::Type& input0_type,
const std::string& input1_name,
const Shape& input1_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& operation);
void do_eltwise_kernel(cldnn::topology& topology,
const std::string& input0_name,
const Shape& input0_shape,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment