|
1 | | -from typing import Optional, AsyncGenerator |
2 | | - |
| 1 | +import anyio |
3 | 2 | import httpx |
4 | 3 | import orjson |
5 | 4 |
|
| 5 | +async def chat_with_endpoint(): |
| 6 | + async with httpx.AsyncClient() as client: |
| 7 | + while True: |
| 8 | + # Get user input |
| 9 | + prompt = input("\nYou: ") |
| 10 | + if prompt.lower() == "exit": |
| 11 | + break |
6 | 12 |
|
7 | | -class StreamLLMService: |
8 | | - def __init__(self, base_url: str = "http://localhost:11434/v1"): |
9 | | - self.base_url = base_url |
10 | | - self.model = "llama3.2" |
11 | | - |
12 | | - async def stream_chat(self, prompt: str) -> AsyncGenerator[bytes, None]: |
13 | | - """Stream chat completion responses from LLM.""" |
14 | | - # Send user message first |
15 | | - user_msg = { |
16 | | - "role": "user", |
17 | | - "content": prompt, |
18 | | - } |
19 | | - yield orjson.dumps(user_msg) + b"\n" |
20 | | - |
21 | | - # Open client as context manager and stream responses |
22 | | - async with httpx.AsyncClient(base_url=self.base_url) as client: |
| 13 | + # Send request to the API |
| 14 | + print("\nModel: ", end="", flush=True) |
23 | 15 | async with client.stream( |
24 | 16 | "POST", |
25 | | - "/chat/completions", |
26 | | - json={ |
27 | | - "model": self.model, |
28 | | - "messages": [{"role": "user", "content": prompt}], |
29 | | - "stream": True, |
30 | | - }, |
31 | | - timeout=60.0, |
| 17 | + "http://localhost:8000/chat/", |
| 18 | + data={"prompt": prompt}, |
| 19 | + timeout=60 |
32 | 20 | ) as response: |
33 | | - async for line in response.aiter_lines(): |
34 | | - print(line) |
35 | | - if line.startswith("data: ") and line != "data: [DONE]": |
| 21 | + async for chunk in response.aiter_lines(): |
| 22 | + if chunk: |
36 | 23 | try: |
37 | | - json_line = line[6:] # Remove "data: " prefix |
38 | | - data = orjson.loads(json_line) |
39 | | - content = ( |
40 | | - data.get("choices", [{}])[0] |
41 | | - .get("delta", {}) |
42 | | - .get("content", "") |
43 | | - ) |
44 | | - if content: |
45 | | - model_msg = {"role": "model", "content": content} |
46 | | - yield orjson.dumps(model_msg) + b"\n" |
47 | | - except Exception: |
48 | | - pass |
49 | | - |
| 24 | + data = orjson.loads(chunk) |
| 25 | + print(data["content"], end="", flush=True) |
| 26 | + except Exception as e: |
| 27 | + print(f"\nError parsing chunk: {e}") |
50 | 28 |
|
51 | | -# FastAPI dependency |
52 | | -def get_llm_service(base_url: Optional[str] = None) -> StreamLLMService: |
53 | | - return StreamLLMService(base_url=base_url or "http://localhost:11434/v1") |
| 29 | +if __name__ == "__main__": |
| 30 | + anyio.run(chat_with_endpoint) |
0 commit comments