Commit f9ded0b1 authored by shssf's avatar shssf Committed by Robert Kimball

IntelGPU backend: Greater, Less, Equal operations (#1331)

parent c5889b2b
...@@ -65,6 +65,16 @@ static void arguments_check(const shared_ptr<Node>& op, size_t input, size_t out ...@@ -65,6 +65,16 @@ static void arguments_check(const shared_ptr<Node>& op, size_t input, size_t out
} }
} }
static void argument_type_check(const element::Type& type)
{
if (type != element::f32)
{
ostringstream os;
os << "Kernel data type " << type << " is not supported";
throw invalid_argument(os.str());
}
}
static void do_eltwise_operation(cldnn::topology& topology, static void do_eltwise_operation(cldnn::topology& topology,
const shared_ptr<Node>& op, const shared_ptr<Node>& op,
cldnn::eltwise_mode mode) cldnn::eltwise_mode mode)
...@@ -98,6 +108,33 @@ static void do_unary_operation(cldnn::topology& topology, ...@@ -98,6 +108,33 @@ static void do_unary_operation(cldnn::topology& topology,
topology.add(cldnn_unary); topology.add(cldnn_unary);
} }
static void do_logical_operation(cldnn::topology& topology,
const shared_ptr<Node>& op,
const string& operation)
{
arguments_check(op, 2, 1);
const string& inputA_name = op->get_inputs().at(0).get_tensor().get_name();
const Shape& inputA_shape = op->get_inputs().at(0).get_shape();
argument_type_check(op->get_inputs().at(0).get_tensor().get_element_type());
const string& inputB_name = op->get_inputs().at(1).get_tensor().get_name();
const Shape& inputB_shape = op->get_inputs().at(1).get_shape();
argument_type_check(op->get_inputs().at(1).get_tensor().get_element_type());
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type = op->get_outputs().begin()->get_tensor().get_element_type();
runtime::intelgpu::do_logic_kernel(topology,
inputA_name,
inputA_shape,
inputB_name,
inputB_shape,
output_name,
output_shape,
output_type,
operation);
}
// This function needed to only change the name of the data in topology // This function needed to only change the name of the data in topology
// No real data copy needed // No real data copy needed
static void do_equal_propagation(cldnn::topology& topology, static void do_equal_propagation(cldnn::topology& topology,
...@@ -487,6 +524,30 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -487,6 +524,30 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
{ {
do_unary_operation(topology, op, activation_logistic); do_unary_operation(topology, op, activation_logistic);
} }
else if ("Greater" == op->description())
{
do_logical_operation(topology, op, " > ");
}
else if ("GreaterEq" == op->description())
{
do_logical_operation(topology, op, " >= ");
}
else if ("Equal" == op->description())
{
do_logical_operation(topology, op, " == ");
}
else if ("NotEqual" == op->description())
{
do_logical_operation(topology, op, " != ");
}
else if ("Less" == op->description())
{
do_logical_operation(topology, op, " < ");
}
else if ("LessEq" == op->description())
{
do_logical_operation(topology, op, " <= ");
}
else if ("Subtract" == op->description()) else if ("Subtract" == op->description())
{ {
do_eltwise_operation(topology, op, cldnn::eltwise_mode::sub); do_eltwise_operation(topology, op, cldnn::eltwise_mode::sub);
......
...@@ -592,3 +592,67 @@ void runtime::intelgpu::do_select_operation(cldnn::topology& topology, ...@@ -592,3 +592,67 @@ void runtime::intelgpu::do_select_operation(cldnn::topology& topology,
{1}); {1});
topology.add(op_select); topology.add(op_select);
} }
void runtime::intelgpu::do_logic_kernel(cldnn::topology& topology,
const string& inputA_name,
const Shape& inputA_shape,
const string& inputB_name,
const Shape& inputB_shape,
const string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const string& operation)
{
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const string entry_point_name = "logic_" + output_name;
codegen::CodeWriter writer;
writer << "__kernel void " << entry_point_name << "(const __global float inputA"
<< array_dims(inputA_shape) << ", const __global float inputB"
<< array_dims(inputB_shape) << ", __global char output" << array_dims(output_shape)
<< ")\n";
writer.block_begin();
{
size_t var_idx = 0;
// Main loops
for (auto const& i : output_shape)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i << "; ++i"
<< var_idx << ")\n";
writer.block_begin();
++var_idx;
}
writer << "if (inputA" << access_dims(inputA_shape) << operation << "inputB"
<< access_dims(inputB_shape) << ")\n";
writer.block_begin();
{
writer << "output" << access_dims(output_shape) << " = 1;\n";
}
writer.block_end();
writer << "else\n";
writer.block_begin();
{
writer << "output" << access_dims(output_shape) << " = 0;\n";
}
writer.block_end();
// Closing brackets for main loops
for (auto const& i : output_shape)
{
writer.block_end();
}
}
writer.block_end();
const cldnn::custom_gpu_primitive op_logical(output_name,
{inputA_name, inputB_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(2, 1),
"",
layout,
{1});
topology.add(op_logical);
}
...@@ -70,6 +70,16 @@ namespace ngraph ...@@ -70,6 +70,16 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const element::Type& output_type); const element::Type& output_type);
void do_logic_kernel(cldnn::topology& topology,
const std::string& inputA_name,
const Shape& inputA_shape,
const std::string& inputB_name,
const Shape& inputB_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const std::string& operation);
// Helper functions used in cldnn::custom_gpu_primitive kernels // Helper functions used in cldnn::custom_gpu_primitive kernels
std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output); std::vector<cldnn_arg> get_kernel_args(size_t input, size_t output);
std::string array_dims(const Shape& dimentions); std::string array_dims(const Shape& dimentions);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment