Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ build/

# IDE
.vscode/
.pytest_cache/

# Jupyter
notebooks/.ipynb_checkpoints/
Expand Down
110 changes: 108 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report
- **`train_lora.py`** - LoRA fine-tuning using HuggingFace Trainer + PEFT. Supports
QLoRA (4-bit quantization) for training on 1-2 A100 GPUs.

- **`serve.py`** - FastAPI inference server that loads the fine-tuned model and
serves docstring generation via HTTP.
- **`serve.py`** - FastAPI inference server that uses ollama API to generate
docstrings. The server uses a hard-coded system prompt for NumPy-style docstring
generation.

### Evaluation (`src/evaluation/`)

Expand Down Expand Up @@ -87,6 +88,111 @@ python -m src.data.convert_seed \
--output-dir data/processed/python-method
```

## Serving

The FastAPI inference server provides HTTP endpoints for docstring generation using
ollama as the backend. The server uses a system prompt stored in
`src/training/prompts/system_prompt.md` to generate NumPy-style docstrings.

### Prerequisites

1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally
2. **Pull a model**: Download a code model (e.g., `qwen2.5-coder:32b`):
```bash
ollama pull qwen2.5-coder:32b
```

### Starting the Server

Start the FastAPI server using uvicorn:

```bash
# Using uvicorn directly
uvicorn src.training.serve:app --host 0.0.0.0 --port 8000

# Or run the module directly
python -m src.training.serve
```

The server will start on `http://localhost:8000` by default.

### Configuration

The server can be configured using environment variables:

- `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`)
- `OLLAMA_MODEL` - Model name to use (default: `qwen2.5-coder:32b`)
- `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`)

Example:
```bash
OLLAMA_MODEL=qwen2.5-coder:7b uvicorn src.training.serve:app --port 8000
```

### API Endpoints

#### Health Check

Check if the service is healthy and ollama is accessible:

```bash
curl http://localhost:8000/health
```

**Response (200 OK):**
```json
{
"status": "healthy",
"service": "ollama"
}
```

**Response (503 Service Unavailable):**
```json
{
"detail": "Service unhealthy: ollama is not running or not accessible"
}
```

#### Generate Docstring

Generate a docstring for a Python function:

```bash
curl -X POST http://localhost:8000/generate \
-H "Content-Type: application/json" \
-d '{
"code": "def add(x, y):\n return x + y",
"max_new_tokens": 256
}'
```

**Request Body:**
- `code` (required): Python function code as a string
- `max_new_tokens` (optional): Maximum number of tokens to generate (default: 256)

**Response (200 OK):**
```json
{
"docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n\"\"\""
}
```

**Response (500 Internal Server Error):**
```json
{
"detail": "Failed to generate docstring: <error message>"
}
```

### Testing

Run the test suite to verify the API endpoints:

```bash
pytest tests/test_serve.py -v
```

## Dataset

The seed dataset comes from the [NeuralCodeSum](https://github.com/wasiahmad/NeuralCodeSum)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ dependencies = [
"safetensors",
"fastapi>=0.104.0",
"uvicorn>=0.24.0",
"requests>=2.31.0",
]

[project.optional-dependencies]
dev = [
"pytest>=7.0",
"ruff>=0.1.0",
"httpx>=0.24.0",
]

[tool.hatch.build.targets.wheel]
Expand Down
100 changes: 100 additions & 0 deletions scripts/run_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import requests
import sys
import json
import time
from pathlib import Path


def load_system_prompt() -> str:
"""Load the default system prompt from the prompts directory."""
prompt_path = Path(__file__).parent.parent / "src" / "training" / "prompts" / "system_prompt.md"
if not prompt_path.exists():
raise FileNotFoundError(
f"System prompt file not found: {prompt_path}. "
"Please ensure the prompt file exists."
)
return prompt_path.read_text(encoding="utf-8")


DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate


def build_payload(model, system_msgs, user_msgs, stream):
messages = []
for s in system_msgs:
messages.append({"role": "system", "content": s})
for u in user_msgs:
messages.append({"role": "user", "content": u})

return {
"model": model,
"messages": messages,
"stream": stream,
"keep_alive": 0 # Unload model after request
}


def main():
parser = argparse.ArgumentParser(description="Send system and user prompts to model endpoint")
parser.add_argument("--url", default=DEFAULT_URL, help="API endpoint URL")
parser.add_argument("--model", required=True, help="Model name, e.g. qwen2.5-coder:32b")
parser.add_argument("--system", action="append", default=None, help="System prompt (repeatable). Overrides default.")
parser.add_argument("--user", action="append", default=[], help="User prompt (repeatable)")
parser.add_argument("--stream", action="store_true", help="Enable streaming mode")
parser.add_argument("--timeout", type=float, default=120.0, help="Request timeout in seconds")

args = parser.parse_args()

# Use default system prompt if none provided
if args.system:
system_msgs = args.system
else:
try:
system_msgs = [load_system_prompt()]
except FileNotFoundError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

if not args.user:
print("Error: at least one --user prompt is required.", file=sys.stderr)
sys.exit(2)

payload = build_payload(args.model, system_msgs, args.user, args.stream)

# Track execution time
start_time = time.time()
try:
resp = requests.post(args.url, json=payload, timeout=args.timeout)
resp.raise_for_status()
except requests.RequestException as e:
elapsed_time = time.time() - start_time
print(f"Request failed: {e}", file=sys.stderr)
print(f"Execution time: {elapsed_time:.2f}s", file=sys.stderr)
sys.exit(1)

try:
data = resp.json()
except ValueError:
print("Response is not valid JSON", file=sys.stderr)
print(resp.text, file=sys.stderr)
sys.exit(1)

# Handle /api/chat response format
if "message" in data:
print(data["message"].get("content", ""))
elif "response" in data:
print(data["response"])
elif "choices" in data and isinstance(data["choices"], list):
for c in data["choices"]:
print(c.get("message", {}).get("content", c.get("text", "")))
else:
print(json.dumps(data, indent=2))

# Print execution time
elapsed_time = time.time() - start_time
print(f"Execution time: {elapsed_time:.2f}s", file=sys.stdout)


if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions src/training/prompts/system_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly.

## Output Rules
- Output ONLY the docstring content (including the triple quotes)
- Do NOT include the function signature or body
- Do NOT add any explanation before or after the docstring

## NumPy Docstring Format

### Structure (include sections only when applicable)
"""
Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse").

Extended summary providing more details about the function behavior,
algorithm, or implementation notes. Optional but recommended for
complex functions.

Parameters
----------
param_name : type
Description of the parameter. If the description spans multiple
lines, indent continuation lines.
param_name : type, optional
For optional parameters, specify default value in description.
Default is `default_value`.
*args : type
Description of variable positional arguments.
**kwargs : type
Description of variable keyword arguments.

Returns
-------
type
Description of return value.
name : type
Use this format when returning named values or multiple values.

Yields
------
type
For generator functions, describe yielded values.

Raises
------
ExceptionType
Explanation of when this exception is raised.

See Also
--------
related_function : Brief description of relation.

Notes
-----
Additional technical notes, mathematical formulas (using LaTeX),
or implementation details.

Examples
--------
>>> function_name(arg1, arg2)
expected_output
"""

### Type Annotation Conventions
- Basic types: `int`, `float`, `str`, `bool`, `None`
- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)`
- Multiple types: `int or float`, `str or None`
- Array-like: `array_like`, `numpy.ndarray of shape (n, m)`
- Callable: `callable`
- Optional params: append `, optional` after type

### Guidelines
1. First line: concise, imperative verb, no variable names, ends with period
2. Leave one blank line after the summary before Parameters
3. Align parameter descriptions consistently
4. Include realistic, runnable Examples when behavior isn't obvious
5. Document all exceptions that may be explicitly raised
6. For boolean params, describe what True/False means
Loading