Commit 0af487e9 authored by fenglei.tian's avatar fenglei.tian

comments and code style

parent 7525d6e1
...@@ -74,10 +74,10 @@ namespace ngraph ...@@ -74,10 +74,10 @@ namespace ngraph
std::string kernel; std::string kernel;
std::string data_type("float"); std::string data_type("float");
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_)" + name + "(" + void cuda_)" + name + "(" + data_type +
data_type + "* in, " + data_type + "* out, size_t m, size_t k, size_t n)\n" + R"( "* in, " + data_type + "* out, size_t m, size_t k, size_t n)\n" + R"(
{ {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
......
...@@ -462,6 +462,7 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast( ...@@ -462,6 +462,7 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast(
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& axes = broadcast->get_broadcast_axes(); auto& axes = broadcast->get_broadcast_axes();
//broadcast axes is empty, do a copy
if (axes.empty()) if (axes.empty())
{ {
writer << "{ // " << n->get_name() << " \n"; writer << "{ // " << n->get_name() << " \n";
...@@ -473,8 +474,10 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast( ...@@ -473,8 +474,10 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast(
return; return;
} }
//broadcast axes size is 1, or can be group to 1 (serveral continuous axes, like 01 or 12 or 123 etc)
vector<int> axes_v; vector<int> axes_v;
std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v)); std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v));
std::sort(axes_v.begin(), axes_v.end());
bool is_one_axes = true; bool is_one_axes = true;
if (axes.size() != 1) if (axes.size() != 1)
{ {
...@@ -490,13 +493,13 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast( ...@@ -490,13 +493,13 @@ void runtime::gpu::GPU_Emitter::EmitBroadcast(
if (is_one_axes) if (is_one_axes)
{ {
int repeat_times = 1; int repeat_times = 1;
for (int i = 0; i < axes.size(); i++) for (int i = 0; i < axes_v.size(); i++)
{ {
repeat_times *= result_shape[axes_v[i]]; repeat_times *= result_shape[axes_v[i]];
} }
int repeat_size = 1; int repeat_size = 1;
for (int i = *axes.rbegin(); i < result_shape.size(); i++) for (int i = *axes_v.rbegin() + 1; i < result_shape.size(); i++)
{ {
repeat_size *= result_shape[i]; repeat_size *= result_shape[i];
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment