|
1 | | -import anyio |
| 1 | +from typing import Optional, AsyncGenerator |
| 2 | + |
2 | 3 | import httpx |
3 | 4 | import orjson |
4 | 5 |
|
5 | | -API_URL = "http://localhost:8000/chat/" |
6 | | - |
7 | 6 |
|
8 | | -async def chat_with_endpoint(): |
9 | | - async with httpx.AsyncClient() as client: |
10 | | - while True: |
11 | | - prompt = input("\nYou: ") |
12 | | - if prompt.lower() == "exit": |
13 | | - break |
| 7 | +class StreamLLMService: |
| 8 | + def __init__(self, base_url: str = "http://localhost:11434/v1"): |
| 9 | + self.base_url = base_url |
| 10 | + self.model = "llama3.2" |
14 | 11 |
|
15 | | - print("\nModel: ", end="", flush=True) |
16 | | - try: |
17 | | - async with client.stream( |
18 | | - "POST", API_URL, data={"prompt": prompt}, timeout=60 |
19 | | - ) as response: |
20 | | - async for chunk in response.aiter_lines(): |
21 | | - if not chunk: |
22 | | - continue |
| 12 | + async def stream_chat(self, prompt: str) -> AsyncGenerator[bytes, None]: |
| 13 | + """Stream chat completion responses from LLM.""" |
| 14 | + # Send user message first |
| 15 | + user_msg = { |
| 16 | + "role": "user", |
| 17 | + "content": prompt, |
| 18 | + } |
| 19 | + yield orjson.dumps(user_msg) + b"\n" |
23 | 20 |
|
| 21 | + # Open client as context manager and stream responses |
| 22 | + async with httpx.AsyncClient(base_url=self.base_url) as client: |
| 23 | + async with client.stream( |
| 24 | + "POST", |
| 25 | + "/chat/completions", |
| 26 | + json={ |
| 27 | + "model": self.model, |
| 28 | + "messages": [{"role": "user", "content": prompt}], |
| 29 | + "stream": True, |
| 30 | + }, |
| 31 | + timeout=60.0, |
| 32 | + ) as response: |
| 33 | + async for line in response.aiter_lines(): |
| 34 | + print(line) |
| 35 | + if line.startswith("data: ") and line != "data: [DONE]": |
24 | 36 | try: |
25 | | - print(orjson.loads(chunk)["content"], end="", flush=True) |
26 | | - except Exception as e: |
27 | | - print(f"\nError parsing chunk: {e}") |
28 | | - except httpx.RequestError as e: |
29 | | - print(f"\nConnection error: {e}") |
| 37 | + json_line = line[6:] # Remove "data: " prefix |
| 38 | + data = orjson.loads(json_line) |
| 39 | + content = ( |
| 40 | + data.get("choices", [{}])[0] |
| 41 | + .get("delta", {}) |
| 42 | + .get("content", "") |
| 43 | + ) |
| 44 | + if content: |
| 45 | + model_msg = {"role": "model", "content": content} |
| 46 | + yield orjson.dumps(model_msg) + b"\n" |
| 47 | + except Exception: |
| 48 | + pass |
30 | 49 |
|
31 | 50 |
|
32 | | -if __name__ == "__main__": |
33 | | - anyio.run(chat_with_endpoint) |
| 51 | +# FastAPI dependency |
| 52 | +def get_llm_service(base_url: Optional[str] = None) -> StreamLLMService: |
| 53 | + return StreamLLMService(base_url=base_url or "http://localhost:11434/v1") |
0 commit comments