Skip to content

Commit 4823075

Browse files
authored
Merge pull request #651 from Xreki/warpctc
Integrate warp-ctc as WarpCTCLayer, including unit test and layer interface.
2 parents adc23f6 + 78bdd32 commit 4823075

26 files changed

+1154
-27
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "warp-ctc"]
2+
path = warp-ctc
3+
url = https://github.com/baidu-research/warp-ctc.git

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
sha: c25201a00e6b0514370501050cf2a8538ac12270
33
hooks:
44
- id: remove-crlf
5+
files: (?!.*warp-ctc)^.*$
56
- repo: https://github.com/reyoung/mirrors-yapf.git
67
sha: v0.13.2
78
hooks:
@@ -13,6 +14,7 @@
1314
- id: check-merge-conflict
1415
- id: check-symlinks
1516
- id: detect-private-key
17+
files: (?!.*warp-ctc)^.*$
1618
- id: end-of-file-fixer
1719
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
1820
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ find_package(Git REQUIRED)
7777
include(version)
7878
add_definitions(-DPADDLE_VERSION=\"${PADDLE_VERSION}\")
7979

80-
8180
if(NOT WITH_GPU)
8281
add_definitions(-DPADDLE_ONLY_CPU)
8382
add_definitions(-DHPPL_STUB_FUNC)
83+
8484
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
8585
else()
8686
if(${CUDA_VERSION_MAJOR} GREATER 6)
@@ -102,15 +102,15 @@ else()
102102
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SSE3_FLAG}")
103103
endif(WITH_AVX)
104104

105-
if(WITH_DSO)
106-
add_definitions(-DPADDLE_USE_DSO)
107-
endif(WITH_DSO)
108-
109105
# Include cuda and cudnn
110106
include_directories(${CUDNN_INCLUDE_DIR})
111107
include_directories(${CUDA_TOOLKIT_INCLUDE})
112108
endif(NOT WITH_GPU)
113109

110+
if(WITH_DSO)
111+
add_definitions(-DPADDLE_USE_DSO)
112+
endif(WITH_DSO)
113+
114114
if(WITH_DOUBLE)
115115
add_definitions(-DPADDLE_TYPE_DOUBLE)
116116
set(ACCURACY double)

cmake/util.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ function(link_paddle_exe TARGET_NAME)
148148
target_link_libraries(${TARGET_NAME} rt)
149149
endif()
150150
endif()
151+
152+
if(NOT WITH_DSO)
153+
target_link_libraries(${TARGET_NAME}
154+
${WARPCTC_LIBRARY})
155+
endif()
151156
endfunction()
152157

153158
# link_paddle_test

paddle/cuda/CMakeLists.txt

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,28 @@ else()
1515
endif()
1616

1717
set(CUDA_CXX_WITH_GPU_SOURCES
18+
src/hl_cudart_wrap.cc
1819
src/hl_cuda_cublas.cc
1920
src/hl_cuda_cudnn.cc
2021
src/hl_cuda_device.cc)
2122

22-
set_source_files_properties(${CUDA_CXX_WITH_GPU_SOURCES}
23-
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
23+
if(WITH_GPU)
24+
set(CUDA_CXX_SOURCES
25+
src/hl_dso_loader.cc
26+
src/hl_warpctc_wrap.cc
27+
${CUDA_CXX_WITH_GPU_SOURCES})
28+
29+
set_source_files_properties(${CUDA_CXX_SOURCES}
30+
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
31+
else()
32+
set(CUDA_CXX_SOURCES
33+
src/hl_dso_loader.cc
34+
src/hl_warpctc_wrap.cc)
35+
endif()
2436

2537
set_source_files_properties(${AVX_SOURCES}
2638
PROPERTIES COMPILE_FLAGS "-mavx")
2739

28-
set(CUDA_DSO_SOURCES
29-
src/hl_dso_loader.cc
30-
src/hl_cudart_wrap.cc)
31-
3240
set(CUDA_CU_SOURCES
3341
src/hl_perturbation_util.cu
3442
src/hl_cuda_aggregate.cu
@@ -44,6 +52,7 @@ set(CUDA_CU_SOURCES
4452
set(CUDA_HEADERS
4553
include/hl_time.h
4654
include/hl_dso_loader.h
55+
include/hl_warpctc_wrap.h
4756
include/hl_sequence.h
4857
include/hl_cuda_cublas.h
4958
include/hl_batch_transpose.h
@@ -75,14 +84,14 @@ if(WITH_GPU)
7584
cuda_add_library(paddle_cuda
7685
${CUDA_SOURCES}
7786
${CUDA_CU_SOURCES}
78-
${CUDA_DSO_SOURCES}
79-
${CUDA_CXX_WITH_GPU_SOURCES})
87+
${CUDA_CXX_SOURCES})
8088
else()
81-
add_library(paddle_cuda ${CUDA_SOURCES})
89+
add_library(paddle_cuda
90+
${CUDA_SOURCES}
91+
${CUDA_CXX_SOURCES})
8292
endif()
8393

8494
add_style_check_target(paddle_cuda
8595
${CUDA_SOURCES}
8696
${CUDA_HEADERS}
87-
${CUDA_DSO_SOURCES}
88-
${CUDA_CXX_WITH_GPU_SOURCES})
97+
${CUDA_CXX_SOURCES})

paddle/cuda/include/hl_dso_loader.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ limitations under the License. */
1818
#include <dlfcn.h>
1919
#include <string>
2020
#include <memory>
21-
#include <cuda_runtime.h>
22-
#include <cublas_v2.h>
23-
#include <curand.h>
24-
#include <cudnn.h>
2521
#include "hl_base.h"
2622

2723
/**
@@ -56,4 +52,12 @@ void GetCudartDsoHandle(void** dso_handle);
5652
*/
5753
void GetCurandDsoHandle(void** dso_handle);
5854

55+
/**
56+
* @brief load the DSO of warp-ctc
57+
*
58+
* @param **dso_handle dso handler
59+
*
60+
*/
61+
void GetWarpCTCDsoHandle(void** dso_handle);
62+
5963
#endif // HL_DSO_LOADER_H_

paddle/cuda/include/hl_gpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License. */
2525
#include "hl_sparse.h"
2626
#include "hl_lstm.h"
2727
#include "hl_sequence.h"
28+
#include "hl_warpctc_wrap.h"
2829

2930
#ifdef HPPL_STUB_FUNC
3031
#include "stub/hl_cuda_stub.h"

paddle/cuda/include/hl_sequence.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,39 @@ extern void hl_sequence2batch_add(real* batch,
172172
int batchCount,
173173
bool seq2batch);
174174

175+
/**
176+
* @brief Memory copy from sequence to batch,
177+
* while padding all sequences to the same length.
178+
*
179+
* if seq2batch == true
180+
*
181+
* copy from sequence to batch:
182+
* batch[i] = sequence[sequenceStartPositions[i]]
183+
*
184+
* if seq2batch == false
185+
*
186+
* copy from batch to sequence:
187+
* sequence[sequenceStartPositions[i]] = batch[i]
188+
*
189+
* @param[in,out] batch batch matrix.
190+
* @param[in,out] sequence sequence matrix.
191+
* @param[in] sequenceStartPositions index vector.
192+
* @param[in] sequenceWidth width of sequence.
193+
* @param[in] maxSequenceLength maximum length of sequences.
194+
* @param[in] numSequences number of sequences.
195+
* @param[in] normByTimes whether dividing sequence's length.
196+
* @param[in] seq2batch copy direction.
197+
*
198+
*/
199+
extern void hl_sequence2batch_copy_padding(real* batch,
200+
real* sequence,
201+
const int* sequenceStartPositions,
202+
const size_t sequenceWidth,
203+
const size_t maxSequenceLength,
204+
const size_t numSequences,
205+
bool normByTimes,
206+
bool seq2batch);
207+
175208
/**
176209
* @brief dst = Op(src), src is sequence.
177210
*
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifndef HL_WARPCTC_WRAP_H_
16+
#define HL_WARPCTC_WRAP_H_
17+
18+
#include "hl_base.h"
19+
#include "warp-ctc/include/ctc.h"
20+
21+
typedef ctcStatus_t hl_warpctc_status_t;
22+
typedef ctcOptions hl_warpctc_options_t;
23+
24+
/**
25+
* @brief Init ctc options.
26+
*
27+
* @param[in] blank blank label used in ctc loss function.
28+
* @param[in] useGpu whether use gpu.
29+
* @param[out] options handle to store cpu or gpu informations.
30+
*
31+
*/
32+
extern void hl_warpctc_init(const size_t blank,
33+
bool useGpu,
34+
hl_warpctc_options_t* options);
35+
36+
/**
37+
* @brief Compute the connectionist temporal classification loss,
38+
* and optionally compute the gradient with respect to the inputs.
39+
*
40+
* if batchGrad == nullptr
41+
*
42+
* only compute the ctc loss.
43+
*
44+
* if batchGrad != nullptr
45+
*
46+
* compute both ctc loss and gradient.
47+
*
48+
* @param[in] batchInput batch matrix of input probabilities,
49+
* in maxSequenceLength x numSequence x numClasses
50+
* (row-major) format.
51+
* @param[out] batchGrad batch matrix of gradient.
52+
* @param[in] cpuLabels labels always in CPU memory.
53+
* @param[in] cpuLabelLengths length of all labels in CPU memory.
54+
* @param[in] cpuInputLengths length of all sequences in CPU memory.
55+
* @param[in] numClasses number of possible output symbols.
56+
* @param[in] numSequences number of sequence.
57+
* @param[out] cpuCosts cost of each sequence in CPU memory.
58+
* @param[out] workspace workspace to store some temporary results.
59+
* @param[in] options handle to store cpu or gpu informations.
60+
*
61+
*/
62+
extern void hl_warpctc_compute_loss(const real* batchInput,
63+
real* batchGrad,
64+
const int* cpuLabels,
65+
const int* cpuLabelLengths,
66+
const int* cpuInputLengths,
67+
const size_t numClasses,
68+
const size_t numSequences,
69+
real* cpuCosts,
70+
void* workspace,
71+
hl_warpctc_options_t* options);
72+
73+
/**
74+
* @brief Compute the required workspace size.
75+
* There is no memory allocated operations within warp-ctc.
76+
*
77+
* @param[in] cpuLabelLengths length of all labels in CPU memory.
78+
* @param[in] cpuInputLengths length of all sequences in CPU memory.
79+
* @param[in] numClasses number of possible output symbols.
80+
* @param[in] numSequences number of sequence.
81+
* @param[in] options handle to store cpu or gpu informations.
82+
* @param[out] bytes pointer to a scalar where the memory
83+
* requirement in bytes will be placed.
84+
*
85+
*/
86+
extern void hl_warpctc_get_workspace_size(const int* cpuLabelLengths,
87+
const int* cpuInputLengths,
88+
const size_t numClasses,
89+
const size_t numSequences,
90+
hl_warpctc_options_t* options,
91+
size_t* bytes);
92+
93+
#endif // HL_WARPCTC_WRAP_H_

paddle/cuda/include/stub/hl_sequence_stub.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ inline void hl_sequence2batch_add(real* batch,
7070
int batchCount,
7171
bool seq2batch) {}
7272

73+
inline void hl_sequence2batch_copy_padding(real* batch,
74+
real* sequence,
75+
const int* sequenceStartPositions,
76+
const size_t sequenceWidth,
77+
const size_t maxSequenceLength,
78+
const size_t numSequences,
79+
bool normByTimes,
80+
bool seq2batch) {}
81+
7382
inline void hl_sequence_avg_forward(real* dst,
7483
real* src,
7584
const int* starts,

0 commit comments

Comments
 (0)