Skip to content

Commit cc00104

Browse files
author
ytian218
committed
server: validate n_batch == n_ubatch for embeddings (#6263)
Fixes #6263 where server accepts mismatched batch/ubatch values with embeddings, leading to suboptimal or incorrect behavior. Problem: Embeddings and reranking use non-causal attention which requires all tokens to be processed within a single ubatch. When n_batch != n_ubatch, the configuration is incoherent. Default values differ (n_batch=2048, n_ubatch=512), so users encounter this frequently. Solution: - Add parameter validation in main() after common_params_parse() - When embeddings enabled and n_batch != n_ubatch: * Log warnings explaining the requirement * Automatically set both to min(n_batch, n_ubatch) * Ensure coherent configuration This follows the auto-correction approach suggested by @mirekphd and provides better UX than strict rejection. Testing: ✅ Builds successfully ✅ Validation triggers: -b 2048 -ub 512 --embedding → logs warnings, adjusts both to 512 ✅ No false positives: -b 512 -ub 512 --embedding → no warnings ✅ Verified on macOS M3 Pro with embedding model
1 parent 583cb83 commit cc00104

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tools/server/server.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3657,6 +3657,17 @@ int main(int argc, char ** argv) {
36573657
return 1;
36583658
}
36593659

3660+
// validate batch size for embeddings and reranking
3661+
// non-causal attention (embeddings/reranking) requires n_batch == n_ubatch
3662+
// see https://github.com/ggml-org/llama.cpp/issues/6263
3663+
if (params.embedding && params.n_batch != params.n_ubatch) {
3664+
LOG_WRN("%s: embeddings/reranking mode requires n_batch == n_ubatch\n", __func__);
3665+
LOG_WRN("%s: setting both to min(%d, %d) = %d to avoid configuration issues\n",
3666+
__func__, params.n_batch, params.n_ubatch,
3667+
std::min(params.n_batch, params.n_ubatch));
3668+
params.n_batch = params.n_ubatch = std::min(params.n_batch, params.n_ubatch);
3669+
}
3670+
36603671
// TODO: should we have a separate n_parallel parameter for the server?
36613672
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
36623673
// TODO: this is a common configuration that is suitable for most local use cases

0 commit comments

Comments
 (0)