Commit 28622bde authored by Sergey Shalnov's avatar Sergey Shalnov Committed by Robert Kimball

IntelGPU backend: All and Any operations (#2239)

parent e45caeb9
......@@ -65,6 +65,8 @@
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -418,7 +420,6 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
......@@ -959,6 +960,50 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
}
break;
}
case OP_TYPEID::All:
{
arguments_check(op, 1, 1);
const shared_ptr<op::All> all_op = static_pointer_cast<op::All>(op);
const AxisSet& axis = all_op->get_reduction_axes();
const shared_ptr<Node> def_val = all_op->get_default_value();
const shared_ptr<op::Constant> def_const = static_pointer_cast<op::Constant>(def_val);
const vector<std::string>& values = def_const->get_value_strings();
// Empty axis is not a case for do_equal_propagation()
do_all_any_op(topology,
get_input_name(op, 0),
get_input_shape(op, 0),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
axis,
"lhs && rhs",
values.at(0));
break;
}
case OP_TYPEID::Any:
{
arguments_check(op, 1, 1);
const shared_ptr<op::Any> any_op = static_pointer_cast<op::Any>(op);
const AxisSet& axis = any_op->get_reduction_axes();
const shared_ptr<Node> def_val = any_op->get_default_value();
const shared_ptr<op::Constant> def_const = static_pointer_cast<op::Constant>(def_val);
const vector<std::string>& values = def_const->get_value_strings();
// Empty axis is not a case for do_equal_propagation()
do_all_any_op(topology,
get_input_name(op, 0),
get_input_shape(op, 0),
get_output_name(op),
get_output_shape(op),
get_output_type(op),
axis,
"lhs || rhs",
values.at(0));
break;
}
case OP_TYPEID::Relu:
{
do_unary_operation(topology, op, activation_relu);
......@@ -1713,9 +1758,7 @@ runtime::Handle runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function>
topology.add(lrn);
break;
}
case OP_TYPEID::All:
case OP_TYPEID::AllReduce:
case OP_TYPEID::Any:
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::FunctionCall:
case OP_TYPEID::Dequantize:
......
......@@ -24,6 +24,88 @@
using namespace std;
using namespace ngraph;
void runtime::intelgpu::do_all_any_op(cldnn::topology& topology,
const string& input0_name,
const Shape& input0_shape,
const string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const AxisSet& axis,
const std::string& operation,
const std::string& init_val)
{
const string entry_point_name = "custom_op_all_any_" + output_name;
const string kernel_type_name = get_opencl_type_name(output_type);
const size_t input_size = shape_size<Shape>(input0_shape);
codegen::CodeWriter writer;
// The kernel name and parameters
gen_func_def(writer,
entry_point_name,
{1, kernel_type_name},
{input0_shape, {1}},
kernel_type_name,
output_shape);
writer.block_begin();
{
// Initialization loop
size_t var_idx = 0;
for (auto const& i : output_shape)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i << "; ++i"
<< var_idx << ")\n";
writer.block_begin();
++var_idx;
}
writer << "output" << access_dims(output_shape) << " = " << init_val << ";\n";
// Closing brackets for initialization loop
for (auto const& i : output_shape)
{
writer.block_end();
}
if (input_size && !input0_shape.empty())
{
// Main operation loop
var_idx = 0;
for (auto const& i : input0_shape)
{
writer << "for (uint i" << var_idx << " = 0; i" << var_idx << " < " << i << "; ++i"
<< var_idx << ")\n";
writer.block_begin();
++var_idx;
}
writer << kernel_type_name << " lhs = output" << access_dims(input0_shape, "i", axis)
<< ";\n"
<< kernel_type_name << " rhs = input0" << access_dims(input0_shape) << ";\n"
<< "output" << access_dims(input0_shape, "i", axis) << " = (" << operation
<< ");\n";
// Closing brackets for loop
for (auto const& i : input0_shape)
{
writer.block_end();
}
}
} // End of function bracket
writer.block_end();
const cldnn::layout layout = IntelGPULayout::create_cldnn_layout(output_type, output_shape);
const cldnn::custom_gpu_primitive op_all_any(output_name,
{input0_name},
{writer.get_code()},
entry_point_name,
get_kernel_args(1, 1),
"",
layout,
{1});
topology.add(op_all_any);
}
static void get_custom_func_name(codegen::CodeWriter& writer,
vector<shared_ptr<Function>>& func,
const string& func_name,
......
......@@ -28,6 +28,16 @@ namespace ngraph
{
namespace intelgpu
{
void do_all_any_op(cldnn::topology& topology,
const std::string& input0_name,
const Shape& input0_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
const AxisSet& axis,
const std::string& operation,
const std::string& init_val);
void do_reduce_func_call(cldnn::topology& topology,
const std::string& input0_name,
const Shape& input0_shape,
......
......@@ -21,6 +21,8 @@
#include "ngraph/runtime/intelgpu/visualize_tree.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
......@@ -210,6 +212,15 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
writer << print_table_row_dims("reduction_axis", arith_op->get_reduction_axes());
break;
}
case OP_TYPEID::All:
case OP_TYPEID::Any:
{
const shared_ptr<op::util::LogicalReduction> logical_op =
static_pointer_cast<op::util::LogicalReduction>(node);
writer << print_table_row_dims("reduction_axis", logical_op->get_reduction_axes());
break;
}
case OP_TYPEID::ArgMin:
case OP_TYPEID::ArgMax:
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment