Commit 937b8e42 authored by Li Peng's avatar Li Peng

dnn(ocl4dnn): support log softmax in ocl4dnn

Signed-off-by: 's avatarLi Peng <peng.li@intel.com>
parent af8ed9d0
...@@ -100,6 +100,7 @@ public: ...@@ -100,6 +100,7 @@ public:
config.in_shape = shape(*inputs[0]); config.in_shape = shape(*inputs[0]);
config.axis = axisRaw; config.axis = axisRaw;
config.channels = inputs[0]->size[axisRaw]; config.channels = inputs[0]->size[axisRaw];
config.logsoftmax = logSoftMax;
softmaxOp = Ptr<OCL4DNNSoftmax<float> >(new OCL4DNNSoftmax<float>(config)); softmaxOp = Ptr<OCL4DNNSoftmax<float> >(new OCL4DNNSoftmax<float>(config));
} }
...@@ -108,7 +109,7 @@ public: ...@@ -108,7 +109,7 @@ public:
srcMat = inputs[0]->getUMat(ACCESS_READ); srcMat = inputs[0]->getUMat(ACCESS_READ);
dstMat = outputs[0].getUMat(ACCESS_WRITE); dstMat = outputs[0].getUMat(ACCESS_WRITE);
if (!logSoftMax && softmaxOp->Forward(srcMat, dstMat)) if (softmaxOp->Forward(srcMat, dstMat))
return true; return true;
const Mat &src = *inputs[0]; const Mat &src = *inputs[0];
......
...@@ -445,11 +445,12 @@ class OCL4DNNLRN ...@@ -445,11 +445,12 @@ class OCL4DNNLRN
struct OCL4DNNSoftmaxConfig struct OCL4DNNSoftmaxConfig
{ {
OCL4DNNSoftmaxConfig() : axis(0), channels(0) OCL4DNNSoftmaxConfig() : axis(0), channels(0), logsoftmax(false)
{} {}
MatShape in_shape; MatShape in_shape;
int axis; int axis;
int channels; int channels;
bool logsoftmax;
}; };
template<typename Dtype> template<typename Dtype>
...@@ -467,6 +468,7 @@ class OCL4DNNSoftmax ...@@ -467,6 +468,7 @@ class OCL4DNNSoftmax
int32_t channels_; int32_t channels_;
int32_t count_; int32_t count_;
bool use_slm_; bool use_slm_;
bool log_softmax_;
UMat scale_data_; UMat scale_data_;
}; };
#endif // HAVE_OPENCL #endif // HAVE_OPENCL
......
...@@ -52,6 +52,7 @@ OCL4DNNSoftmax<Dtype>::OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config) ...@@ -52,6 +52,7 @@ OCL4DNNSoftmax<Dtype>::OCL4DNNSoftmax(OCL4DNNSoftmaxConfig config)
{ {
softmax_axis_ = config.axis; softmax_axis_ = config.axis;
channels_ = config.channels; channels_ = config.channels;
log_softmax_ = config.logsoftmax;
inner_num_ = 1; inner_num_ = 1;
outer_num_ = 1; outer_num_ = 1;
...@@ -90,6 +91,7 @@ bool OCL4DNNSoftmax<Dtype>::Forward(const UMat& bottom, UMat& top) ...@@ -90,6 +91,7 @@ bool OCL4DNNSoftmax<Dtype>::Forward(const UMat& bottom, UMat& top)
String kname; String kname;
ocl::Kernel oclk_softmax_forward_kernel; ocl::Kernel oclk_softmax_forward_kernel;
if (log_softmax_) opts += " -DLOG_SOFTMAX ";
if (use_slm_) if (use_slm_)
kname = CL_KERNEL_SELECT("softmax_forward_slm"); kname = CL_KERNEL_SELECT("softmax_forward_slm");
else else
......
...@@ -112,7 +112,11 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann ...@@ -112,7 +112,11 @@ __kernel void TEMPLATE(softmax_forward_slm,Dtype)(const int num, const int chann
for (int index = get_global_id(0); index < channels * spatial_dim; for (int index = get_global_id(0); index < channels * spatial_dim;
index += get_global_size(0)) { index += get_global_size(0)) {
int s = index % spatial_dim; int s = index % spatial_dim;
out[n * channels * spatial_dim + index] = out_tmp[index] / scale_tmp[s]; Dtype v = out_tmp[index] / scale_tmp[s];
#ifdef LOG_SOFTMAX
v = log(v);
#endif
out[n * channels * spatial_dim + index] = v;
} }
} }
...@@ -177,6 +181,10 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels, ...@@ -177,6 +181,10 @@ __kernel void TEMPLATE(softmax_forward,Dtype)(const int num, const int channels,
for (int index = get_global_id(0); index < channels * spatial_dim; for (int index = get_global_id(0); index < channels * spatial_dim;
index += get_global_size(0)) { index += get_global_size(0)) {
int s = index % spatial_dim; int s = index % spatial_dim;
out[n * channels * spatial_dim + index] /= scale[n * spatial_dim + s]; Dtype v = out[n * channels * spatial_dim + index] / scale[n * spatial_dim + s];
#ifdef LOG_SOFTMAX
v = log(v);
#endif
out[n * channels * spatial_dim + index] = v;
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment