|
13 | 13 |
|
14 | 14 |
|
15 | 15 | class Settings(BaseSettings): |
16 | | - model: str |
17 | | - n_ctx: int = 2048 |
18 | | - n_batch: int = 512 |
19 | | - n_threads: int = max((os.cpu_count() or 2) // 2, 1) |
20 | | - f16_kv: bool = True |
21 | | - use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out... |
22 | | - use_mmap: bool = True |
23 | | - embedding: bool = True |
24 | | - last_n_tokens_size: int = 64 |
25 | | - logits_all: bool = False |
26 | | - cache: bool = False # WARNING: This is an experimental feature |
27 | | - vocab_only: bool = False |
| 16 | + model: str = Field( |
| 17 | + description="The path to the model to use for generating completions." |
| 18 | + ) |
| 19 | + n_ctx: int = Field(default=2048, ge=1, description="The context size.") |
| 20 | + n_batch: int = Field( |
| 21 | + default=512, ge=1, description="The batch size to use per eval." |
| 22 | + ) |
| 23 | + n_threads: int = Field( |
| 24 | + default=max((os.cpu_count() or 2) // 2, 1), |
| 25 | + ge=1, |
| 26 | + description="The number of threads to use.", |
| 27 | + ) |
| 28 | + f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") |
| 29 | + use_mlock: bool = Field( |
| 30 | + default=bool(llama_cpp.llama_mlock_supported().value), |
| 31 | + description="Use mlock.", |
| 32 | + ) |
| 33 | + use_mmap: bool = Field( |
| 34 | + default=bool(llama_cpp.llama_mmap_supported().value), |
| 35 | + description="Use mmap.", |
| 36 | + ) |
| 37 | + embedding: bool = Field(default=True, description="Whether to use embeddings.") |
| 38 | + last_n_tokens_size: int = Field( |
| 39 | + default=64, |
| 40 | + ge=0, |
| 41 | + description="Last n tokens to keep for repeat penalty calculation.", |
| 42 | + ) |
| 43 | + logits_all: bool = Field(default=True, description="Whether to return logits.") |
| 44 | + cache: bool = Field( |
| 45 | + default=False, |
| 46 | + description="Use a cache to reduce processing times for evaluated prompts.", |
| 47 | + ) |
| 48 | + vocab_only: bool = Field( |
| 49 | + default=False, description="Whether to only return the vocabulary." |
| 50 | + ) |
28 | 51 |
|
29 | 52 |
|
30 | 53 | router = APIRouter() |
@@ -74,79 +97,75 @@ def get_llama(): |
74 | 97 | with llama_lock: |
75 | 98 | yield llama |
76 | 99 |
|
77 | | -model_field = Field( |
78 | | - description="The model to use for generating completions." |
79 | | -) |
| 100 | + |
| 101 | +model_field = Field(description="The model to use for generating completions.") |
80 | 102 |
|
81 | 103 | max_tokens_field = Field( |
82 | | - default=16, |
83 | | - ge=1, |
84 | | - le=2048, |
85 | | - description="The maximum number of tokens to generate." |
| 104 | + default=16, ge=1, le=2048, description="The maximum number of tokens to generate." |
86 | 105 | ) |
87 | 106 |
|
88 | 107 | temperature_field = Field( |
89 | 108 | default=0.8, |
90 | 109 | ge=0.0, |
91 | 110 | le=2.0, |
92 | | - description="Adjust the randomness of the generated text.\n\n" + |
93 | | - "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run." |
| 111 | + description="Adjust the randomness of the generated text.\n\n" |
| 112 | + + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.", |
94 | 113 | ) |
95 | 114 |
|
96 | 115 | top_p_field = Field( |
97 | 116 | default=0.95, |
98 | 117 | ge=0.0, |
99 | 118 | le=1.0, |
100 | | - description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" + |
101 | | - "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text." |
| 119 | + description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" |
| 120 | + + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", |
102 | 121 | ) |
103 | 122 |
|
104 | 123 | stop_field = Field( |
105 | 124 | default=None, |
106 | | - description="A list of tokens at which to stop generation. If None, no stop tokens are used." |
| 125 | + description="A list of tokens at which to stop generation. If None, no stop tokens are used.", |
107 | 126 | ) |
108 | 127 |
|
109 | 128 | stream_field = Field( |
110 | 129 | default=False, |
111 | | - description="Whether to stream the results as they are generated. Useful for chatbots." |
| 130 | + description="Whether to stream the results as they are generated. Useful for chatbots.", |
112 | 131 | ) |
113 | 132 |
|
114 | 133 | top_k_field = Field( |
115 | 134 | default=40, |
116 | 135 | ge=0, |
117 | | - description="Limit the next token selection to the K most probable tokens.\n\n" + |
118 | | - "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text." |
| 136 | + description="Limit the next token selection to the K most probable tokens.\n\n" |
| 137 | + + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.", |
119 | 138 | ) |
120 | 139 |
|
121 | 140 | repeat_penalty_field = Field( |
122 | 141 | default=1.0, |
123 | 142 | ge=0.0, |
124 | | - description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" + |
125 | | - "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient." |
| 143 | + description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" |
| 144 | + + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", |
126 | 145 | ) |
127 | 146 |
|
| 147 | + |
128 | 148 | class CreateCompletionRequest(BaseModel): |
129 | 149 | prompt: Optional[str] = Field( |
130 | | - default="", |
131 | | - description="The prompt to generate completions for." |
| 150 | + default="", description="The prompt to generate completions for." |
132 | 151 | ) |
133 | 152 | suffix: Optional[str] = Field( |
134 | 153 | default=None, |
135 | | - description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots." |
| 154 | + description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", |
136 | 155 | ) |
137 | 156 | max_tokens: int = max_tokens_field |
138 | 157 | temperature: float = temperature_field |
139 | 158 | top_p: float = top_p_field |
140 | 159 | echo: bool = Field( |
141 | 160 | default=False, |
142 | | - description="Whether to echo the prompt in the generated text. Useful for chatbots." |
| 161 | + description="Whether to echo the prompt in the generated text. Useful for chatbots.", |
143 | 162 | ) |
144 | 163 | stop: Optional[List[str]] = stop_field |
145 | 164 | stream: bool = stream_field |
146 | 165 | logprobs: Optional[int] = Field( |
147 | 166 | default=None, |
148 | 167 | ge=0, |
149 | | - description="The number of logprobs to generate. If None, no logprobs are generated." |
| 168 | + description="The number of logprobs to generate. If None, no logprobs are generated.", |
150 | 169 | ) |
151 | 170 |
|
152 | 171 | # ignored or currently unsupported |
@@ -204,9 +223,7 @@ def create_completion( |
204 | 223 |
|
205 | 224 | class CreateEmbeddingRequest(BaseModel): |
206 | 225 | model: Optional[str] = model_field |
207 | | - input: str = Field( |
208 | | - description="The input to embed." |
209 | | - ) |
| 226 | + input: str = Field(description="The input to embed.") |
210 | 227 | user: Optional[str] |
211 | 228 |
|
212 | 229 | class Config: |
@@ -239,8 +256,7 @@ class ChatCompletionRequestMessage(BaseModel): |
239 | 256 |
|
240 | 257 | class CreateChatCompletionRequest(BaseModel): |
241 | 258 | messages: List[ChatCompletionRequestMessage] = Field( |
242 | | - default=[], |
243 | | - description="A list of messages to generate completions for." |
| 259 | + default=[], description="A list of messages to generate completions for." |
244 | 260 | ) |
245 | 261 | max_tokens: int = max_tokens_field |
246 | 262 | temperature: float = temperature_field |
|
0 commit comments