Commit 92adea38 authored by shssf's avatar shssf Committed by Scott Cyphers

IntelGPU backend: Sum and redeveloped Broadcast operation (#1276)

parent cb84305e
...@@ -289,6 +289,8 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -289,6 +289,8 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
const string& output_name = op->get_outputs().begin()->get_tensor().get_name(); const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape(); const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const shared_ptr<op::Broadcast> broadcast = static_pointer_cast<op::Broadcast>(op); const shared_ptr<op::Broadcast> broadcast = static_pointer_cast<op::Broadcast>(op);
const AxisSet& axis = broadcast->get_broadcast_axes(); const AxisSet& axis = broadcast->get_broadcast_axes();
...@@ -297,10 +299,67 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func) ...@@ -297,10 +299,67 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
{ {
do_equal_propagation(topology, input_name, output_name); do_equal_propagation(topology, input_name, output_name);
} }
else if (input_shape.empty())
{
do_bcast_sum_operation_scalar(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
true);
}
else
{
do_bcast_sum_operation(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
axis,
true);
}
}
else if ("Sum" == op->description())
{
arguments_check(op, 1, 1);
const string& input_name = op->get_inputs().begin()->get_tensor().get_name();
const Shape& input_shape = op->get_inputs().begin()->get_shape();
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const shared_ptr<op::Sum> sum = static_pointer_cast<op::Sum>(op);
const AxisSet& axis = sum->get_reduction_axes();
if (axis.empty())
{
do_equal_propagation(topology, input_name, output_name);
}
else if (output_shape.empty())
{
do_bcast_sum_operation_scalar(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
false);
}
else else
{ {
do_broadcast_operation( do_bcast_sum_operation(topology,
topology, input_name, input_shape, output_name, output_shape, axis); input_name,
input_shape,
output_name,
output_shape,
output_type,
axis,
false);
} }
} }
else if ("Reshape" == op->description()) else if ("Reshape" == op->description())
......
...@@ -27,13 +27,26 @@ namespace ngraph ...@@ -27,13 +27,26 @@ namespace ngraph
{ {
namespace intelgpu namespace intelgpu
{ {
// This implements Broadcast nGraph operation // This implements Broadcast and Sum nGraph operations
void do_broadcast_operation(cldnn::topology& topology, // in case of input_shape is not empty
void do_bcast_sum_operation(cldnn::topology& topology,
const std::string& input_name, const std::string& input_name,
const Shape& input_shape, const Shape& input_shape,
const std::string& output_name, const std::string& output_name,
const Shape& output_shape, const Shape& output_shape,
const AxisSet& axis); const element::Type& output_type,
const AxisSet& axis,
bool is_bcast);
// This implements Broadcast and Sum nGraph operations
// in case of input_shape is empty
void do_bcast_sum_operation_scalar(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
bool is_bcast);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment