Commit 82fc63c0 authored by dmyershov's avatar dmyershov Committed by Robert Kimball

IntelGPU backend: Concat operation implementation (#1363)

* IntelGPU backend: Concat operation implementation

* Several remarks were fixed

* Remaining remarks were fixed; List of tests for INTELGPU was updated

* PR1363: Minor Fixes
parent 104fd3ee
......@@ -17,6 +17,7 @@
#include <CPP/activation.hpp>
#include <CPP/activation_grad.hpp>
#include <CPP/batch_norm.hpp>
#include <CPP/concatenation.hpp>
#include <CPP/convolution.hpp>
#include <CPP/data.hpp>
#include <CPP/eltwise.hpp>
......@@ -40,6 +41,7 @@
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
......@@ -324,6 +326,33 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
reversed_axes);
}
}
else if ("Concat" == op->description())
{
if (op->get_inputs().empty() || op->get_outputs().size() != 1)
{
arguments_check(op, 1, 1);
}
const size_t ngraph_tensor_dims = get_input_shape(op, 0).size();
const shared_ptr<op::Concat> concat_op = static_pointer_cast<op::Concat>(op);
const size_t ngraph_concat_axis = concat_op->get_concatenation_axis();
vector<cldnn::primitive_id> inputs;
cldnn::concatenation::concatenation_axis cldnn_axis =
runtime::intelgpu::IntelGPULayout::get_cldnn_axis(ngraph_tensor_dims,
ngraph_concat_axis);
for (auto const& input : op->get_inputs())
{
const Shape& input_shape = input.get_shape();
if (shape_size(input_shape))
{
inputs.push_back(input.get_tensor().get_name());
}
}
const cldnn::concatenation cldnn_concat(get_output_name(op), inputs, cldnn_axis);
topology.add(cldnn_concat);
}
else if ("Add" == op->description())
{
do_eltwise_operation(topology, op, cldnn::eltwise_mode::sum);
......
......@@ -131,19 +131,23 @@ cldnn::layout runtime::intelgpu::IntelGPULayout::create_cldnn_layout(
}
cldnn::concatenation::concatenation_axis
runtime::intelgpu::IntelGPULayout::get_cldnn_axis(size_t tensor_channel)
runtime::intelgpu::IntelGPULayout::get_cldnn_axis(size_t shape_size, size_t axis)
{
switch (tensor_channel)
const size_t t_channel = shape_size - axis - 1;
switch (t_channel)
{
case 0: return cldnn::concatenation::along_b;
case 1: return cldnn::concatenation::along_f;
case 2: return cldnn::concatenation::along_y;
case 3: return cldnn::concatenation::along_x;
default:
case 0: return cldnn::concatenation::along_x;
case 1: return cldnn::concatenation::along_y;
case 2: return cldnn::concatenation::along_f;
case 3:
if (shape_size < 5)
{
ostringstream os;
os << "IntelGPULayout::get_cldnn_axis: wrong tensor channel " << tensor_channel;
throw invalid_argument(os.str());
return cldnn::concatenation::along_b;
}
/* no break */
default:
throw invalid_argument("IntelGPULayout::get_cldnn_axis: wrong tensor channel " +
to_string(t_channel));
}
}
......@@ -53,7 +53,7 @@ public:
static cldnn::tensor create_cldnn_tensor(const Shape& element_shape);
static cldnn::tensor create_cldnn_offset(const Shape& pad_below);
// This function converts Shape dimension_id into cldnn::concatenation id
static cldnn::concatenation::concatenation_axis get_cldnn_axis(size_t tensor_channel);
static cldnn::concatenation::concatenation_axis get_cldnn_axis(size_t shape_size, size_t axis);
private:
Strides strides;
......
......@@ -17,9 +17,6 @@ backwards_avgpool_n2_c2_hw4x4_numeric
backwards_avgpool_n2_c2_hw4x4_win_2x2_str_1x1_numeric
backwards_batch_norm_three_outputs
backwards_ceiling
backwards_concat_axis_0
backwards_concat_axis_1
backwards_concat_vector
backwards_cos
backwards_cosh
backwards_dot_scalar_tensor
......@@ -56,16 +53,7 @@ batch_norm_one_output
batch_norm_three_outputs
broadcast_vector_rowwise_int64
ceiling
concat_2d_tensor
concat_4d_tensor
concat_5d
concat_matrix_colwise
concat_matrix_int64
concat_matrix_rowwise
concat_vector
concat_zero_length_1d_last
concat_zero_length_1d_middle
concat_zero_length_4d_middle
constant_multi_use
convert_float32_bool
convert_int32_bool
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment