Commit 49d15902 authored by dmyershov's avatar dmyershov Committed by Scott Cyphers

IntelGPU backend: Reverse operation implementation (#1338)

parent 91a3bf87
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -320,6 +321,35 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -320,6 +321,35 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
output_shape, output_shape,
output_type); output_type);
} }
else if ("Reverse" == op->description())
{
arguments_check(op, 1, 1);
const string& input_name = op->get_inputs().at(0).get_tensor().get_name();
const Shape& input_shape = op->get_inputs().at(0).get_shape();
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const shared_ptr<op::Reverse> reverse_op = static_pointer_cast<op::Reverse>(op);
const AxisSet& reversed_axes = reverse_op->get_reversed_axes();
if (reversed_axes.empty())
{
do_equal_propagation(topology, input_name, output_name);
}
else
{
do_reverse_operation(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
reversed_axes);
}
}
else if ("Add" == op->description()) else if ("Add" == op->description())
{ {
do_eltwise_operation(topology, op, cldnn::eltwise_mode::sum); do_eltwise_operation(topology, op, cldnn::eltwise_mode::sum);
......
...@@ -61,7 +61,8 @@ string runtime::intelgpu::array_dims(const Shape& dimentions) ...@@ -61,7 +61,8 @@ string runtime::intelgpu::array_dims(const Shape& dimentions)
return buffer; return buffer;
} }
string runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& axis) string
runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& axis, bool is_reversed)
{ {
size_t var_idx = 0; size_t var_idx = 0;
string buffer; string buffer;
...@@ -72,6 +73,10 @@ string runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& ax ...@@ -72,6 +73,10 @@ string runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& ax
{ {
buffer += "[i" + to_string(var_idx) + "]"; buffer += "[i" + to_string(var_idx) + "]";
} }
else if (is_reversed)
{
buffer += "[" + to_string(i) + " - i" + to_string(var_idx) + " - 1]";
}
++var_idx; ++var_idx;
} }
...@@ -83,6 +88,46 @@ string runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& ax ...@@ -83,6 +88,46 @@ string runtime::intelgpu::access_dims(const Shape& dimentions, const AxisSet& ax
return buffer; return buffer;
} }
static vector<size_t> generate_loops(codegen::CodeWriter& writer, const Shape& shape, bool is_begin)
{
const size_t cldnn_gws_lim = 3;
vector<size_t> gws;
size_t var_idx = 0;
for (auto const& i : shape)
{
if (var_idx < cldnn_gws_lim)
{
if (is_begin)
{
writer << "const unsigned i" << var_idx << " = get_global_id(" << var_idx << ");\n";
gws.push_back(i);
}
}
else
{
if (is_begin)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i << "; ++i"
<< var_idx << ")\n";
writer.block_begin();
}
else
{
writer.block_end();
}
}
++var_idx;
}
if (gws.empty())
{
gws.push_back(1);
}
return gws;
}
static string access_dims_strided(const Shape& dimentions, static string access_dims_strided(const Shape& dimentions,
const Shape& pad_below, const Shape& pad_below,
const Shape& pad_interior, const Shape& pad_interior,
...@@ -669,3 +714,42 @@ void runtime::intelgpu::do_logic_kernel(cldnn::topology& topology, ...@@ -669,3 +714,42 @@ void runtime::intelgpu::do_logic_kernel(cldnn::topology& topology,
{1}); {1});
topology.add(op_logical); topology.add(op_logical);
} }
void runtime::intelgpu::do_reverse_operation(cldnn::topology& topology,
const string& input_name,
const Shape& input_shape,
const string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const AxisSet& reversed_axes)
{
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const string entry_point_name = "reverse_" + output_name;
codegen::CodeWriter writer;
vector<size_t> gws;
writer << "__kernel void " << entry_point_name << "(const __global float input"
<< array_dims(input_shape) << ", __global float output" << array_dims(output_shape)
<< ")\n";
writer.block_begin();
{
gws = generate_loops(writer, output_shape, true);
writer << "output" << access_dims(output_shape) << " = input"
<< access_dims(output_shape, reversed_axes, true) << ";\n";
generate_loops(writer, output_shape, false);
}
writer.block_end();
const cldnn::custom_gpu_primitive op_reverse(output_name,
{input_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(1, 1),
"",
layout,
gws);
topology.add(op_reverse);
}
...@@ -82,10 +82,20 @@ namespace ngraph ...@@ -82,10 +82,20 @@ namespace ngraph
const element::Type& output_type, const element::Type& output_type,
const std::string& operation); const std::string& operation);
void do_reverse_operation(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const AxisSet& reversed_axes);
// Helper functions used in cldnn::custom_gpu_primitive kernels // Helper functions used in cldnn::custom_gpu_primitive kernels
std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output); std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output);
std::string array_dims(const Shape& dimentions); std::string array_dims(const Shape& dimentions);
std::string access_dims(const Shape& dimentions, const AxisSet& axis = {}); std::string access_dims(const Shape& dimentions,
const AxisSet& axis = {},
bool is_reversed = false);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment