Commit e3b42bf9 authored by Li Peng's avatar Li Peng

batch_norm and blank layer ocl implementation

Signed-off-by: 's avatarLi Peng <peng.li@intel.com>
parent 67f9406c
......@@ -22,6 +22,7 @@ class BatchNormLayerImpl : public BatchNormLayer
{
public:
Mat weights_, bias_;
Mat weightMat, biasMat;
BatchNormLayerImpl(const LayerParams& params)
{
......@@ -96,17 +97,81 @@ public:
return true;
}
void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
if (inputs[0]->dims == 4)
{
int groups = inputs[0]->size[0];
int channels = inputs[0]->size[1];
int rows = inputs[0]->size[2];
int cols = inputs[0]->size[3];
MatShape s = shape(groups * channels, rows * cols);
weightMat = Mat(s[0], s[1], CV_32FC1);
biasMat = Mat(s[0], s[1], CV_32FC1);
for (int n = 0; n < s[0]; n++)
{
weightMat.row(n).setTo(weights_.at<float>(n % channels));
biasMat.row(n).setTo(bias_.at<float>(n % channels));
}
}
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
{
std::vector<UMat> inputs;
std::vector<UMat> outputs;
inputs_.getUMatVector(inputs);
outputs_.getUMatVector(outputs);
CV_Assert(blobs.size() >= 2);
CV_Assert(inputs.size() == 1);
UMat &inpBlob = inputs[0];
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
int groups = inpBlob.size[0];
int channels = inpBlob.size[1];
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
for (size_t ii = 0; ii < outputs.size(); ii++)
{
if (inpBlob.dims == 2)
{
UMat& src = inputs[ii];
UMat& dst = outputs[ii];
multiply(src, weights_, dst);
add(dst, bias_, dst);
}
else
{
MatShape s = shape(groups * channels, rows * cols);
UMat src = inputs[ii].reshape(1, s.size(), &s[0]);
UMat dst = outputs[ii].reshape(1, s.size(), &s[0]);
multiply(src, weightMat, dst);
add(dst, biasMat, dst);
}
}
return true;
}
#endif
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr)
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN((preferableTarget == DNN_TARGET_OPENCL) &&
OCL_PERFORMANCE_CHECK(ocl::Device::getDefault().isIntel()),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
}
......
......@@ -63,8 +63,22 @@ public:
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs, OutputArrayOfArrays outputs, OutputArrayOfArrays internals)
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
{
std::vector<UMat> inputs;
std::vector<UMat> outputs;
inputs_.getUMatVector(inputs);
outputs_.getUMatVector(outputs);
for (int i = 0, n = outputs.size(); i < n; ++i)
{
void *src_handle = inputs[i].handle(ACCESS_READ);
void *dst_handle = outputs[i].handle(ACCESS_WRITE);
if (src_handle != dst_handle)
inputs[i].copyTo(outputs[i]);
}
return true;
}
#endif
......
......@@ -152,6 +152,13 @@ TEST(Test_TensorFlow, batch_norm)
runTensorFlowNet("batch_norm_text", DNN_TARGET_CPU, true);
}
OCL_TEST(Test_TensorFlow, batch_norm)
{
runTensorFlowNet("batch_norm", DNN_TARGET_OPENCL);
runTensorFlowNet("fused_batch_norm", DNN_TARGET_OPENCL);
runTensorFlowNet("batch_norm_text", DNN_TARGET_OPENCL, true);
}
TEST(Test_TensorFlow, pooling)
{
runTensorFlowNet("max_pool_even");
......
......@@ -170,6 +170,11 @@ TEST(Torch_Importer, run_batch_norm)
runTorchNet("net_batch_norm", DNN_TARGET_CPU, "", false, true);
}
OCL_TEST(Torch_Importer, run_batch_norm)
{
runTorchNet("net_batch_norm", DNN_TARGET_OPENCL, "", false, true);
}
TEST(Torch_Importer, net_prelu)
{
runTorchNet("net_prelu");
......@@ -242,6 +247,11 @@ TEST(Torch_Importer, net_non_spatial)
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
}
OCL_TEST(Torch_Importer, net_non_spatial)
{
runTorchNet("net_non_spatial", DNN_TARGET_OPENCL, "", false, true);
}
TEST(Torch_Importer, ENet_accuracy)
{
Net net;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment