11import logging
22from pathlib import Path
3- from typing import Sequence
3+ from typing import Sequence , Optional
44from mcp .server import Server
55from mcp .server .session import ServerSession
66from mcp .server .stdio import stdio_server
1313)
1414from enum import Enum
1515import git
16- from pydantic import BaseModel
16+ from pydantic import BaseModel , Field
1717
1818# Default number of context lines to show in diff output
1919DEFAULT_CONTEXT_LINES = 3
@@ -65,6 +65,24 @@ class GitShow(BaseModel):
6565class GitInit (BaseModel ):
6666 repo_path : str
6767
68+ class GitBranch (BaseModel ):
69+ repo_path : str = Field (
70+ ...,
71+ description = "The path to the Git repository." ,
72+ )
73+ branch_type : str = Field (
74+ ...,
75+ description = "Whether to list local branches ('local'), remote branches ('remote') or all branches('all')." ,
76+ )
77+ contains : Optional [str ] = Field (
78+ None ,
79+ description = "The commit sha that branch should contain. Do not pass anything to this param if no commit sha is specified" ,
80+ )
81+ not_contains : Optional [str ] = Field (
82+ None ,
83+ description = "The commit sha that branch should NOT contain. Do not pass anything to this param if no commit sha is specified" ,
84+ )
85+
6886class GitTools (str , Enum ):
6987 STATUS = "git_status"
7088 DIFF_UNSTAGED = "git_diff_unstaged"
@@ -78,6 +96,7 @@ class GitTools(str, Enum):
7896 CHECKOUT = "git_checkout"
7997 SHOW = "git_show"
8098 INIT = "git_init"
99+ BRANCH = "git_branch"
81100
82101def git_status (repo : git .Repo ) -> str :
83102 return repo .git .status ()
@@ -153,6 +172,34 @@ def git_show(repo: git.Repo, revision: str) -> str:
153172 output .append (d .diff .decode ('utf-8' ))
154173 return "" .join (output )
155174
175+ def git_branch (repo : git .Repo , branch_type : str , contains : str | None = None , not_contains : str | None = None ) -> str :
176+ match contains :
177+ case None :
178+ contains_sha = (None ,)
179+ case _:
180+ contains_sha = ("--contains" , contains )
181+
182+ match not_contains :
183+ case None :
184+ not_contains_sha = (None ,)
185+ case _:
186+ not_contains_sha = ("--no-contains" , not_contains )
187+
188+ match branch_type :
189+ case 'local' :
190+ b_type = None
191+ case 'remote' :
192+ b_type = "-r"
193+ case 'all' :
194+ b_type = "-a"
195+ case _:
196+ return f"Invalid branch type: { branch_type } "
197+
198+ # None value will be auto deleted by GitPython
199+ branch_info = repo .git .branch (b_type , * contains_sha , * not_contains_sha )
200+
201+ return branch_info
202+
156203async def serve (repository : Path | None ) -> None :
157204 logger = logging .getLogger (__name__ )
158205
@@ -228,6 +275,11 @@ async def list_tools() -> list[Tool]:
228275 name = GitTools .INIT ,
229276 description = "Initialize a new Git repository" ,
230277 inputSchema = GitInit .model_json_schema (),
278+ ),
279+ Tool (
280+ name = GitTools .BRANCH ,
281+ description = "List Git branches" ,
282+ inputSchema = GitBranch .model_json_schema (),
231283 )
232284 ]
233285
@@ -357,6 +409,18 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
357409 text = result
358410 )]
359411
412+ case GitTools .BRANCH :
413+ result = git_branch (
414+ repo ,
415+ arguments .get ("branch_type" , 'local' ),
416+ arguments .get ("contains" , None ),
417+ arguments .get ("not_contains" , None ),
418+ )
419+ return [TextContent (
420+ type = "text" ,
421+ text = result
422+ )]
423+
360424 case _:
361425 raise ValueError (f"Unknown tool: { name } " )
362426
0 commit comments