Commit 92adea38 authored by shssf's avatar shssf Committed by Scott Cyphers

IntelGPU backend: Sum and redeveloped Broadcast operation (#1276)

parent cb84305e
......@@ -289,6 +289,8 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const shared_ptr<op::Broadcast> broadcast = static_pointer_cast<op::Broadcast>(op);
const AxisSet& axis = broadcast->get_broadcast_axes();
......@@ -297,10 +299,67 @@ bool runtime::intelgpu::IntelGPUBackend::compile(shared_ptr<Function> func)
{
do_equal_propagation(topology, input_name, output_name);
}
else if (input_shape.empty())
{
do_bcast_sum_operation_scalar(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
true);
}
else
{
do_bcast_sum_operation(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
axis,
true);
}
}
else if ("Sum" == op->description())
{
arguments_check(op, 1, 1);
const string& input_name = op->get_inputs().begin()->get_tensor().get_name();
const Shape& input_shape = op->get_inputs().begin()->get_shape();
const string& output_name = op->get_outputs().begin()->get_tensor().get_name();
const Shape& output_shape = op->get_outputs().begin()->get_shape();
const element::Type& output_type =
op->get_outputs().begin()->get_tensor().get_element_type();
const shared_ptr<op::Sum> sum = static_pointer_cast<op::Sum>(op);
const AxisSet& axis = sum->get_reduction_axes();
if (axis.empty())
{
do_equal_propagation(topology, input_name, output_name);
}
else if (output_shape.empty())
{
do_bcast_sum_operation_scalar(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
false);
}
else
{
do_broadcast_operation(
topology, input_name, input_shape, output_name, output_shape, axis);
do_bcast_sum_operation(topology,
input_name,
input_shape,
output_name,
output_shape,
output_type,
axis,
false);
}
}
else if ("Reshape" == op->description())
......
......@@ -27,13 +27,26 @@ namespace ngraph
{
namespace intelgpu
{
// This implements Broadcast nGraph operation
void do_broadcast_operation(cldnn::topology& topology,
// This implements Broadcast and Sum nGraph operations
// in case of input_shape is not empty
void do_bcast_sum_operation(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& output_name,
const Shape& output_shape,
const AxisSet& axis);
const element::Type& output_type,
const AxisSet& axis,
bool is_bcast);
// This implements Broadcast and Sum nGraph operations
// in case of input_shape is empty
void do_bcast_sum_operation_scalar(cldnn::topology& topology,
const std::string& input_name,
const Shape& input_shape,
const std::string& output_name,
const Shape& output_shape,
const element::Type& output_type,
bool is_bcast);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment