Commit 9b8002cd authored by Aaron Denney's avatar Aaron Denney

remove use of constant memory in calc_all_iterations/compute_message/message_per_pixel

parent b792419c
...@@ -662,13 +662,14 @@ namespace cv { namespace cuda { namespace device ...@@ -662,13 +662,14 @@ namespace cv { namespace cuda { namespace device
template <typename T> template <typename T>
__device__ void message_per_pixel(const T* data, T* msg_dst, const T* msg1, const T* msg2, const T* msg3, __device__ void message_per_pixel(const T* data, T* msg_dst, const T* msg1, const T* msg2, const T* msg3,
const T* dst_disp, const T* src_disp, int nr_plane, int max_disc_term, float disc_single_jump, volatile T* temp) const T* dst_disp, const T* src_disp, int nr_plane, int max_disc_term, float disc_single_jump, volatile T* temp,
size_t disp_step)
{ {
T minimum = numeric_limits<T>::max(); T minimum = numeric_limits<T>::max();
for(int d = 0; d < nr_plane; d++) for(int d = 0; d < nr_plane; d++)
{ {
int idx = d * cdisp_step1; int idx = d * disp_step;
T val = data[idx] + msg1[idx] + msg2[idx] + msg3[idx]; T val = data[idx] + msg1[idx] + msg2[idx] + msg3[idx];
if(val < minimum) if(val < minimum)
...@@ -681,43 +682,43 @@ namespace cv { namespace cuda { namespace device ...@@ -681,43 +682,43 @@ namespace cv { namespace cuda { namespace device
for(int d = 0; d < nr_plane; d++) for(int d = 0; d < nr_plane; d++)
{ {
float cost_min = minimum + max_disc_term; float cost_min = minimum + max_disc_term;
T src_disp_reg = src_disp[d * cdisp_step1]; T src_disp_reg = src_disp[d * disp_step];
for(int d2 = 0; d2 < nr_plane; d2++) for(int d2 = 0; d2 < nr_plane; d2++)
cost_min = fmin(cost_min, msg_dst[d2 * cdisp_step1] + disc_single_jump * ::abs(dst_disp[d2 * cdisp_step1] - src_disp_reg)); cost_min = fmin(cost_min, msg_dst[d2 * disp_step] + disc_single_jump * ::abs(dst_disp[d2 * disp_step] - src_disp_reg));
temp[d * cdisp_step1] = saturate_cast<T>(cost_min); temp[d * disp_step] = saturate_cast<T>(cost_min);
sum += cost_min; sum += cost_min;
} }
sum /= nr_plane; sum /= nr_plane;
for(int d = 0; d < nr_plane; d++) for(int d = 0; d < nr_plane; d++)
msg_dst[d * cdisp_step1] = saturate_cast<T>(temp[d * cdisp_step1] - sum); msg_dst[d * disp_step] = saturate_cast<T>(temp[d * disp_step] - sum);
} }
template <typename T> template <typename T>
__global__ void compute_message(uchar *ctemp, T* u_, T* d_, T* l_, T* r_, const T* data_cost_selected, const T* selected_disp_pyr_cur, int h, int w, int nr_plane, int i, int max_disc_term, float disc_single_jump) __global__ void compute_message(uchar *ctemp, T* u_, T* d_, T* l_, T* r_, const T* data_cost_selected, const T* selected_disp_pyr_cur, int h, int w, int nr_plane, int i, int max_disc_term, float disc_single_jump, size_t msg_step, size_t disp_step)
{ {
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int x = ((blockIdx.x * blockDim.x + threadIdx.x) << 1) + ((y + i) & 1); int x = ((blockIdx.x * blockDim.x + threadIdx.x) << 1) + ((y + i) & 1);
if (y > 0 && y < h - 1 && x > 0 && x < w - 1) if (y > 0 && y < h - 1 && x > 0 && x < w - 1)
{ {
const T* data = data_cost_selected + y * cmsg_step + x; const T* data = data_cost_selected + y * msg_step + x;
T* u = u_ + y * cmsg_step + x; T* u = u_ + y * msg_step + x;
T* d = d_ + y * cmsg_step + x; T* d = d_ + y * msg_step + x;
T* l = l_ + y * cmsg_step + x; T* l = l_ + y * msg_step + x;
T* r = r_ + y * cmsg_step + x; T* r = r_ + y * msg_step + x;
const T* disp = selected_disp_pyr_cur + y * cmsg_step + x; const T* disp = selected_disp_pyr_cur + y * msg_step + x;
T* temp = (T*)ctemp + y * cmsg_step + x; T* temp = (T*)ctemp + y * msg_step + x;
message_per_pixel(data, u, r - 1, u + cmsg_step, l + 1, disp, disp - cmsg_step, nr_plane, max_disc_term, disc_single_jump, temp); message_per_pixel(data, u, r - 1, u + msg_step, l + 1, disp, disp - msg_step, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
message_per_pixel(data, d, d - cmsg_step, r - 1, l + 1, disp, disp + cmsg_step, nr_plane, max_disc_term, disc_single_jump, temp); message_per_pixel(data, d, d - msg_step, r - 1, l + 1, disp, disp + msg_step, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
message_per_pixel(data, l, u + cmsg_step, d - cmsg_step, l + 1, disp, disp - 1, nr_plane, max_disc_term, disc_single_jump, temp); message_per_pixel(data, l, u + msg_step, d - msg_step, l + 1, disp, disp - 1, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
message_per_pixel(data, r, u + cmsg_step, d - cmsg_step, r - 1, disp, disp + 1, nr_plane, max_disc_term, disc_single_jump, temp); message_per_pixel(data, r, u + msg_step, d - msg_step, r - 1, disp, disp + 1, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
} }
} }
...@@ -727,8 +728,6 @@ namespace cv { namespace cuda { namespace device ...@@ -727,8 +728,6 @@ namespace cv { namespace cuda { namespace device
const T* selected_disp_pyr_cur, size_t msg_step, int h, int w, int nr_plane, int iters, int max_disc_term, float disc_single_jump, cudaStream_t stream) const T* selected_disp_pyr_cur, size_t msg_step, int h, int w, int nr_plane, int iters, int max_disc_term, float disc_single_jump, cudaStream_t stream)
{ {
size_t disp_step = msg_step * h; size_t disp_step = msg_step * h;
cudaSafeCall( cudaMemcpyToSymbol(cdisp_step1, &disp_step, sizeof(size_t)) );
cudaSafeCall( cudaMemcpyToSymbol(cmsg_step, &msg_step, sizeof(size_t)) );
dim3 threads(32, 8, 1); dim3 threads(32, 8, 1);
dim3 grid(1, 1, 1); dim3 grid(1, 1, 1);
...@@ -738,7 +737,7 @@ namespace cv { namespace cuda { namespace device ...@@ -738,7 +737,7 @@ namespace cv { namespace cuda { namespace device
for(int t = 0; t < iters; ++t) for(int t = 0; t < iters; ++t)
{ {
compute_message<<<grid, threads, 0, stream>>>(ctemp, u, d, l, r, data_cost_selected, selected_disp_pyr_cur, h, w, nr_plane, t & 1, max_disc_term, disc_single_jump); compute_message<<<grid, threads, 0, stream>>>(ctemp, u, d, l, r, data_cost_selected, selected_disp_pyr_cur, h, w, nr_plane, t & 1, max_disc_term, disc_single_jump, msg_step, disp_step);
cudaSafeCall( cudaGetLastError() ); cudaSafeCall( cudaGetLastError() );
} }
if (stream == 0) if (stream == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment