Commit 8c6a5be0 authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: Convolution workaround operations (#1402)

parent a3a9a9fa
......@@ -21,6 +21,7 @@ set(SRC
intelgpu_op_batchnorm.cpp
intelgpu_op_broadcast.cpp
intelgpu_op_custom_kernels.cpp
intelgpu_op_convolution.cpp
code_writer.cpp
)
......
......@@ -34,6 +34,7 @@
#include "ngraph/runtime/intelgpu/intelgpu_layout.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_batchnorm.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_broadcast.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_convolution.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_tensor_view.hpp"
......@@ -758,62 +759,107 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
arguments_check(op, 2, 1);
const shared_ptr<op::Convolution> conv_op = static_pointer_cast<op::Convolution>(op);
const Strides& conv_stride = conv_op->get_window_movement_strides();
const Strides& conv_dilation = conv_op->get_window_dilation_strides();
const CoordinateDiff& conv_padding_below = conv_op->get_padding_below();
const CoordinateDiff& conv_padding_above = conv_op->get_padding_above();
const Strides& conv_data_dilation = conv_op->get_data_dilation_strides();
if (conv_stride.size() > 2)
const Strides& win_stride = conv_op->get_window_movement_strides();
const Strides& win_dilation = conv_op->get_window_dilation_strides();
const Strides& data_dilation = conv_op->get_data_dilation_strides();
const CoordinateDiff& pad_below = conv_op->get_padding_below();
const CoordinateDiff& pad_above = conv_op->get_padding_above();
// clDNN has quite limited support for Convolution operation
// following are the checks to go with workaround
if ((win_stride.size() > 2) || (pad_below.size() > 2 || pad_above.size() > 2) ||
(pad_below.at(0) != pad_above.at(0) || pad_below.at(1) != pad_above.at(1)) ||
(win_dilation.size() > 2) ||
(data_dilation.size() > 2 || data_dilation.at(0) != 1 || data_dilation.at(1) != 1))
{
ostringstream os;
os << "Unsupported strides for \"" << op->description() << '\"';
throw std::invalid_argument(os.str());
do_convolution_operation(topology,
get_input_name(op, 0),
get_input_shape(op, 0),
get_input_name(op, 1),
get_input_shape(op, 1),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below(),
conv_op->get_window_movement_strides(),
conv_op->get_window_dilation_strides(),
conv_op->get_data_dilation_strides(),
0,
1,
1,
"input[batch][input_channel]",
"filter[output_channel][input_channel]",
"output[batch][output_channel]",
false);
}
if (conv_padding_below.size() > 2 || conv_padding_above.size() > 2)
else
{
ostringstream os;
os << "Unsupported padding for \"" << op->description() << '\"';
throw std::invalid_argument(os.str());
const cldnn::tensor input_offset(0, 0, -pad_below.at(1), -pad_below.at(0));
const cldnn::tensor strides(1, 1, win_stride.at(1), win_stride.at(0));
const cldnn::tensor dilation(1, 1, win_dilation.at(1), win_dilation.at(0));
const cldnn::convolution cldnn_conv(get_output_name(op),
get_input_name(op, 0),
{get_input_name(op, 1)},
strides,
input_offset,
dilation);
topology.add(cldnn_conv);
}
}
else if ("ConvolutionBackpropFilters" == op->description())
{
arguments_check(op, 2, 1);
//TODO: Further clDNN version will work with different paddings above and below
if (conv_padding_below.at(0) != conv_padding_above.at(0) ||
conv_padding_below.at(1) != conv_padding_above.at(1))
{
ostringstream os;
os << "Paddings above and below are different for \"" << op->description() << '\"';
throw std::invalid_argument(os.str());
}
const shared_ptr<op::ConvolutionBackpropFilters> conv_op =
static_pointer_cast<op::ConvolutionBackpropFilters>(op);
if (conv_dilation.size() > 2)
{
ostringstream os;
os << "Unsupported dilation for \"" << op->description() << '\"';
throw std::invalid_argument(os.str());
}
do_convolution_operation(topology,
get_input_name(op, 0),
get_input_shape(op, 0),
get_input_name(op, 1),
get_input_shape(op, 1),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
1,
0,
0,
"input[input_channel][batch]",
"filter[input_channel][output_channel]",
"output[output_channel][batch]",
false);
}
else if ("ConvolutionBackpropData" == op->description())
{
arguments_check(op, 2, 1);
if (conv_data_dilation.size() > 2 || conv_data_dilation.at(0) != 1 ||
conv_data_dilation.at(1) != 1)
{
ostringstream os;
os << "Unsupported data dilation for \"" << op->description() << '\"';
throw std::invalid_argument(os.str());
}
const shared_ptr<op::ConvolutionBackpropData> conv_op =
static_pointer_cast<op::ConvolutionBackpropData>(op);
const cldnn::tensor input_offset(
0, 0, -conv_padding_below.at(1), -conv_padding_below.at(0));
const cldnn::tensor strides(1, 1, conv_stride.at(1), conv_stride.at(0));
const cldnn::tensor dilation(1, 1, conv_dilation.at(1), conv_dilation.at(0));
const cldnn::convolution cldnn_conv(get_output_name(op),
get_input_name(op, 0),
{get_input_name(op, 1)},
strides,
input_offset,
dilation);
topology.add(cldnn_conv);
do_convolution_operation(topology,
get_input_name(op, 1),
get_input_shape(op, 1),
get_input_name(op, 0),
get_input_shape(op, 0),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
conv_op->get_padding_below_backward(),
conv_op->get_window_movement_strides_backward(),
conv_op->get_window_dilation_strides_backward(),
conv_op->get_data_dilation_strides_backward(),
0,
1,
1,
"input[batch][input_channel]",
"filter[input_channel][output_channel]",
"output[batch][output_channel]",
true);
}
else if ("Min" == op->description())
{
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <CPP/custom_gpu_primitive.hpp>
#include "ngraph/runtime/intelgpu/code_writer.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_layout.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_convolution.hpp"
#include "ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.hpp"
using namespace std;
using namespace ngraph;
// this is duplication of the runtime::intelgpu::access_dims
// needs to be merged but not at the same time as this new code
static string array_dim(const Shape& dimentions, const string& var = "i", bool is_reversed = false)
{
size_t var_idx = 0;
string buffer;
for (auto const& i : dimentions)
{
if (is_reversed)
{
buffer += "[" + to_string(i) + " - " + var + to_string(var_idx) + " - 1]";
}
else
{
buffer += "[" + var + to_string(var_idx) + "]";
}
++var_idx;
}
if (buffer.empty())
{ // it means scalar
buffer = "[0]";
}
return buffer;
}
// Padding, Strides and dilation are quite nice explained
// with animations here https://github.com/vdumoulin/conv_arithmetic
//
// batch axes for both input data and output data are 0
// input channel axes for both input data and filters are 1
// output channel axes for filters is 0
// output channel axis for output data is 1
// Example (Convolution):
// data[ 2, 1, 3, 5, 8 ]
// filter[ 2, 1, 2, 2, 3 ]
// output[ 2, 2, 2, 4, 6 ]
// it is like
// data[ batch, data_channel, 3, 5, 8 ]
// filter[ output_channel, data_channel, 2, 2, 3 ]
// output[ batch, output_channel, 2, 4, 6 ]
//
// Example (ConvolutionBackpropFilters):
// data[ 2, 1, 3, 5 ]
// filter[ 2, 2, 2, 4 ]
// output[ 2, 1, 2, 2 ]
// it is like
// data[ data_channel, batch, 3, 5 ]
// filter[ data_channel, output_channel, 2, 4 ]
// output[ output_channel, batch, 2, 2 ]
//
// Example (ConvolutionBackpropData):
// data[ 2, 2, 2, 4 ]
// filter[ 2, 1, 2, 2 ]
// output[ 2, 1, 3, 5 ]
// pad_below[ 1, 1 ]
// pad_above[ 1, 1 ]
// it is like
// data[ batch, data_channel, 2, 4 ]
// filter[ data_channel, output_channel, 2, 2 ]
// output[ batch, output_channel, 3, 5 ]
void runtime::intelgpu::do_convolution_operation(cldnn::topology& topology,
const string& input_name,
const Shape& input_shape,
const string& filter_name,
const Shape& filter_shape,
const string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const CoordinateDiff& pad_below,
const Strides& win_stride,
const Strides& win_dilation,
const Strides& data_dilation,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t output_channel_axis_result,
const string& input_order,
const string& filter_order,
const string& output_order,
bool reverse_filter)
{
const string& default_pad_value = "0.0f";
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const string entry_point_name = "convolution_" + output_name;
const Shape input_data(input_shape.cbegin() + 2, input_shape.cend());
const Shape filter_data(filter_shape.cbegin() + 2, filter_shape.cend());
const Shape output_data(output_shape.cbegin() + 2, output_shape.cend());
codegen::CodeWriter writer;
vector<size_t> gws;
writer << "__kernel void " << entry_point_name << "(const __global float input"
<< array_dims(input_shape) << ", const __global float filter" << array_dims(filter_shape)
<< ", __global float output" << array_dims(output_shape) << ")\n";
writer.block_begin();
{ // Main function body
writer << "const unsigned batch = get_global_id(0);\n";
gws.push_back(output_shape.at(batch_axis_data));
writer << "// for (uint batch = 0; batch < " << output_shape.at(batch_axis_data)
<< "; ++batch)\n";
writer.block_begin();
{
writer << "const unsigned output_channel = get_global_id(1);\n";
gws.push_back(output_shape.at(output_channel_axis_result));
writer << "// for (uint output_channel = 0; output_channel < "
<< output_shape.at(output_channel_axis_result) << "; ++output_channel)\n";
writer.block_begin();
{
// The first loop over output dimensions
writer << "const unsigned i0 = get_global_id(2);\n";
gws.push_back(output_data.at(0));
writer << "// for (uint i0 = 0; i0 < " << output_data.at(0) << "; ++i0)\n";
writer.block_begin();
{
// Loops over other output dimensions
size_t var_idx = 1;
for (auto i = output_data.begin() + 1; i != output_data.end(); ++i)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << *i
<< "; ++i" << var_idx << ")\n";
writer.block_begin();
++var_idx;
}
writer << "float result = 0.0f;\n";
writer << "\n// Loop over input_channel\n"
<< "for (uint input_channel = 0; input_channel < "
<< input_shape.at(input_channel_axis_data) << "; ++input_channel)\n";
writer.block_begin();
{
// Loop over filter
// Since first two dimensions are special, let start from third dimension
writer << "// Over filter iterations\n";
var_idx = 0;
for (auto const& i : filter_data)
{
writer << "for (uint f" << var_idx << " = 0; f" << var_idx << " < " << i
<< "; ++f" << var_idx << ")\n";
writer.block_begin();
writer << "int input_idx" << var_idx << " = (i" << var_idx << " * "
<< win_stride.at(var_idx) << " /*win_stride*/"
<< ") + (f" << var_idx << " * " << win_dilation.at(var_idx)
<< " /*win_dilation*/) - " << pad_below.at(var_idx)
<< " /*pad_below*/;\n";
++var_idx;
}
// Get the input value
writer << "float input_pad = " << default_pad_value << ";\n";
// Generate dilation conditionals
writer << "if (";
var_idx = 0;
for (auto const& i : output_data)
{
if (var_idx)
{
writer << " && ";
}
writer << "(((i" << var_idx << " + f" << var_idx << ") % "
<< data_dilation.at(var_idx) << ") == 0)";
++var_idx;
}
writer << ") /*data_dilation. If we are in a dilation gap"
", we have no source coordinate.*/\n";
writer.block_begin();
{
// Generate other conditionals
writer << "if (";
var_idx = 0;
for (auto const& i : input_data)
{
if (var_idx)
{
writer << " && ";
}
writer << "((input_idx" << var_idx << " >= 0) && (input_idx"
<< var_idx << " < " << i << "))";
++var_idx;
}
writer << ")\n";
writer.block_begin();
{
writer << "input_pad = " << input_order
<< array_dim(input_data, "input_idx") << ";\n";
}
writer.block_end();
//End of other conditional generation
}
writer.block_end();
//End of dilation conditional generation
// Output element calculation
writer << "result += input_pad * " << filter_order
<< array_dim(filter_data, "f", reverse_filter) << ";\n";
// Closing brackets for filter loop
for (auto const& i : filter_data)
{
writer.block_end();
}
}
writer.block_end();
writer << "// End input_channel loop\n";
writer << output_order << access_dims(output_data) << " = result;\n";
// Closing brackets for other output dimensions
for (auto i = output_data.begin() + 1; i != output_data.end(); ++i)
{
writer.block_end();
}
} // Closing brackets for the first loop over output dimensions
writer.block_end();
} // End of loop over output_channel
writer.block_end();
} // End of loop over batch
writer.block_end();
} // Main function body
writer.block_end();
const cldnn::custom_gpu_primitive op_convolution(output_name,
{input_name, filter_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(2, 1),
"",
layout,
gws);
topology.add(op_convolution);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <CPP/topology.hpp>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace intelgpu
{
// This implements Convolution nGraph operation
// nGraph uses channels in this operation but clDNN uses full input data
void do_convolution_operation(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& filter_name,
const Shape& filter_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const CoordinateDiff& pad_below,
const Strides& win_stride,
const Strides& win_dilation,
const Strides& data_dilation,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t output_channel_axis_result,
const std::string& input_order,
const std::string& filter_order,
const std::string& output_order,
bool reverse_filter);
}
}
}
......@@ -13,14 +13,11 @@ backwards_avgpool_n2_c2_hw4x4_numeric
backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric
backwards_batch_norm_three_outputs
backwards_ceiling
backwards_cos
backwards_cosh
backwards_dot_scalar_tensor
backwards_dot_tensor3_tensor3
backwards_dot_tensor_scalar
backwards_dot_tensor_vector
backwards_floor
backwards_maximum
backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2
backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_maxpool_n4c1h4w4_kh2kw2_sh1sw1
......@@ -48,17 +45,11 @@ concat_matrix_int64
constant_multi_use
convert_int32_bool
convert_int32_float32
convolution_2d_1item
convolution_2d_1item_1o1i_data_dilated
convolution_2d_1item_2o1i_data_dilated
convolution_2d_1item_2o2i_data_dilated
convolution_2d_1item_5o3i_data_dilated
convolution_2d_1item_padded_1_1x1_1
convolution_2d_1item_padded_2_3x4_5
convolution_2d_2item_5o3i_data_dilated
convolution_2d_2items
convolution_2d_2items_dilated
convolution_2d_2items_dilated_padded
convolution_2d_2items_strided
convolution_2d_2items_strided_padded
convolution_2d_2items_strided_padded_same
......@@ -69,16 +60,6 @@ convolution_3d_1item_large_5o3i_padded_uneven_filter_uneven_data_dilation_data_d
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_filter_dilated_data_dilated
convolution_3d_2item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated
convolution_3d_2items
convolution_4d_2items
convolution_4d_4items
convolution_4d_4items_dilated
convolution_4d_4items_padded_neg
convolution_4d_4items_strided
convolution_4d_4items_strided_dilated
convolution_4d_4items_strided_dilated_padded
convolution_4d_4items_strided_dilated_padded_neg
convolution_4d_4items_strided_dilated_padded_same
convolution_outlining
divide_by_zero_int32
dot_matrix_vector_int64
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment