Skip to content

Commit 10fe308

Browse files
authored
Merge pull request #949 from ccx0912/MY_COOL_STUFF_BRANCH
Added support for cudnn v6 and cuda 8.0
2 parents e823c95 + 18ebeec commit 10fe308

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

paddle/cuda/src/hl_cuda_cudnn.cc

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,15 @@ void hl_cudnn_init(cudnnHandle_t* cudnn_handle, cudaStream_t stream) {
175175
<< "PaddlePaddle Requirement: "
176176
<< "(header v[2-3] with libcudnn v[2-3]) Or "
177177
<< "(header v4 with libcudnn v4) Or "
178-
<< "(header v5 with libcudnn v5).";
178+
<< "(header v5 with libcudnn v5) Or"
179+
<< "(header v6 with libcudnn v6).";
179180

180-
CHECK(!(CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
181+
CHECK(!(CUDNN_VERSION < 6000 && CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
181182
<< "cudnn v5 requires cuda version >= 7.5";
182183

184+
CHECK(!(CUDNN_VERSION >= 6000 && CUDA_VERSION < 8000))
185+
<< "cudnn v6 requires cuda version >= 8.0";
186+
183187
CHECK_CUDNN(dynload::cudnnCreate(cudnn_handle));
184188
CHECK_CUDNN(dynload::cudnnSetStream(*cudnn_handle, stream));
185189

@@ -610,6 +614,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
610614
CHECK_CUDNN(dynload::cudnnCreateConvolutionDescriptor(&hl_conv->desc));
611615

612616
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
617+
618+
#if CUDNN_VERSION >= 6000
619+
#ifndef PADDLE_TYPE_DOUBLE
620+
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
621+
#else
622+
cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
623+
#endif
624+
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
625+
padding_height,
626+
padding_width,
627+
stride_height,
628+
stride_width,
629+
1,
630+
1,
631+
mode,
632+
data_type));
633+
#else
613634
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
614635
padding_height,
615636
padding_width,
@@ -618,6 +639,7 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
618639
1,
619640
1,
620641
mode));
642+
#endif
621643

622644
hl_conv->input_image = image;
623645
hl_conv->filter = filter;
@@ -645,6 +667,23 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
645667

646668
cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
647669
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
670+
671+
#if CUDNN_VERSION >= 6000
672+
#ifndef PADDLE_TYPE_DOUBLE
673+
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
674+
#else
675+
cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
676+
#endif
677+
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
678+
padding_height,
679+
padding_width,
680+
stride_height,
681+
stride_width,
682+
1,
683+
1,
684+
mode,
685+
data_type));
686+
#else
648687
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
649688
padding_height,
650689
padding_width,
@@ -653,6 +692,7 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
653692
1,
654693
1,
655694
mode));
695+
#endif
656696

657697
cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
658698
hl_conv->input_image = image;

0 commit comments

Comments
 (0)