Commit a41c1baa authored by Fenglei's avatar Fenglei Committed by Robert Kimball

add back missing part (#1785)

parent 6cd35432
......@@ -1179,7 +1179,31 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const std::string& dtype,
auto& filter_desc = get_cudnn_filter_descriptor(input_filter_shape, data_type, tensor_format);
auto& conv_desc = get_cudnn_convolution_descriptor(
padding_below, window_movement_strides, window_dilation_strides, mode, data_type);
const cudnnConvolutionFwdAlgo_t conv_fwd_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
cudnnConvolutionFwdAlgo_t conv_fwd_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
if (find_algo != algo_search::NONE)
{
int num_algos;
int max_algos = 0;
CUDNN_SAFE_CALL(
cudnnGetConvolutionForwardAlgorithmMaxCount(*m_ctx->cudnn_handle, &max_algos));
std::vector<cudnnConvolutionFwdAlgoPerf_t> results(max_algos);
auto cudnn_algo_search = (find_algo == algo_search::EXPLICIT)
? cudnnFindConvolutionForwardAlgorithm
: cudnnGetConvolutionForwardAlgorithm_v7;
CUDNN_SAFE_CALL((*cudnn_algo_search)(*m_ctx->cudnn_handle,
tensor_desc_0,
filter_desc,
conv_desc,
tensor_desc_1,
static_cast<int>(results.size()),
&num_algos,
results.data()));
results.resize(num_algos);
conv_fwd_algo =
select_cudnn_algo<cudnnConvolutionFwdAlgoPerf_t, cudnnConvolutionFwdAlgo_t>(results);
}
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment