Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 486 additions & 14 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
]

[project.optional-dependencies]
claude = ["claude-code-sdk (==0.0.22)"]
openai = ["openai (>=1.64.0,<2)"]

[project.scripts]
Expand Down
110 changes: 110 additions & 0 deletions src/git_draft/bots/claude_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Claude code bot implementations

Useful links:

* https://github.com/anthropics/claude-code
* https://docs.anthropic.com/en/docs/claude-code/sdk/sdk-python
"""

from collections.abc import Mapping
import dataclasses
import logging
from typing import Any

import claude_code_sdk as sdk

from ..common import UnreachableError, reindent
from .common import ActionSummary, Bot, Goal, UserFeedback, Worktree


_logger = logging.getLogger(__name__)


def new_bot() -> Bot:
return _Bot()


_PROMPT_SUFFIX = reindent("""
ALWAYS use the feedback's MCP server ask_user tool if you need to request
any information from the user. NEVER repeat yourself by also asking your
question to the user in other ways.
""")


class _Bot(Bot):
def __init__(self) -> None:
self._options = sdk.ClaudeCodeOptions(
allowed_tools=["Read", "Write", "mcp__feedback__ask_user"],
permission_mode="bypassPermissions", # TODO: Tighten
append_system_prompt=_PROMPT_SUFFIX,
)

async def act(
self, goal: Goal, tree: Worktree, feedback: UserFeedback
) -> ActionSummary:
summary = ActionSummary()
with tree.edit_files() as tree_path:
options = dataclasses.replace(
self._options,
cwd=tree_path,
mcp_servers={"feedback": _feedback_mcp_server(feedback)},
)
async with sdk.ClaudeSDKClient(options) as client:
await client.query(goal.prompt)
async for msg in client.receive_response():
_logger.debug("SDK message: %s", msg)
match msg:
case sdk.UserMessage(content):
_notify(feedback, content)
case sdk.AssistantMessage(content, _):
_notify(feedback, content)
case sdk.ResultMessage() as message:
# This message's result appears to be identical to
# the last assistant message's content, so we do
# not need to show it.
summary.turn_count = message.num_turns
summary.cost = message.total_cost_usd
if usage := message.usage:
summary.token_count = _token_count(usage)
summary.usage_details = usage
case sdk.SystemMessage():
pass # TODO: Notify on tool usage?
return summary


def _token_count(usage: Mapping[str, Any]) -> int:
return (
usage["input_tokens"]
+ usage["cache_creation_input_tokens"]
+ usage["cache_read_input_tokens"]
+ usage["output_tokens"]
)


def _notify(
feedback: UserFeedback, content: str | list[sdk.ContentBlock]
) -> None:
if isinstance(content, str):
feedback.notify(content)
return

for block in content:
match block:
case sdk.TextBlock(text):
feedback.notify(text)
case sdk.ThinkingBlock(thinking, signature):
feedback.notify(thinking)
feedback.notify(signature)
case sdk.ToolUseBlock() | sdk.ToolResultBlock() as block:
_logger.debug("Using tool: %s", block)
case _:
raise UnreachableError()


def _feedback_mcp_server(feedback: UserFeedback) -> sdk.McpServerConfig:
@sdk.tool("ask_user", "Request feedback from the user", {"question": str})
async def ask_user(args: Any) -> Any:
question = args["question"]
return {"content": [{"type": "text", "text": feedback.ask(question)}]}

return sdk.create_sdk_mcp_server(name="feedback", tools=[ask_user])
12 changes: 7 additions & 5 deletions src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path, PurePosixPath
from typing import Protocol

from ..common import ensure_state_home, qualified_class_name
from ..common import JSONObject, ensure_state_home, qualified_class_name


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -71,11 +71,13 @@ class ActionSummary:
"""

title: str | None = None
request_count: int | None = None
token_count: int | None = None
turn_count: int | None = None
token_count: int | None = None # TODO: Split into input and output.
cost: float | None = None
usage_details: JSONObject | None = None # TODO: Use.

def increment_request_count(self, n: int = 1, init: bool = False) -> None:
self._increment("request_count", n, init)
def increment_turn_count(self, n: int = 1, init: bool = False) -> None:
self._increment("turn_count", n, init)

def increment_token_count(self, n: int, init: bool = False) -> None:
self._increment("token_count", n, init)
Expand Down
11 changes: 2 additions & 9 deletions src/git_draft/bots/openai_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
"""OpenAI API-backed bots

They can be used with services other than OpenAI as long as them implement a
sufficient subset of the API. For example the `completions_bot` only requires
tools support.

See the following links for more resources:

* https://platform.openai.com/docs/assistants/tools/function-calling
* https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps
* https://platform.openai.com/docs/api-reference/assistants-streaming/events
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
sufficient subset of the API. For example the `new_completions_bot` only
requires tools support.
"""

from .assistants import new_threads_bot
Expand Down
9 changes: 7 additions & 2 deletions src/git_draft/bots/openai_api/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Note that this API is (will soon?) be deprecated in favor of the responses API.
It does not support the gpt-5 series of models.

* https://platform.openai.com/docs/assistants/tools/function-calling
* https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps
* https://platform.openai.com/docs/api-reference/assistants-streaming/events
* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py
"""

from collections.abc import Sequence
Expand Down Expand Up @@ -66,7 +71,7 @@ async def act(

# We intentionally do not count the two requests above, to focus on
# "data requests" only.
action = ActionSummary(request_count=0, token_count=0)
action = ActionSummary(turn_count=0, token_count=0)
with self._client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant_id,
Expand All @@ -89,7 +94,7 @@ def __init__(
self._tree = tree
self._feedback = feedback
self._action = action
self._action.increment_request_count()
self._action.increment_turn_count()

def _clone(self) -> Self:
return self.__class__(
Expand Down
2 changes: 1 addition & 1 deletion src/git_draft/bots/openai_api/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Shared OpenAPI abstractions"""
"""Shared OpenAI abstractions"""

from collections.abc import Mapping, Sequence
import json
Expand Down
2 changes: 1 addition & 1 deletion src/git_draft/bots/openai_api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def act(
if done:
break

return ActionSummary(request_count=request_count)
return ActionSummary(turn_count=request_count)


class _CompletionsToolHandler(ToolHandler[str | None]):
Expand Down
6 changes: 6 additions & 0 deletions src/git_draft/bots/openai_api/responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Responses API backed implementation

* https://platform.openai.com/docs/guides/function-calling?api-mode=responses
"""

# TODO: Implement.
10 changes: 8 additions & 2 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ async def generate_draft(

# Ensure that we are in a folio.
folio = _maybe_active_folio(self._repo)
if not folio:
if folio:
self._progress.report(
"Reusing active draft branch.", name=folio.branch_name()
)
else:
folio = self._create_folio()
with self._store.cursor() as cursor:
[(prompt_id, seqno)] = cursor.execute(
Expand All @@ -183,6 +187,8 @@ async def generate_draft(
"Completed bot run.",
runtime=round(change.walltime.total_seconds(), 1),
tokens=change.action.token_count,
turns=change.action.turn_count,
cost=change.action.cost,
)

# Create git commits, references, and update branches.
Expand Down Expand Up @@ -218,7 +224,7 @@ async def generate_draft(
"prompt_id": prompt_id,
"bot_class": qualified_class_name(bot.__class__),
"walltime_seconds": change.walltime.total_seconds(),
"request_count": change.action.request_count,
"turn_count": change.action.turn_count,
"token_count": change.action.token_count,
"pending_question": feedback.pending_question,
},
Expand Down
4 changes: 2 additions & 2 deletions src/git_draft/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def __init__(

@override
def _notify(self, update: str) -> None:
self._spinner.update(update)
self._spinner.yaspin.write(f"○ {update}")

@override
def _ask(self, question: str) -> str | None:
with self._spinner.hidden():
answer = input(question + " ")
answer = input(f"● {question} ")
return answer or None


Expand Down
4 changes: 2 additions & 2 deletions src/git_draft/queries/add-action-summary.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ insert into action_summaries (
prompt_id,
bot_class,
walltime_seconds,
request_count,
turn_count,
token_count,
pending_question)
values (
:prompt_id,
:bot_class,
:walltime_seconds,
:request_count,
:turn_count,
:token_count,
:pending_question);
2 changes: 1 addition & 1 deletion src/git_draft/queries/create-tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ create table if not exists action_summaries (
created_at timestamp default current_timestamp,
bot_class text not null,
walltime_seconds real not null,
request_count int,
turn_count int,
token_count int,
pending_question text,
foreign key (prompt_id) references prompts (id) on delete cascade
Expand Down
12 changes: 6 additions & 6 deletions tests/git_draft/bots/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class TestActionSummary:
def test_increment_noinit(self) -> None:
action = sut.ActionSummary()
with pytest.raises(ValueError):
action.increment_request_count()
action.increment_turn_count()

def test_increment_request_count(self) -> None:
def test_increment_turn_count(self) -> None:
action = sut.ActionSummary()
action.increment_request_count(init=True)
assert action.request_count == 1
action.increment_request_count()
assert action.request_count == 2
action.increment_turn_count(init=True)
assert action.turn_count == 1
action.increment_turn_count()
assert action.turn_count == 2

def test_increment_token_count(self) -> None:
action = sut.ActionSummary()
Expand Down