@@ -175,11 +175,15 @@ void hl_cudnn_init(cudnnHandle_t* cudnn_handle, cudaStream_t stream) {
175175 << " PaddlePaddle Requirement: "
176176 << " (header v[2-3] with libcudnn v[2-3]) Or "
177177 << " (header v4 with libcudnn v4) Or "
178- << " (header v5 with libcudnn v5)." ;
178+ << " (header v5 with libcudnn v5) Or"
179+ << " (header v6 with libcudnn v6)." ;
179180
180- CHECK (!(CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050 ))
181+ CHECK (!(CUDNN_VERSION < 6000 && CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050 ))
181182 << " cudnn v5 requires cuda version >= 7.5" ;
182183
184+ CHECK (!(CUDNN_VERSION >= 6000 && CUDA_VERSION < 8000 ))
185+ << " cudnn v6 requires cuda version >= 8.0" ;
186+
183187 CHECK_CUDNN (dynload::cudnnCreate (cudnn_handle));
184188 CHECK_CUDNN (dynload::cudnnSetStream (*cudnn_handle, stream));
185189
@@ -610,6 +614,23 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
610614 CHECK_CUDNN (dynload::cudnnCreateConvolutionDescriptor (&hl_conv->desc ));
611615
612616 cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
617+
618+ #if CUDNN_VERSION >= 6000
619+ #ifndef PADDLE_TYPE_DOUBLE
620+ cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
621+ #else
622+ cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
623+ #endif
624+ CHECK_CUDNN (dynload::cudnnSetConvolution2dDescriptor (hl_conv->desc ,
625+ padding_height,
626+ padding_width,
627+ stride_height,
628+ stride_width,
629+ 1 ,
630+ 1 ,
631+ mode,
632+ data_type));
633+ #else
613634 CHECK_CUDNN (dynload::cudnnSetConvolution2dDescriptor (hl_conv->desc ,
614635 padding_height,
615636 padding_width,
@@ -618,6 +639,7 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
618639 1 ,
619640 1 ,
620641 mode));
642+ #endif
621643
622644 hl_conv->input_image = image;
623645 hl_conv->filter = filter;
@@ -645,6 +667,23 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
645667
646668 cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR (conv);
647669 cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
670+
671+ #if CUDNN_VERSION >= 6000
672+ #ifndef PADDLE_TYPE_DOUBLE
673+ cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
674+ #else
675+ cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
676+ #endif
677+ CHECK_CUDNN (dynload::cudnnSetConvolution2dDescriptor (conv_desc,
678+ padding_height,
679+ padding_width,
680+ stride_height,
681+ stride_width,
682+ 1 ,
683+ 1 ,
684+ mode,
685+ data_type));
686+ #else
648687 CHECK_CUDNN (dynload::cudnnSetConvolution2dDescriptor (conv_desc,
649688 padding_height,
650689 padding_width,
@@ -653,6 +692,7 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
653692 1 ,
654693 1 ,
655694 mode));
695+ #endif
656696
657697 cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
658698 hl_conv->input_image = image;
0 commit comments