From 7cb4e8ec3130ad72dd3a0ec7ac56180562c3e4ff Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Sun, 25 Jan 2026 14:07:35 +0100 Subject: [PATCH 1/2] added a simple script to write docstrings locally based on ollama --- scripts/run_ollama.py | 153 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 scripts/run_ollama.py diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py new file mode 100644 index 0000000..85f121a --- /dev/null +++ b/scripts/run_ollama.py @@ -0,0 +1,153 @@ +import argparse +import requests +import sys +import json + + +SYSTEM_PROMPT = r""" +You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. + +\#\# Output Rules +- Output ONLY the docstring content (including the triple quotes) +- Do NOT include the function signature or body +- Do NOT add any explanation before or after the docstring + +\#\# NumPy Docstring Format + +\#\#\# Structure (include sections only when applicable) +\"\"\" +Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). + +Extended summary providing more details about the function behavior, +algorithm, or implementation notes. Optional but recommended for +complex functions. + +Parameters +---------- +param_name : type + Description of the parameter. If the description spans multiple + lines, indent continuation lines. +param_name : type, optional + For optional parameters, specify default value in description. + Default is `default_value`. +*args : type + Description of variable positional arguments. +**kwargs : type + Description of variable keyword arguments. + +Returns +------- +type + Description of return value. +name : type + Use this format when returning named values or multiple values. + +Yields +------ +type + For generator functions, describe yielded values. + +Raises +------ +ExceptionType + Explanation of when this exception is raised. + +See Also +-------- +related_function : Brief description of relation. + +Notes +----- +Additional technical notes, mathematical formulas (using LaTeX), +or implementation details. + +Examples +-------- +>>> function_name(arg1, arg2) +expected_output +\"\"\" + +\#\#\# Type Annotation Conventions +- Basic types: `int`, `float`, `str`, `bool`, `None` +- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` +- Multiple types: `int or float`, `str or None` +- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` +- Callable: `callable` +- Optional params: append `, optional` after type + +\#\#\# Guidelines +1. First line: concise, imperative verb, no variable names, ends with period +2. Leave one blank line after the summary before Parameters +3. Align parameter descriptions consistently +4. Include realistic, runnable Examples when behavior isn't obvious +5. Document all exceptions that may be explicitly raised +6. For boolean params, describe what True/False means +""" + + +DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate + + +def build_payload(model, system_msgs, user_msgs, stream): + messages = [] + for s in system_msgs: + messages.append({"role": "system", "content": s}) + for u in user_msgs: + messages.append({"role": "user", "content": u}) + + return { + "model": model, + "messages": messages, + "stream": stream, + "keep_alive": 0 # Unload model after request + } + + +def main(): + parser = argparse.ArgumentParser(description="Send system and user prompts to model endpoint") + parser.add_argument("--url", default=DEFAULT_URL, help="API endpoint URL") + parser.add_argument("--model", required=True, help="Model name, e.g. qwen2.5-coder:32b") + parser.add_argument("--system", action="append", default=None, help="System prompt (repeatable). Overrides default.") + parser.add_argument("--user", action="append", default=[], help="User prompt (repeatable)") + parser.add_argument("--stream", action="store_true", help="Enable streaming mode") + parser.add_argument("--timeout", type=float, default=120.0, help="Request timeout in seconds") + + args = parser.parse_args() + + # Use default system prompt if none provided + system_msgs = args.system if args.system else [SYSTEM_PROMPT] + + if not args.user: + print("Error: at least one --user prompt is required.", file=sys.stderr) + sys.exit(2) + + payload = build_payload(args.model, system_msgs, args.user, args.stream) + + try: + resp = requests.post(args.url, json=payload, timeout=args.timeout) + resp.raise_for_status() + except requests.RequestException as e: + print(f"Request failed: {e}", file=sys.stderr) + sys.exit(1) + + try: + data = resp.json() + except ValueError: + print("Response is not valid JSON", file=sys.stderr) + print(resp.text, file=sys.stderr) + sys.exit(1) + + # Handle /api/chat response format + if "message" in data: + print(data["message"].get("content", "")) + elif "response" in data: + print(data["response"]) + elif "choices" in data and isinstance(data["choices"], list): + for c in data["choices"]: + print(c.get("message", {}).get("content", c.get("text", ""))) + else: + print(json.dumps(data, indent=2)) + + +if __name__ == "__main__": + main() \ No newline at end of file From 6c8901947d59dcf1f6f1ec5087400f162434a343 Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Sun, 25 Jan 2026 16:19:43 +0100 Subject: [PATCH 2/2] filled the serve.py with fastapi-based endpoints, ollama is used as a backend --- .gitignore | 1 + README.md | 110 ++++++++- pyproject.toml | 2 + scripts/run_ollama.py | 107 +++------ src/training/prompts/system_prompt.md | 77 +++++++ src/training/serve.py | 120 ++++++++-- tests/test_serve.py | 312 ++++++++++++++++++++++++++ 7 files changed, 634 insertions(+), 95 deletions(-) create mode 100644 src/training/prompts/system_prompt.md create mode 100644 tests/test_serve.py diff --git a/.gitignore b/.gitignore index ad3c25b..d9c1a25 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ build/ # IDE .vscode/ +.pytest_cache/ # Jupyter notebooks/.ipynb_checkpoints/ diff --git a/README.md b/README.md index b94317a..4d607f4 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,9 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report - **`train_lora.py`** - LoRA fine-tuning using HuggingFace Trainer + PEFT. Supports QLoRA (4-bit quantization) for training on 1-2 A100 GPUs. -- **`serve.py`** - FastAPI inference server that loads the fine-tuned model and - serves docstring generation via HTTP. +- **`serve.py`** - FastAPI inference server that uses ollama API to generate + docstrings. The server uses a hard-coded system prompt for NumPy-style docstring + generation. ### Evaluation (`src/evaluation/`) @@ -87,6 +88,111 @@ python -m src.data.convert_seed \ --output-dir data/processed/python-method ``` +## Serving + +The FastAPI inference server provides HTTP endpoints for docstring generation using +ollama as the backend. The server uses a system prompt stored in +`src/training/prompts/system_prompt.md` to generate NumPy-style docstrings. + +### Prerequisites + +1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally +2. **Pull a model**: Download a code model (e.g., `qwen2.5-coder:32b`): + ```bash + ollama pull qwen2.5-coder:32b + ``` + +### Starting the Server + +Start the FastAPI server using uvicorn: + +```bash +# Using uvicorn directly +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 + +# Or run the module directly +python -m src.training.serve +``` + +The server will start on `http://localhost:8000` by default. + +### Configuration + +The server can be configured using environment variables: + +- `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`) +- `OLLAMA_MODEL` - Model name to use (default: `qwen2.5-coder:32b`) +- `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`) + +Example: +```bash +OLLAMA_MODEL=qwen2.5-coder:7b uvicorn src.training.serve:app --port 8000 +``` + +### API Endpoints + +#### Health Check + +Check if the service is healthy and ollama is accessible: + +```bash +curl http://localhost:8000/health +``` + +**Response (200 OK):** +```json +{ + "status": "healthy", + "service": "ollama" +} +``` + +**Response (503 Service Unavailable):** +```json +{ + "detail": "Service unhealthy: ollama is not running or not accessible" +} +``` + +#### Generate Docstring + +Generate a docstring for a Python function: + +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + }' +``` + +**Request Body:** +- `code` (required): Python function code as a string +- `max_new_tokens` (optional): Maximum number of tokens to generate (default: 256) + +**Response (200 OK):** +```json +{ + "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n\"\"\"" +} +``` + +**Response (500 Internal Server Error):** +```json +{ + "detail": "Failed to generate docstring: " +} +``` + +### Testing + +Run the test suite to verify the API endpoints: + +```bash +pytest tests/test_serve.py -v +``` + ## Dataset The seed dataset comes from the [NeuralCodeSum](https://github.com/wasiahmad/NeuralCodeSum) diff --git a/pyproject.toml b/pyproject.toml index 50756b8..4778490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,14 @@ dependencies = [ "safetensors", "fastapi>=0.104.0", "uvicorn>=0.24.0", + "requests>=2.31.0", ] [project.optional-dependencies] dev = [ "pytest>=7.0", "ruff>=0.1.0", + "httpx>=0.24.0", ] [tool.hatch.build.targets.wheel] diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py index 85f121a..bef6323 100644 --- a/scripts/run_ollama.py +++ b/scripts/run_ollama.py @@ -2,87 +2,19 @@ import requests import sys import json +import time +from pathlib import Path -SYSTEM_PROMPT = r""" -You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. - -\#\# Output Rules -- Output ONLY the docstring content (including the triple quotes) -- Do NOT include the function signature or body -- Do NOT add any explanation before or after the docstring - -\#\# NumPy Docstring Format - -\#\#\# Structure (include sections only when applicable) -\"\"\" -Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). - -Extended summary providing more details about the function behavior, -algorithm, or implementation notes. Optional but recommended for -complex functions. - -Parameters ----------- -param_name : type - Description of the parameter. If the description spans multiple - lines, indent continuation lines. -param_name : type, optional - For optional parameters, specify default value in description. - Default is `default_value`. -*args : type - Description of variable positional arguments. -**kwargs : type - Description of variable keyword arguments. - -Returns -------- -type - Description of return value. -name : type - Use this format when returning named values or multiple values. - -Yields ------- -type - For generator functions, describe yielded values. - -Raises ------- -ExceptionType - Explanation of when this exception is raised. - -See Also --------- -related_function : Brief description of relation. - -Notes ------ -Additional technical notes, mathematical formulas (using LaTeX), -or implementation details. - -Examples --------- ->>> function_name(arg1, arg2) -expected_output -\"\"\" - -\#\#\# Type Annotation Conventions -- Basic types: `int`, `float`, `str`, `bool`, `None` -- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` -- Multiple types: `int or float`, `str or None` -- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` -- Callable: `callable` -- Optional params: append `, optional` after type - -\#\#\# Guidelines -1. First line: concise, imperative verb, no variable names, ends with period -2. Leave one blank line after the summary before Parameters -3. Align parameter descriptions consistently -4. Include realistic, runnable Examples when behavior isn't obvious -5. Document all exceptions that may be explicitly raised -6. For boolean params, describe what True/False means -""" +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent.parent / "src" / "training" / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate @@ -115,7 +47,14 @@ def main(): args = parser.parse_args() # Use default system prompt if none provided - system_msgs = args.system if args.system else [SYSTEM_PROMPT] + if args.system: + system_msgs = args.system + else: + try: + system_msgs = [load_system_prompt()] + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) if not args.user: print("Error: at least one --user prompt is required.", file=sys.stderr) @@ -123,11 +62,15 @@ def main(): payload = build_payload(args.model, system_msgs, args.user, args.stream) + # Track execution time + start_time = time.time() try: resp = requests.post(args.url, json=payload, timeout=args.timeout) resp.raise_for_status() except requests.RequestException as e: + elapsed_time = time.time() - start_time print(f"Request failed: {e}", file=sys.stderr) + print(f"Execution time: {elapsed_time:.2f}s", file=sys.stderr) sys.exit(1) try: @@ -148,6 +91,10 @@ def main(): else: print(json.dumps(data, indent=2)) + # Print execution time + elapsed_time = time.time() - start_time + print(f"Execution time: {elapsed_time:.2f}s", file=sys.stdout) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/training/prompts/system_prompt.md b/src/training/prompts/system_prompt.md new file mode 100644 index 0000000..770e1da --- /dev/null +++ b/src/training/prompts/system_prompt.md @@ -0,0 +1,77 @@ +You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. + +## Output Rules +- Output ONLY the docstring content (including the triple quotes) +- Do NOT include the function signature or body +- Do NOT add any explanation before or after the docstring + +## NumPy Docstring Format + +### Structure (include sections only when applicable) +""" +Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). + +Extended summary providing more details about the function behavior, +algorithm, or implementation notes. Optional but recommended for +complex functions. + +Parameters +---------- +param_name : type + Description of the parameter. If the description spans multiple + lines, indent continuation lines. +param_name : type, optional + For optional parameters, specify default value in description. + Default is `default_value`. +*args : type + Description of variable positional arguments. +**kwargs : type + Description of variable keyword arguments. + +Returns +------- +type + Description of return value. +name : type + Use this format when returning named values or multiple values. + +Yields +------ +type + For generator functions, describe yielded values. + +Raises +------ +ExceptionType + Explanation of when this exception is raised. + +See Also +-------- +related_function : Brief description of relation. + +Notes +----- +Additional technical notes, mathematical formulas (using LaTeX), +or implementation details. + +Examples +-------- +>>> function_name(arg1, arg2) +expected_output +""" + +### Type Annotation Conventions +- Basic types: `int`, `float`, `str`, `bool`, `None` +- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` +- Multiple types: `int or float`, `str or None` +- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` +- Callable: `callable` +- Optional params: append `, optional` after type + +### Guidelines +1. First line: concise, imperative verb, no variable names, ends with period +2. Leave one blank line after the summary before Parameters +3. Align parameter descriptions consistently +4. Include realistic, runnable Examples when behavior isn't obvious +5. Document all exceptions that may be explicitly raised +6. For boolean params, describe what True/False means diff --git a/src/training/serve.py b/src/training/serve.py index 619ca5c..ea5f9e8 100644 --- a/src/training/serve.py +++ b/src/training/serve.py @@ -11,30 +11,124 @@ GET /health - Health check """ +import os +from pathlib import Path +from typing import Optional -def load_model(base_model: str, adapter_path: str): - """Load base model + LoRA adapter for inference.""" - raise NotImplementedError +import requests +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +# Configuration +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/chat") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:32b") +REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "120.0")) + + +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") + + +# Load system prompt at module level +SYSTEM_PROMPT = load_system_prompt() def generate_docstring(code: str, max_new_tokens: int = 256) -> str: + """Generate a docstring for the given code snippet using ollama API.""" + payload = { + "model": OLLAMA_MODEL, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": code} + ], + "stream": False, + "keep_alive": 0, + "options": { + "num_predict": max_new_tokens + } + } + + try: + resp = requests.post(OLLAMA_URL, json=payload, timeout=REQUEST_TIMEOUT) + resp.raise_for_status() + data = resp.json() + + # Handle /api/chat response format + if "message" in data: + return data["message"].get("content", "") + elif "response" in data: + return data["response"] + elif "choices" in data and isinstance(data["choices"], list): + content = "" + for c in data["choices"]: + content += c.get("message", {}).get("content", c.get("text", "")) + return content + else: + raise ValueError(f"Unexpected response format: {data}") + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate docstring: {e}") from e + + +def check_ollama_health() -> bool: + """Check if ollama is running locally by making a test request.""" + try: + # Try to list models as a health check + health_url = OLLAMA_URL.replace("/api/chat", "/api/tags") + resp = requests.get(health_url, timeout=5.0) + return resp.status_code == 200 + except requests.RequestException: + return False + + +# FastAPI app +app = FastAPI(title="Docstring Generation API", version="0.1.0") + + +class GenerateRequest(BaseModel): + """Request model for docstring generation.""" + code: str + max_new_tokens: Optional[int] = 256 + + +class GenerateResponse(BaseModel): + """Response model for docstring generation.""" + docstring: str + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): """Generate a docstring for the given code snippet.""" - raise NotImplementedError + try: + docstring = generate_docstring(request.code, request.max_new_tokens) + return GenerateResponse(docstring=docstring) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) -# FastAPI app will be defined here once dependencies are implemented. -# app = FastAPI() -# -# @app.post("/generate") -# async def generate(request: dict): ... -# -# @app.get("/health") -# async def health(): ... +@app.get("/health") +async def health(): + """Health check endpoint that verifies ollama is running.""" + is_healthy = check_ollama_health() + if is_healthy: + return {"status": "healthy", "service": "ollama"} + else: + raise HTTPException( + status_code=503, + detail="Service unhealthy: ollama is not running or not accessible" + ) def main(): """Start the inference server.""" - raise NotImplementedError + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) if __name__ == "__main__": diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000..922211d --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,312 @@ +"""Tests for src.training.serve module.""" + +import json +from unittest.mock import Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.training.serve import app, check_ollama_health, generate_docstring + + +@pytest.fixture +def client(): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +class TestHealthEndpoint: + """Tests for the /health endpoint.""" + + @patch("src.training.serve.requests.get") + def test_health_success(self, mock_get, client): + """Health check should return 200 when ollama is running.""" + # Mock successful response from ollama + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["service"] == "ollama" + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_health_failure_connection_error(self, mock_get, client): + """Health check should return 503 when ollama is not accessible.""" + # Mock connection error + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + assert "ollama" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_non_200_status(self, mock_get, client): + """Health check should return 503 when ollama returns non-200 status.""" + # Mock non-200 response + mock_response = Mock() + mock_response.status_code = 500 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_timeout(self, mock_get, client): + """Health check should return 503 when ollama request times out.""" + import requests + mock_get.side_effect = requests.Timeout("Request timed out") + + response = client.get("/health") + + assert response.status_code == 503 + + +class TestGenerateEndpoint: + """Tests for the /generate endpoint.""" + + @patch("src.training.serve.requests.post") + def test_generate_success_message_format(self, mock_post, client): + """Generate should return docstring when ollama responds with message format.""" + # Mock successful ollama response with message format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": { + "content": '"""Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n"""' + } + } + mock_post.return_value = mock_response + + request_data = { + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Compute the sum" in data["docstring"] + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_success_response_format(self, mock_post, client): + """Generate should return docstring when ollama responds with response format.""" + # Mock successful ollama response with response format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": '"""Return the product of two numbers.\n\nParameters\n----------\na : float\n First number.\nb : float\n Second number.\n\nReturns\n-------\nfloat\n Product of a and b.\n"""' + } + mock_post.return_value = mock_response + + request_data = { + "code": "def multiply(a, b):\n return a * b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Return the product" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_success_choices_format(self, mock_post, client): + """Generate should return docstring when ollama responds with choices format.""" + # Mock successful ollama response with choices format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": '"""Calculate the difference.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Difference of x and y.\n"""' + } + } + ] + } + mock_post.return_value = mock_response + + request_data = { + "code": "def subtract(x, y):\n return x - y", + "max_new_tokens": 128 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Calculate the difference" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_default_max_new_tokens(self, mock_post, client): + """Generate should use default max_new_tokens when not provided.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + # Verify that max_new_tokens was included in the payload + call_args = mock_post.call_args + assert call_args is not None + payload = call_args[1]["json"] + assert "options" in payload + assert payload["options"]["num_predict"] == 256 + + @patch("src.training.serve.requests.post") + def test_generate_failure_connection_error(self, mock_post, client): + """Generate should return 500 when ollama connection fails.""" + import requests + mock_post.side_effect = requests.ConnectionError("Connection refused") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Failed to generate docstring" in data["detail"] + + @patch("src.training.serve.requests.post") + def test_generate_failure_timeout(self, mock_post, client): + """Generate should return 500 when ollama request times out.""" + import requests + mock_post.side_effect = requests.Timeout("Request timed out") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + + @patch("src.training.serve.requests.post") + def test_generate_failure_non_200_status(self, mock_post, client): + """Generate should return 500 when ollama returns non-200 status.""" + mock_response = Mock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = Exception("Internal server error") + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + + @patch("src.training.serve.requests.post") + def test_generate_failure_unexpected_format(self, mock_post, client): + """Generate should return 500 when ollama returns unexpected format.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "unexpected": "format" + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Unexpected response format" in data["detail"] + + def test_generate_missing_code_field(self, client): + """Generate should return 422 when code field is missing.""" + request_data = {} + response = client.post("/generate", json=request_data) + + assert response.status_code == 422 + + def test_generate_empty_code(self, client): + """Generate should accept empty code string.""" + with patch("src.training.serve.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Empty function."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": ""} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + +class TestHelperFunctions: + """Tests for helper functions.""" + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_success(self, mock_get): + """check_ollama_health should return True when ollama is accessible.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + result = check_ollama_health() + + assert result is True + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_failure(self, mock_get): + """check_ollama_health should return False when ollama is not accessible.""" + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + result = check_ollama_health() + + assert result is False + + @patch("src.training.serve.requests.post") + def test_generate_docstring_success(self, mock_post): + """generate_docstring should return docstring content.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + result = generate_docstring("def test(): pass", max_new_tokens=128) + + assert result == '"""Test docstring."""' + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_docstring_failure(self, mock_post): + """generate_docstring should raise RuntimeError on failure.""" + import requests + mock_post.side_effect = requests.RequestException("Connection error") + + with pytest.raises(RuntimeError) as exc_info: + generate_docstring("def test(): pass") + + assert "Failed to generate docstring" in str(exc_info.value)