|
| 1 | +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#ifndef HL_WARPCTC_WRAP_H_ |
| 16 | +#define HL_WARPCTC_WRAP_H_ |
| 17 | + |
| 18 | +#include "hl_base.h" |
| 19 | +#include "warp-ctc/include/ctc.h" |
| 20 | + |
| 21 | +typedef ctcStatus_t hl_warpctc_status_t; |
| 22 | +typedef ctcOptions hl_warpctc_options_t; |
| 23 | + |
| 24 | +/** |
| 25 | + * @brief Init ctc options. |
| 26 | + * |
| 27 | + * @param[in] blank blank label used in ctc loss function. |
| 28 | + * @param[in] useGpu whether use gpu. |
| 29 | + * @param[out] options handle to store cpu or gpu informations. |
| 30 | + * |
| 31 | + */ |
| 32 | +extern void hl_warpctc_init(const size_t blank, |
| 33 | + bool useGpu, |
| 34 | + hl_warpctc_options_t* options); |
| 35 | + |
| 36 | +/** |
| 37 | + * @brief Compute the connectionist temporal classification loss, |
| 38 | + * and optionally compute the gradient with respect to the inputs. |
| 39 | + * |
| 40 | + * if batchGrad == nullptr |
| 41 | + * |
| 42 | + * only compute the ctc loss. |
| 43 | + * |
| 44 | + * if batchGrad != nullptr |
| 45 | + * |
| 46 | + * compute both ctc loss and gradient. |
| 47 | + * |
| 48 | + * @param[in] batchInput batch matrix of input probabilities, |
| 49 | + * in maxSequenceLength x numSequence x numClasses |
| 50 | + * (row-major) format. |
| 51 | + * @param[out] batchGrad batch matrix of gradient. |
| 52 | + * @param[in] cpuLabels labels always in CPU memory. |
| 53 | + * @param[in] cpuLabelLengths length of all labels in CPU memory. |
| 54 | + * @param[in] cpuInputLengths length of all sequences in CPU memory. |
| 55 | + * @param[in] numClasses number of possible output symbols. |
| 56 | + * @param[in] numSequences number of sequence. |
| 57 | + * @param[out] cpuCosts cost of each sequence in CPU memory. |
| 58 | + * @param[out] workspace workspace to store some temporary results. |
| 59 | + * @param[in] options handle to store cpu or gpu informations. |
| 60 | + * |
| 61 | + */ |
| 62 | +extern void hl_warpctc_compute_loss(const real* batchInput, |
| 63 | + real* batchGrad, |
| 64 | + const int* cpuLabels, |
| 65 | + const int* cpuLabelLengths, |
| 66 | + const int* cpuInputLengths, |
| 67 | + const size_t numClasses, |
| 68 | + const size_t numSequences, |
| 69 | + real* cpuCosts, |
| 70 | + void* workspace, |
| 71 | + hl_warpctc_options_t* options); |
| 72 | + |
| 73 | +/** |
| 74 | + * @brief Compute the required workspace size. |
| 75 | + * There is no memory allocated operations within warp-ctc. |
| 76 | + * |
| 77 | + * @param[in] cpuLabelLengths length of all labels in CPU memory. |
| 78 | + * @param[in] cpuInputLengths length of all sequences in CPU memory. |
| 79 | + * @param[in] numClasses number of possible output symbols. |
| 80 | + * @param[in] numSequences number of sequence. |
| 81 | + * @param[in] options handle to store cpu or gpu informations. |
| 82 | + * @param[out] bytes pointer to a scalar where the memory |
| 83 | + * requirement in bytes will be placed. |
| 84 | + * |
| 85 | + */ |
| 86 | +extern void hl_warpctc_get_workspace_size(const int* cpuLabelLengths, |
| 87 | + const int* cpuInputLengths, |
| 88 | + const size_t numClasses, |
| 89 | + const size_t numSequences, |
| 90 | + hl_warpctc_options_t* options, |
| 91 | + size_t* bytes); |
| 92 | + |
| 93 | +#endif // HL_WARPCTC_WRAP_H_ |
0 commit comments