From 7266693296ccb4f36461b64957615e0781dff251 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 Jan 2026 17:01:27 +0000 Subject: [PATCH 1/3] Initial plan From 2163a808ec8f4b15a33d74af077b1b7065f69f33 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 Jan 2026 17:04:32 +0000 Subject: [PATCH 2/3] Add ruff configuration to suppress linter errors and re-enable CI check Co-authored-by: kevinbackhouse <4358136+kevinbackhouse@users.noreply.github.com> --- .github/workflows/ci.yml | 4 +- pyproject.toml | 53 ++ .../mcp_servers/alert_results_models.py | 26 +- .../codeql_python/codeql_sqlite_models.py | 11 +- .../mcp_servers/codeql_python/mcp_server.py | 121 ++-- .../mcp_servers/gh_actions.py | 191 +++--- .../mcp_servers/gh_code_scanning.py | 208 +++--- .../mcp_servers/gh_file_viewer.py | 146 +++-- src/seclab_taskflows/mcp_servers/ghsa.py | 27 +- .../mcp_servers/local_file_viewer.py | 70 +- .../mcp_servers/local_gh_resources.py | 44 +- .../mcp_servers/repo_context.py | 607 +++++++++++------- .../mcp_servers/repo_context_models.py | 79 ++- .../mcp_servers/report_alert_state.py | 182 ++++-- src/seclab_taskflows/mcp_servers/utils.py | 1 + tests/test_00.py | 6 +- 16 files changed, 1115 insertions(+), 661 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbe21a..811a5a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,9 +34,7 @@ jobs: run: pip install --upgrade hatch - name: Run static analysis - run: | - # hatch fmt --check - echo linter errors will be fixed in a separate PR + run: hatch fmt --check - name: Run tests run: hatch test --python ${{ matrix.python-version }} --cover --randomize --parallel --retries 2 --retry-delay 1 diff --git a/pyproject.toml b/pyproject.toml index 0d3c42c..a876d22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,3 +60,56 @@ exclude_lines = [ "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + +[tool.ruff.lint] +ignore = [ + "A001", # Variable name shadows a builtin + "A002", # Argument name shadows a builtin + "ARG001", # Unused function argument + "B007", # Loop control variable not used within loop body + "B008", # Do not perform function calls in argument defaults + "BLE001", # Do not catch blind exception + "C403", # Unnecessary list comprehension - rewrite as a set comprehension + "C405", # Unnecessary list literal - rewrite as a set literal + "E713", # Test for membership should be `not in` + "EM102", # Exception must not use an f-string literal + "F401", # Imported but unused + "F541", # f-string without any placeholders + "F841", # Local variable assigned but never used + "FA100", # Missing `from __future__ import annotations` + "FA102", # Missing `from __future__ import annotations` for `typing.Optional` + "FBT001", # Boolean-typed positional argument in function definition + "FBT002", # Boolean default positional argument in function definition + "G004", # Logging statement uses f-string + "I001", # Import block is un-sorted or un-formatted + "INP001", # File is part of an implicit namespace package + "LOG015", # `root` should be used instead of logger + "N802", # Function name should be lowercase + "PERF102", # Incorrect `dict` comprehension for combining two dicts + "PERF401", # Use a list comprehension to create a transformed list + "PIE810", # Call `startswith` once with a tuple + "PLC0206", # Dict should be used instead of tuple + "PLR1730", # Replace `if` statement with `min()` + "PLR2004", # Magic value used in comparison + "PLW0602", # Using global for variable but no assignment is done + "PLW1508", # Invalid type for environment variable default + "PLW1510", # `subprocess.run` without explicit `check` argument + "RET504", # Unnecessary assignment before `return` statement + "RET505", # Unnecessary `else` after `return` statement + "RUF003", # Comment contains ambiguous character + "RUF013", # PEP 484 prohibits implicit `Optional` + "RUF015", # Prefer `next(iter())` over single element slice + "S607", # Starting a process with a partial executable path + "SIM101", # Use a ternary expression instead of if-else-block + "SIM114", # Combine `if` branches using logical `or` operator + "SIM117", # Use a single `with` statement with multiple contexts + "SIM118", # Use `key in dict` instead of `key in dict.keys()` + "SIM300", # Yoda condition detected + "T201", # `print` found + "TID252", # Prefer absolute imports over relative imports + "TRY003", # Avoid specifying long messages outside the exception class + "TRY300", # Consider moving this statement to an `else` block + "UP032", # Use f-string instead of `format` call + "W291", # Trailing whitespace + "W293", # Blank line contains whitespace +] diff --git a/src/seclab_taskflows/mcp_servers/alert_results_models.py b/src/seclab_taskflows/mcp_servers/alert_results_models.py index 53efc2c..a20852f 100644 --- a/src/seclab_taskflows/mcp_servers/alert_results_models.py +++ b/src/seclab_taskflows/mcp_servers/alert_results_models.py @@ -5,11 +5,13 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship from typing import Optional + class Base(DeclarativeBase): pass + class AlertResults(Base): - __tablename__ = 'alert_results' + __tablename__ = "alert_results" canonical_id: Mapped[int] = mapped_column(primary_key=True) alert_id: Mapped[str] @@ -22,18 +24,21 @@ class AlertResults(Base): valid: Mapped[bool] = mapped_column(nullable=False, default=True) completed: Mapped[bool] = mapped_column(nullable=False, default=False) - relationship('AlertFlowGraph', cascade='all, delete') + relationship("AlertFlowGraph", cascade="all, delete") def __repr__(self): - return (f"") + return ( + f"" + ) + class AlertFlowGraph(Base): - __tablename__ = 'alert_flow_graph' + __tablename__ = "alert_flow_graph" id: Mapped[int] = mapped_column(primary_key=True) - alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE')) + alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE")) flow_data: Mapped[str] = mapped_column(Text) repo: Mapped[str] prev: Mapped[Optional[str]] @@ -41,6 +46,7 @@ class AlertFlowGraph(Base): started: Mapped[bool] = mapped_column(nullable=False, default=False) def __repr__(self): - return (f"") - + return ( + f"" + ) diff --git a/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py b/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py index 51d1224..4e1604c 100644 --- a/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py +++ b/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py @@ -5,12 +5,13 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped from typing import Optional + class Base(DeclarativeBase): pass class Source(Base): - __tablename__ = 'source' + __tablename__ = "source" id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] @@ -20,6 +21,8 @@ class Source(Base): notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) def __repr__(self): - return (f"") + return ( + f"" + ) diff --git a/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py b/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py index 74fade9..9af9452 100644 --- a/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py +++ b/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py @@ -6,8 +6,9 @@ from seclab_taskflow_agent.mcp_servers.codeql.client import run_query, _debug_log from pydantic import Field -#from mcp.server.fastmcp import FastMCP, Context -from fastmcp import FastMCP # use FastMCP 2.0 + +# from mcp.server.fastmcp import FastMCP, Context +from fastmcp import FastMCP # use FastMCP 2.0 from pathlib import Path import os import csv @@ -23,22 +24,20 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_codeql_python.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_codeql_python.log"), + filemode="a", ) -MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR') -CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH') +MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR") +CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH") mcp = FastMCP("CodeQL-Python") # tool name -> templated query lookup for supported languages TEMPLATED_QUERY_PATHS = { # to add a language, port the templated query pack and add its definition here - 'python': { - 'remote_sources': 'queries/mcp-python/remote_sources.ql' - } + "python": {"remote_sources": "queries/mcp-python/remote_sources.ql"} } @@ -49,9 +48,10 @@ def source_to_dict(result): "source_location": result.source_location, "line": result.line, "source_type": result.source_type, - "notes": result.notes + "notes": result.notes, } + def _resolve_query_path(language: str, query: str) -> Path: global TEMPLATED_QUERY_PATHS if language not in TEMPLATED_QUERY_PATHS: @@ -66,7 +66,7 @@ def _resolve_db_path(relative_db_path: str | Path): global CODEQL_DBS_BASE_PATH # path joins will return "/B" if "/A" / "////B" etc. as well # not windows compatible and probably needs additional hardening - relative_db_path = str(relative_db_path).strip().lstrip('/') + relative_db_path = str(relative_db_path).strip().lstrip("/") relative_db_path = Path(relative_db_path) absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve() if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()): @@ -76,21 +76,21 @@ def _resolve_db_path(relative_db_path: str | Path): raise RuntimeError(f"Error: Database not found at {absolute_path}!") return str(absolute_path) + # This sqlite database is specifically made for CodeQL for Python MCP. class CodeqlSqliteBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir if not Path(self.memcache_state_dir).exists(): - db_dir = 'sqlite://' + db_dir = "sqlite://" else: - db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db' + db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db" self.engine = create_engine(db_dir, echo=False) Base.metadata.create_all(self.engine, tables=[Source.__table__]) - - def store_new_source(self, repo, source_location, line, source_type, notes, update = False): + def store_new_source(self, repo, source_location, line, source_type, notes, update=False): with Session(self.engine) as session: - existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first() + existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first() if existing: existing.notes = (existing.notes or "") + notes session.commit() @@ -98,14 +98,16 @@ def store_new_source(self, repo, source_location, line, source_type, notes, upda else: if update: return f"No source exists at repo {repo}, location {source_location}, line {line} to update." - new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes) + new_source = Source( + repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes + ) session.add(new_source) session.commit() return f"Added new source for {source_location} in {repo}." def get_sources(self, repo): with Session(self.engine) as session: - results = session.query(Source).filter_by(repo = repo).all() + results = session.query(Source).filter_by(repo=repo).all() sources = [source_to_dict(source) for source in results] return sources @@ -119,8 +121,8 @@ def _csv_parse(raw): if i == 0: continue # col1 has what we care about, but offer flexibility - keys = row[1].split(',') - this_obj = {'description': row[0].format(*row[2:])} + keys = row[1].split(",") + this_obj = {"description": row[0].format(*row[2:])} for j, k in enumerate(keys): this_obj[k.strip()] = row[j + 2] results.append(this_obj) @@ -141,27 +143,32 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu except RuntimeError: return f"The query {query_name} is not supported for language: {language}" try: - csv = run_query(Path(__file__).parent.resolve() / - query_path, - database_path, - fmt='csv', - template_values=template_values, - log_stderr=True) + csv = run_query( + Path(__file__).parent.resolve() / query_path, + database_path, + fmt="csv", + template_values=template_values, + log_stderr=True, + ) return _csv_parse(csv) except Exception as e: return f"The query {query_name} encountered an error: {e}" + backend = CodeqlSqliteBackend(MEMORY) + @mcp.tool() -def remote_sources(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - database_path: str = Field(description="The CodeQL database path."), - language: str = Field(description="The language used for the CodeQL database.")): +def remote_sources( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + database_path: str = Field(description="The CodeQL database path."), + language: str = Field(description="The language used for the CodeQL database."), +): """List all remote sources and their locations in a CodeQL database, then store the results in a database.""" repo = process_repo(owner, repo) - results = _run_query('remote_sources', database_path, language, {}) + results = _run_query("remote_sources", database_path, language, {}) # Check if results is an error (list of strings) or valid data (list of dicts) if isinstance(results, str): @@ -172,53 +179,67 @@ def remote_sources(owner: str = Field(description="The owner of the GitHub repos for result in results: backend.store_new_source( repo=repo, - source_location=result.get('location', ''), - source_type=result.get('source', ''), - line=int(result.get('line', '0')), - notes=None, #result.get('description', ''), - update=False + source_location=result.get("location", ""), + source_type=result.get("source", ""), + line=int(result.get("line", "0")), + notes=None, # result.get('description', ''), + update=False, ) stored_count += 1 return f"Stored {stored_count} remote sources in {repo}." + @mcp.tool() -def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def fetch_sources( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Fetch all sources from the repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_sources(repo)) + @mcp.tool() -def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - source_location: str = Field(description="The path to the file"), - line: int = Field(description="The line number of the source"), - notes: str = Field(description="The notes to append to this source")): +def add_source_notes( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + source_location: str = Field(description="The path to the file"), + line: int = Field(description="The line number of the source"), + notes: str = Field(description="The notes to append to this source"), +): """ Add new notes to an existing source. The notes will be appended to any existing notes. """ repo = process_repo(owner, repo) - return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True) + return backend.store_new_source( + repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True + ) + @mcp.tool() -def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def clear_codeql_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Clear all data for a given repo from the database """ repo = process_repo(owner, repo) with Session(backend.engine) as session: - deleted_sources = session.query(Source).filter_by(repo = repo).delete() + deleted_sources = session.query(Source).filter_by(repo=repo).delete() session.commit() return f"Cleared {deleted_sources} sources from repo {repo}." + if __name__ == "__main__": # Check if codeql/python-all pack is installed, if not install it - if not os.path.isdir('/.codeql/packages/codeql/python-all'): - pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python') + if not os.path.isdir("/.codeql/packages/codeql/python-all"): + pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath( + "mcp-python" + ) print(f"Installing CodeQL pack from {pack_path}") subprocess.run(["codeql", "pack", "install", pack_path]) mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998) diff --git a/src/seclab_taskflows/mcp_servers/gh_actions.py b/src/seclab_taskflows/mcp_servers/gh_actions.py index 51f7451..1deaa1f 100644 --- a/src/seclab_taskflows/mcp_servers/gh_actions.py +++ b/src/seclab_taskflows/mcp_servers/gh_actions.py @@ -16,16 +16,18 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_gh_actions.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_gh_actions.log"), + filemode="a", ) + class Base(DeclarativeBase): pass + class WorkflowUses(Base): - __tablename__ = 'workflow_uses' + __tablename__ = "workflow_uses" id: Mapped[int] = mapped_column(primary_key=True) user: Mapped[str] @@ -34,34 +36,45 @@ class WorkflowUses(Base): repo: Mapped[str] def __repr__(self): - return (f"") + return f"" mcp = FastMCP("GitHubCodeScanning") -high_privileged_triggers = set(["issues", "issue_comment", "pull_request_comment", "pull_request_review", "pull_request_review_comment", - "pull_request_target"]) +high_privileged_triggers = set( + [ + "issues", + "issue_comment", + "pull_request_comment", + "pull_request_review", + "pull_request_review_comment", + "pull_request_target", + ] +) -unimportant_triggers = set(['pull_request', 'workflow_dispatch']) +unimportant_triggers = set(["pull_request", "workflow_dispatch"]) -GH_TOKEN = os.getenv('GH_TOKEN', default='') +GH_TOKEN = os.getenv("GH_TOKEN", default="") -ACTIONS_DB_DIR = mcp_data_dir('seclab-taskflows', 'gh_actions', 'ACTIONS_DB_DIR') +ACTIONS_DB_DIR = mcp_data_dir("seclab-taskflows", "gh_actions", "ACTIONS_DB_DIR") -engine = create_engine(f'sqlite:///{os.path.abspath(ACTIONS_DB_DIR)}/actions.db', echo=False) -Base.metadata.create_all(engine, tables = [WorkflowUses.__table__]) +engine = create_engine(f"sqlite:///{os.path.abspath(ACTIONS_DB_DIR)}/actions.db", echo=False) +Base.metadata.create_all(engine, tables=[WorkflowUses.__table__]) -async def call_api(url: str, params: dict, raw = False) -> str: +async def call_api(url: str, params: dict, raw=False) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } if raw: headers["Accept"] = "application/vnd.github.raw+json" + async def _fetch(url, headers, params): try: - async with httpx.AsyncClient(headers = headers) as client: + async with httpx.AsyncClient(headers=headers) as client: r = await client.get(url, params=params) r.raise_for_status() return r @@ -74,41 +87,40 @@ async def _fetch(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - r = await _fetch(url, headers = headers, params=params) + r = await _fetch(url, headers=headers, params=params) return r + @mcp.tool() async def fetch_workflow( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_id: str = Field(description="The ID or name of the workflow")) -> str: + workflow_id: str = Field(description="The ID or name of the workflow"), +) -> str: """ Fetch the details of a GitHub Actions workflow. """ - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={}) if isinstance(r, str): return r return r.json() + @mcp.tool() async def check_workflow_active( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_id: str = Field(description="The ID or name of the workflow")) -> str: + workflow_id: str = Field(description="The ID or name of the workflow"), +) -> str: """ Check if a GitHub Actions workflow is active. """ - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={}) if isinstance(r, str): return r return f"Workflow {workflow_id} is {'active' if r.json().get('state') == 'active' else 'inactive'}." + def find_in_yaml(key, node): if isinstance(node, dict): for k, v in node.items(): @@ -122,12 +134,11 @@ def find_in_yaml(key, node): for result in find_in_yaml(key, item): yield result -async def get_workflow_triggers(owner: str, repo: str, workflow_file_path: str) -> str: +async def get_workflow_triggers(owner: str, repo: str, workflow_file_path: str) -> str: r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", - params={}, raw = True - ) + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", params={}, raw=True + ) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) @@ -135,81 +146,76 @@ async def get_workflow_triggers(owner: str, repo: str, workflow_file_path: str) triggers = list(find_in_yaml(True, data)) return triggers + @mcp.tool() async def find_workflow_run_dependency( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), workflow_file_path: str = Field(description="The file path of the workflow that is triggered by `workflow_run`"), - high_privileged: bool = Field(description="Whether to return high privileged dependencies only.") -)->str: + high_privileged: bool = Field(description="Whether to return high privileged dependencies only."), +) -> str: """ Find the workflow that triggers this workflow_run. """ r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", - params={}, raw=True + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", params={}, raw=True ) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) - trigger_workflow = list(find_in_yaml('workflow_run', data))[0].get('workflows', []) + trigger_workflow = list(find_in_yaml("workflow_run", data))[0].get("workflows", []) if not trigger_workflow: return json.dumps([], indent=2) r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", - params={}, raw=True + url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", params={}, raw=True ) if isinstance(r, str): return json.dumps([r]) if not r.json(): return json.dumps([], indent=2) - paths_list = [item['path'] for item in r.json() if item['path'].endswith('.yml') or item['path'].endswith('.yaml')] + paths_list = [item["path"] for item in r.json() if item["path"].endswith(".yml") or item["path"].endswith(".yaml")] results = [] for path in paths_list: - workflow_id = path.split('/')[-1] - active = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", - params={} + workflow_id = path.split("/")[-1] + active = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={} ) - if not isinstance(active, str) and active.json().get('state') == "active": - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={}, raw=True - ) + if not isinstance(active, str) and active.json().get("state") == "active": + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}, raw=True) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) - name = data.get('name', '') + name = data.get("name", "") if name in trigger_workflow or "*" in trigger_workflow: triggers = data.get(True, {}) if not high_privileged or high_privileged_triggers.intersection(set(triggers)): - results.append({ - "path": path, - "name": name, - "triggers": triggers - }) + results.append({"path": path, "name": name, "triggers": triggers}) return json.dumps(results, indent=2) + @mcp.tool() async def get_workflow_trigger( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow")) -> str: + workflow_file_path: str = Field(description="The file path of the workflow"), +) -> str: """ Get the trigger of a GitHub Actions workflow. """ return json.dumps(await get_workflow_triggers(owner, repo, workflow_file_path), indent=2) + @mcp.tool() async def check_workflow_reusable( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow")) -> str: + workflow_file_path: str = Field(description="The file path of the workflow"), +) -> str: """ Check if a GitHub Actions workflow is reusable. """ - if workflow_file_path.endswith('/action.yml') or workflow_file_path.endswith('/action.yaml'): + if workflow_file_path.endswith("/action.yml") or workflow_file_path.endswith("/action.yaml"): return "This workflow is reusable as an action." triggers = await get_workflow_triggers(owner, repo, workflow_file_path) print(f"Triggers found: {triggers}") @@ -218,15 +224,17 @@ async def check_workflow_reusable( return "This workflow is reusable as a workflow call." elif isinstance(trigger, dict): for k, v in trigger.items(): - if 'workflow_call' == k: + if "workflow_call" == k: return "This workflow is reusable." return "This workflow is not reusable." + @mcp.tool() async def get_high_privileged_workflow_triggers( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow")) -> str: + workflow_file_path: str = Field(description="The file path of the workflow"), +) -> str: """ Gets the high privileged triggers for a workflow, if none returns, then the workflow is not high privileged. """ @@ -236,55 +244,55 @@ async def get_high_privileged_workflow_triggers( if isinstance(trigger, str): if trigger in high_privileged_triggers: results.append(trigger) - elif trigger == 'workflow_run': + elif trigger == "workflow_run": results.append(trigger) elif isinstance(trigger, dict): this_results = {} for k, v in trigger.items(): if k in high_privileged_triggers: this_results[k] = v - elif k == 'workflow_run': + elif k == "workflow_run": if not v or isinstance(v, str): this_results[k] = v - elif isinstance(v, dict) and not 'branches' in v: + elif isinstance(v, dict) and not "branches" in v: this_results[k] = v if this_results: results.append(this_results) - return json.dumps(["Workflow is high privileged" if results else "Workflow is not high privileged", results], indent = 2) + return json.dumps( + ["Workflow is high privileged" if results else "Workflow is not high privileged", results], indent=2 + ) + @mcp.tool() async def get_workflow_user( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), workflow_file_path: str = Field(description="The file path of the workflow"), - save_to_db: bool = Field(description="Save the results to database.", default=False)) -> str: + save_to_db: bool = Field(description="Save the results to database.", default=False), +) -> str: """ Get the user of a reusable workflow in repo. """ - paths = workflow_file_path.split('/') - if workflow_file_path.endswith('/action.yml') or workflow_file_path.endswith('/action.yaml'): + paths = workflow_file_path.split("/") + if workflow_file_path.endswith("/action.yml") or workflow_file_path.endswith("/action.yaml"): action_name = paths[-2] else: - action_name = paths[-1].replace('.yml', '').replace('.yaml', '') - paths = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", - params={} - ) + action_name = paths[-1].replace(".yml", "").replace(".yaml", "") + paths = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", params={}) if isinstance(paths, str) or not paths.json(): return json.dumps([], indent=2) - paths_list = [item['path'] for item in paths.json() if item['path'].endswith('.yml') or item['path'].endswith('.yaml')] + paths_list = [ + item["path"] for item in paths.json() if item["path"].endswith(".yml") or item["path"].endswith(".yaml") + ] results = [] for path in paths_list: - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={}, raw=True - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}, raw=True) if isinstance(r, str): continue data = yaml.safe_load(r.text) - uses = list(find_in_yaml('uses', data)) + uses = list(find_in_yaml("uses", data)) lines = r.text.splitlines() actual_name = {} for use in uses: @@ -297,26 +305,24 @@ async def get_workflow_user( for use, line_numbers in actual_name.items(): if not line_numbers: continue - results.append({ - "user": path, - "lines": line_numbers, - "action_name": workflow_file_path, - "repo": f"{owner}/{repo}" - }) + results.append( + {"user": path, "lines": line_numbers, "action_name": workflow_file_path, "repo": f"{owner}/{repo}"} + ) if not results: return json.dumps([]) if save_to_db: with Session(engine) as session: for result in results: - result['lines'] = json.dumps(result['lines']) # Convert list of lines to JSON string - result['repo'] = result['repo'].lower() + result["lines"] = json.dumps(result["lines"]) # Convert list of lines to JSON string + result["repo"] = result["repo"].lower() workflow_use = WorkflowUses(**result) session.add(workflow_use) session.commit() return f"Search results saved to database." return json.dumps(results) + @mcp.tool() def fetch_last_workflow_users_results() -> str: """ @@ -326,7 +332,18 @@ def fetch_last_workflow_users_results() -> str: results = session.query(WorkflowUses).all() session.query(WorkflowUses).delete() session.commit() - return json.dumps([{"user": result.user, "lines" : json.loads(result.lines), "action": result.action_name, "repo" : result.repo.lower()} for result in results]) + return json.dumps( + [ + { + "user": result.user, + "lines": json.loads(result.lines), + "action": result.action_name, + "repo": result.repo.lower(), + } + for result in results + ] + ) + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/gh_code_scanning.py b/src/seclab_taskflows/mcp_servers/gh_code_scanning.py index bbdff6f..bfb1300 100644 --- a/src/seclab_taskflows/mcp_servers/gh_code_scanning.py +++ b/src/seclab_taskflows/mcp_servers/gh_code_scanning.py @@ -20,62 +20,71 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_gh_code_scanning.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_gh_code_scanning.log"), + filemode="a", ) mcp = FastMCP("GitHubCodeScanning") -GH_TOKEN = os.getenv('GH_TOKEN', default='') +GH_TOKEN = os.getenv("GH_TOKEN", default="") + +CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH") +ALERT_RESULTS_DIR = mcp_data_dir("seclab-taskflows", "gh_code_scanning", "ALERT_RESULTS_DIR") -CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH') -ALERT_RESULTS_DIR = mcp_data_dir('seclab-taskflows', 'gh_code_scanning', 'ALERT_RESULTS_DIR') def parse_alert(alert: dict) -> dict: """Parse the alert dictionary to extract relevant information.""" + def _parse_location(location: dict) -> str: """Parse the location dictionary to extract file and line information.""" if not location: - return 'No location information available' - file_path = location.get('path', '') - start_line = location.get('start_line', '') - end_line = location.get('end_line', '') - start_column = location.get('start_column', '') - end_column = location.get('end_column', '') + return "No location information available" + file_path = location.get("path", "") + start_line = location.get("start_line", "") + end_line = location.get("end_line", "") + start_column = location.get("start_column", "") + end_column = location.get("end_column", "") if not file_path or not start_line or not end_line or not start_column or not end_column: - return 'No location information available' + return "No location information available" return f"{file_path}:{start_line}:{start_column}:{end_line}:{end_column}" + def _get_language(category: str) -> str: - return category.split(':')[1] if category and ':' in category else '' + return category.split(":")[1] if category and ":" in category else "" + def _get_repo_from_html_url(html_url: str) -> str: """Extract the repository name from the HTML URL.""" if not html_url: - return '' - parts = html_url.split('/') + return "" + parts = html_url.split("/") if len(parts) < 5: - return '' + return "" return f"{parts[3]}/{parts[4]}".lower() parsed = { - 'alert_id': alert.get('number', 'No number'), - 'rule': alert.get('rule', {}).get('id', 'No rule'), - 'state': alert.get('state', 'No state'), - 'location': _parse_location(alert.get('most_recent_instance', {}).get('location', 'No location')), - 'language': _get_language(alert.get('most_recent_instance', {}).get('category', 'No language')), - 'created': alert.get('created_at', 'No created'), - 'updated': alert.get('updated_at', 'No updated'), - 'dismissed_comment': alert.get('dismissed_comment', ''), + "alert_id": alert.get("number", "No number"), + "rule": alert.get("rule", {}).get("id", "No rule"), + "state": alert.get("state", "No state"), + "location": _parse_location(alert.get("most_recent_instance", {}).get("location", "No location")), + "language": _get_language(alert.get("most_recent_instance", {}).get("category", "No language")), + "created": alert.get("created_at", "No created"), + "updated": alert.get("updated_at", "No updated"), + "dismissed_comment": alert.get("dismissed_comment", ""), } return parsed + async def call_api(url: str, params: dict) -> str | httpx.Response: """Call the GitHub code scanning API to fetch alert.""" - headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } + async def _fetch_alerts(url, headers, params): try: - async with httpx.AsyncClient(headers = headers) as client: + async with httpx.AsyncClient(headers=headers) as client: r = await client.get(url, params=params) r.raise_for_status() return r @@ -88,14 +97,16 @@ async def _fetch_alerts(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - r = await _fetch_alerts(url, headers = headers, params=params) + r = await _fetch_alerts(url, headers=headers, params=params) return r @mcp.tool() -async def get_alert_by_number(owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name."), - alert_number: int = Field(description="The alert number to get the alert for. Example: 1")) -> str: +async def get_alert_by_number( + owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name."), + alert_number: int = Field(description="The alert number to get the alert for. Example: 1"), +) -> str: """Get the alert by number for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/alerts/{alert_number}" resp = await call_api(url, {}) @@ -105,24 +116,25 @@ async def get_alert_by_number(owner: str = Field(description="The owner of the r return json.dumps(parsed_alert) return resp -async def fetch_alerts_from_gh(owner: str, repo: str, state: str = 'open', rule = '') -> str: + +async def fetch_alerts_from_gh(owner: str, repo: str, state: str = "open", rule="") -> str: """Fetch all code scanning alerts for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/alerts" - if state not in ['open', 'closed', 'dismissed']: - state = 'open' - params = {'state': state, 'per_page': 100} - #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + if state not in ["open", "closed", "dismissed"]: + state = "open" + params = {"state": state, "per_page": 100} + # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') results = [] while True: resp = await call_api(url, params) resp_headers = resp.headers - link = resp_headers.get('link', '') + link = resp_headers.get("link", "") resp = resp.json() if isinstance(resp, list): this_results = [parse_alert(alert) for alert in resp] if rule: - this_results = [alert for alert in this_results if alert.get('rule') == rule] + this_results = [alert for alert in this_results if alert.get("rule") == rule] results += this_results else: return resp + " url: " + url @@ -136,63 +148,76 @@ async def fetch_alerts_from_gh(owner: str, repo: str, state: str = 'open', rule return results return "No alerts found." + @mcp.tool() -async def fetch_alerts(owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name."), - state: str = Field(default='open', description="The state of the alert to filter by. Default is 'open'."), - rule: str = Field(description='The rule of the alert to fetch', default = '')) -> str: +async def fetch_alerts( + owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name."), + state: str = Field(default="open", description="The state of the alert to filter by. Default is 'open'."), + rule: str = Field(description="The rule of the alert to fetch", default=""), +) -> str: """Fetch all code scanning alerts for a specific repository.""" results = await fetch_alerts_from_gh(owner, repo, state, rule) if isinstance(results, str): return results return json.dumps(results, indent=2) + @mcp.tool() async def fetch_alerts_to_sql( owner: str = Field(description="The owner of the repo"), repo: str = Field(description="The repository name."), - state: str = Field(default='open', description="The state of the alert to filter by. Default is 'open'."), - rule = Field(description='The rule of the alert to fetch', default = ''), - rename_repo: str = Field(description="An optional alternative repo name for storing the alerts, if not specify, repo is used ", default = '') - ) -> str: + state: str = Field(default="open", description="The state of the alert to filter by. Default is 'open'."), + rule=Field(description="The rule of the alert to fetch", default=""), + rename_repo: str = Field( + description="An optional alternative repo name for storing the alerts, if not specify, repo is used ", + default="", + ), +) -> str: """Fetch all code scanning alerts for a specific repository and store them in a SQL database.""" results = await fetch_alerts_from_gh(owner, repo, state, rule) - sql_db_path = f"sqlite:///{ALERT_RESULTS_DIR}/alert_results.db" + sql_db_path = f"sqlite:///{ALERT_RESULTS_DIR}/alert_results.db" if isinstance(results, str) or not results: return results engine = create_engine(sql_db_path, echo=False) Base.metadata.create_all(engine, tables=[AlertResults.__table__, AlertFlowGraph.__table__]) with Session(engine) as session: for alert in results: - session.add(AlertResults( - alert_id=alert.get('alert_id', ''), - repo = rename_repo.lower() if rename_repo else repo.lower(), - language=alert.get('language', ''), - rule=alert.get('rule', ''), - location=alert.get('location', ''), - result='', - created=alert.get('created', ''), - valid=True - )) + session.add( + AlertResults( + alert_id=alert.get("alert_id", ""), + repo=rename_repo.lower() if rename_repo else repo.lower(), + language=alert.get("language", ""), + rule=alert.get("rule", ""), + location=alert.get("location", ""), + result="", + created=alert.get("created", ""), + valid=True, + ) + ) session.commit() return f"Stored {len(results)} alerts in the SQL database at {sql_db_path}." + async def _fetch_codeql_databases(owner: str, repo: str, language: str): """Fetch the CodeQL databases for a given repo and language.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/codeql/databases/{language}" - headers = {"Accept": "application/zip,application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {os.getenv('GH_TOKEN')}"} + headers = { + "Accept": "application/zip,application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {os.getenv('GH_TOKEN')}", + } try: async with httpx.AsyncClient() as client: - async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: + async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: response.raise_for_status() expected_path = f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip" if os.path.realpath(expected_path) != expected_path: return f"Error: Invalid path for CodeQL database: {expected_path}" if not Path(f"{CODEQL_DBS_BASE_PATH}/{owner}").exists(): os.makedirs(f"{CODEQL_DBS_BASE_PATH}/{owner}", exist_ok=True) - async with aiofiles.open(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip", 'wb') as f: + async with aiofiles.open(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip", "wb") as f: async for chunk in response.aiter_bytes(): await f.write(chunk) # Unzip the downloaded file @@ -200,7 +225,7 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): if not zip_path.exists(): return f"Error: CodeQL database for {repo} ({language}) does not exist." - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(Path(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}")) # Remove the zip file after extraction os.remove(zip_path) @@ -209,7 +234,12 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): if Path(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}/codeql_db").exists(): qldb_subfolder = "codeql_db" - return json.dumps({'message': f"CodeQL database for {repo} ({language}) fetched successfully.", 'relative_database_path': f"{owner}/{repo}/{qldb_subfolder}"}) + return json.dumps( + { + "message": f"CodeQL database for {repo} ({language}) fetched successfully.", + "relative_database_path": f"{owner}/{repo}/{qldb_subfolder}", + } + ) except httpx.RequestError as e: return f"Error: Request error: {e}" except httpx.HTTPStatusError as e: @@ -217,19 +247,23 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): except Exception as e: return f"Error: An unexpected error occurred: {e}" + @mcp.tool() -async def fetch_database(owner: str = Field(description="The owner of the repo."), - repo: str = Field(description="The name of the repo."), - language: str = Field(description="The language used for the CodeQL database.")): +async def fetch_database( + owner: str = Field(description="The owner of the repo."), + repo: str = Field(description="The name of the repo."), + language: str = Field(description="The language used for the CodeQL database."), +): """Fetch the CodeQL database for a given repo and language.""" return await _fetch_codeql_databases(owner, repo, language) + @mcp.tool() async def dismiss_alert( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), alert_id: str = Field(description="The ID of the alert to dismiss"), - reason: str = Field(description="The reason for dismissing the alert. It must be less than 280 characters.") + reason: str = Field(description="The reason for dismissing the alert. It must be less than 280 characters."), ) -> str: """ Dismiss a code scanning alert. @@ -238,31 +272,34 @@ async def dismiss_alert( headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}" + "Authorization": f"Bearer {GH_TOKEN}", } async with httpx.AsyncClient(headers=headers) as client: - response = await client.patch(url, json={"state": "dismissed", "dismissed_reason": "false positive", "dismissed_comment": reason}) + response = await client.patch( + url, json={"state": "dismissed", "dismissed_reason": "false positive", "dismissed_comment": reason} + ) response.raise_for_status() return response.text + @mcp.tool() async def check_alert_issue_exists( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - alert_id: str = Field(description="The ID of the alert to check for an associated issue") + alert_id: str = Field(description="The ID of the alert to check for an associated issue"), ) -> str: """ Check if an issue exists for a specific alert in a repository. """ url = f"https://api.github.com/repos/{owner}/{repo}/issues" - #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') params = {"state": "open", "per_page": 100} while True: resp = await call_api(url, params=params) resp_headers = resp.headers - link = resp_headers.get('link', '') + link = resp_headers.get("link", "") resp = resp.json() if isinstance(resp, list): for issue in resp: @@ -277,12 +314,16 @@ async def check_alert_issue_exists( params = parse_qs(urlparse(url).query) return "No issue found for this alert." + @mcp.tool() async def fetch_issues_matches( - repo: str = Field(description="A comma separated list of repositories to search in. Each term is of the form owner/repo. For example: 'owner1/repo1,owner2/repo2'"), + repo: str = Field( + description="A comma separated list of repositories to search in. Each term is of the form owner/repo. For example: 'owner1/repo1,owner2/repo2'" + ), matches: str = Field(description="The search term to match against issue titles"), - state: str = Field(default='open', description="The state of the issues to filter by. Default is 'open'."), - labels: str = Field(default="", description="Labels to filter issues by")) -> str: + state: str = Field(default="open", description="The state of the issues to filter by. Default is 'open'."), + labels: str = Field(default="", description="Labels to filter issues by"), +) -> str: """ Fetch issues from a repository that match a specific title pattern. """ @@ -298,18 +339,25 @@ async def fetch_issues_matches( } if labels: params["labels"] = labels - #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') while True: resp = await call_api(url, params=params) resp_headers = resp.headers - link = resp_headers.get('link', '') + link = resp_headers.get("link", "") resp = resp.json() if isinstance(resp, list): for issue in resp: if matches in issue.get("title", "") or matches in issue.get("body", ""): - results.append({"title": issue["title"], "number": issue["number"], "repo": r, "body": issue.get("body", ""), - "labels": issue.get("labels", [])}) + results.append( + { + "title": issue["title"], + "number": issue["number"], + "repo": r, + "body": issue.get("body", ""), + "labels": issue.get("labels", []), + } + ) else: return resp + " url: " + url m = link_pattern.search(link) diff --git a/src/seclab_taskflows/mcp_servers/gh_file_viewer.py b/src/seclab_taskflows/mcp_servers/gh_file_viewer.py index ea9a40c..30519ca 100644 --- a/src/seclab_taskflows/mcp_servers/gh_file_viewer.py +++ b/src/seclab_taskflows/mcp_servers/gh_file_viewer.py @@ -19,16 +19,18 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_gh_file_viewer.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_gh_file_viewer.log"), + filemode="a", ) + class Base(DeclarativeBase): pass + class SearchResults(Base): - __tablename__ = 'search_results' + __tablename__ = "search_results" id: Mapped[int] = mapped_column(primary_key=True) path: Mapped[str] @@ -38,26 +40,33 @@ class SearchResults(Base): repo: Mapped[str] def __repr__(self): - return (f"") + return ( + f"" + ) + mcp = FastMCP("GitHubFileViewer") -GH_TOKEN = os.getenv('GH_TOKEN', default='') +GH_TOKEN = os.getenv("GH_TOKEN", default="") -SEARCH_RESULT_DIR = mcp_data_dir('seclab-taskflows', 'gh_file_viewer', 'SEARCH_RESULTS_DIR') +SEARCH_RESULT_DIR = mcp_data_dir("seclab-taskflows", "gh_file_viewer", "SEARCH_RESULTS_DIR") -engine = create_engine(f'sqlite:///{os.path.abspath(SEARCH_RESULT_DIR)}/search_result.db', echo=False) -Base.metadata.create_all(engine, tables = [SearchResults.__table__]) +engine = create_engine(f"sqlite:///{os.path.abspath(SEARCH_RESULT_DIR)}/search_result.db", echo=False) +Base.metadata.create_all(engine, tables=[SearchResults.__table__]) async def call_api(url: str, params: dict) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = {"Accept": "application/vnd.github.raw+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github.raw+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } + async def _fetch_file(url, headers, params): try: - async with httpx.AsyncClient(headers = headers) as client: + async with httpx.AsyncClient(headers=headers) as client: r = await client.get(url, params=params, follow_redirects=True) r.raise_for_status() return r @@ -70,19 +79,24 @@ async def _fetch_file(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - return await _fetch_file(url, headers = headers, params=params) + return await _fetch_file(url, headers=headers, params=params) + def remove_root_dir(path): - return '/'.join(path.split('/')[1:]) + return "/".join(path.split("/")[1:]) + async def _fetch_source_zip(owner: str, repo: str, tmp_dir): """Fetch the source code.""" url = f"https://api.github.com/repos/{owner}/{repo}/zipball" - headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } try: async with httpx.AsyncClient() as client: - async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: + async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: response.raise_for_status() expected_path = Path(tmp_dir) / owner / f"{repo}.zip" resolved_path = expected_path.resolve() @@ -90,7 +104,7 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: Invalid path for source code: {expected_path}" if not Path(f"{tmp_dir}/{owner}").exists(): os.makedirs(f"{tmp_dir}/{owner}", exist_ok=True) - async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", 'wb') as f: + async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", "wb") as f: async for chunk in response.aiter_bytes(): await f.write(chunk) return f"source code for {repo} fetched successfully." @@ -101,20 +115,21 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): except Exception as e: return f"Error: An unexpected error occurred: {e}" + def search_zipfile(database_path, term): results = {} with zipfile.ZipFile(database_path) as z: for entry in z.infolist(): if entry.is_dir(): continue - with z.open(entry, 'r') as f: + with z.open(entry, "r") as f: for i, line in enumerate(f): if term in str(line): filename = remove_root_dir(entry.filename) if not filename in results: - results[filename] = [i+1] + results[filename] = [i + 1] else: - results[filename].append(i+1) + results[filename].append(i + 1) return results @@ -122,40 +137,36 @@ def search_zipfile(database_path, term): async def fetch_file_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the file in the repository"))-> str: + path: str = Field(description="The path to the file in the repository"), +) -> str: """ Fetch the content of a file from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) if isinstance(r, str): return r lines = r.text.splitlines() for i in range(len(lines)): - lines[i] = f"{i+1}: {lines[i]}" + lines[i] = f"{i + 1}: {lines[i]}" return "\n".join(lines) + @mcp.tool() async def get_file_lines_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), start_line: int = Field(description="The starting line number to fetch from the file", default=1), - length: int = Field(description="The ending line number to fetch from the file", default=10)) -> str: - """Fetch a range of lines from a file in a GitHub repository. - """ + length: int = Field(description="The ending line number to fetch from the file", default=10), +) -> str: + """Fetch a range of lines from a file in a GitHub repository.""" owner = owner.lower() repo = repo.lower() - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) if isinstance(r, str): return r lines = r.text.splitlines() @@ -163,61 +174,63 @@ async def get_file_lines_from_gh( start_line = 1 if length < 1: length = 10 - lines = lines[start_line-1:start_line-1+length] + lines = lines[start_line - 1 : start_line - 1 + length] if not lines: return f"No lines found in the range {start_line} to {start_line + length - 1} in {path}." - return "\n".join([f"{i+start_line}: {line}" for i, line in enumerate(lines)]) + return "\n".join([f"{i + start_line}: {line}" for i, line in enumerate(lines)]) + @mcp.tool() async def search_file_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), - search_term: str = Field(description="The term to search for in the file")) -> str: + search_term: str = Field(description="The term to search for in the file"), +) -> str: """ Search for a term in a file from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) if isinstance(r, str): return r lines = r.text.splitlines() - matches = [f"{i+1}: {line}" for i,line in enumerate(lines) if search_term in line] + matches = [f"{i + 1}: {line}" for i, line in enumerate(lines) if search_term in line] if not matches: return f"No matches found for '{search_term}' in {path}." return "\n".join(matches) + @mcp.tool() async def search_files_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), paths: str = Field(description="A comma separated list of paths to the file in the repository"), search_term: str = Field(description="The term to search for in the file"), - save_to_db: bool = Field(description="Save the results to database.", default=False)) -> str: + save_to_db: bool = Field(description="Save the results to database.", default=False), +) -> str: """ Search for a term in a list of files from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - paths_list = [path.strip() for path in paths.split(',')] + paths_list = [path.strip() for path in paths.split(",")] if not paths_list: return "No paths provided for search." results = [] for path in paths_list: - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) if isinstance(r, str): return r lines = r.text.splitlines() - matches = [{"path": path, "line" : i+1, "search_term": search_term, "owner": owner.lower(), "repo" : repo.lower()} for i,line in enumerate(lines) if search_term in line] + matches = [ + {"path": path, "line": i + 1, "search_term": search_term, "owner": owner.lower(), "repo": repo.lower()} + for i, line in enumerate(lines) + if search_term in line + ] if matches: results.extend(matches) if not results: @@ -231,6 +244,7 @@ async def search_files_from_gh( return f"Search results saved to database." return json.dumps(results) + @mcp.tool() def fetch_last_search_results() -> str: """ @@ -240,43 +254,54 @@ def fetch_last_search_results() -> str: results = session.query(SearchResults).all() session.query(SearchResults).delete() session.commit() - return json.dumps([{"path": result.path, "line" : result.line, "search_term": result.search_term, "owner": result.owner.lower(), "repo" : result.repo.lower()} for result in results]) + return json.dumps( + [ + { + "path": result.path, + "line": result.line, + "search_term": result.search_term, + "owner": result.owner.lower(), + "repo": result.repo.lower(), + } + for result in results + ] + ) + @mcp.tool() async def list_directory_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository")) -> str: + path: str = Field(description="The path to the directory in the repository"), +) -> str: """ Fetch the content of a directory from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", - params={} - ) + r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) if isinstance(r, str): return r if not r.json(): return json.dumps([], indent=2) - content = [item['path'] for item in r.json()] + content = [item["path"] for item in r.json()] return json.dumps(content, indent=2) + @mcp.tool() async def search_repo_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - search_term: str = Field(description="The term to search within the repo.") + search_term: str = Field(description="The term to search within the repo."), ): """ Search for the search term in the entire repository. """ owner = owner.lower() repo = repo.lower() - + with tempfile.TemporaryDirectory() as tmp_dir: result = await _fetch_source_zip(owner, repo, tmp_dir) source_path = Path(f"{tmp_dir}/{owner}/{repo}.zip") @@ -284,9 +309,10 @@ async def search_repo_from_gh( return json.dumps([result], indent=2) results = search_zipfile(source_path, search_term) out = [] - for k,v in results.items(): + for k, v in results.items(): out.append({"owner": owner, "repo": repo, "path": k, "lines": v}) return json.dumps(out, indent=2) + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/ghsa.py b/src/seclab_taskflows/mcp_servers/ghsa.py index c149df6..4611c71 100644 --- a/src/seclab_taskflows/mcp_servers/ghsa.py +++ b/src/seclab_taskflows/mcp_servers/ghsa.py @@ -10,9 +10,9 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_ghsa.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_ghsa.log"), + filemode="a", ) mcp = FastMCP("GitHubRepoAdvisories") @@ -30,10 +30,11 @@ def parse_advisory(advisory: dict) -> dict: "state": advisory.get("state", ""), } + async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: """Fetch all security advisories for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/security-advisories" - params = {'per_page': 100} + params = {"per_page": 100} # See https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') results = [] @@ -42,7 +43,7 @@ async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: if isinstance(resp, str): return resp resp_headers = resp.headers - link = resp_headers.get('link', '') + link = resp_headers.get("link", "") resp = resp.json() if isinstance(resp, list): results += [parse_advisory(advisory) for advisory in resp] @@ -58,9 +59,11 @@ async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: return results return "No advisories found." + @mcp.tool() -async def fetch_GHSA_list(owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name")) -> str: +async def fetch_GHSA_list( + owner: str = Field(description="The owner of the repo"), repo: str = Field(description="The repository name") +) -> str: """Fetch all GitHub Security Advisories (GHSAs) for a specific repository.""" results = await fetch_GHSA_list_from_gh(owner, repo) if isinstance(results, str): @@ -78,15 +81,19 @@ async def fetch_GHSA_details_from_gh(owner: str, repo: str, ghsa_id: str) -> str return resp.json() return "Not found." + @mcp.tool() -async def fetch_GHSA_details(owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name"), - ghsa_id: str = Field(description="The ghsa_id of the advisory")) -> str: +async def fetch_GHSA_details( + owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name"), + ghsa_id: str = Field(description="The ghsa_id of the advisory"), +) -> str: """Fetch a GitHub Security Advisory for a specific repository and GHSA ID.""" results = await fetch_GHSA_details_from_gh(owner, repo, ghsa_id) if isinstance(results, str): return results return json.dumps(results, indent=2) + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/local_file_viewer.py b/src/seclab_taskflows/mcp_servers/local_file_viewer.py index 4dd73bc..a85297b 100644 --- a/src/seclab_taskflows/mcp_servers/local_file_viewer.py +++ b/src/seclab_taskflows/mcp_servers/local_file_viewer.py @@ -15,18 +15,19 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_local_file_viewer.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_local_file_viewer.log"), + filemode="a", ) mcp = FastMCP("LocalFileViewer") -LOCAL_GH_DIR = mcp_data_dir('seclab-taskflows', 'local_file_viewer', 'LOCAL_GH_DIR') +LOCAL_GH_DIR = mcp_data_dir("seclab-taskflows", "local_file_viewer", "LOCAL_GH_DIR") -LINE_LIMIT_FOR_FETCHING_FILE_CONTENT = int(os.getenv('LINE_LIMIT_FOR_FETCHING_FILE_CONTENT', default=1000)) +LINE_LIMIT_FOR_FETCHING_FILE_CONTENT = int(os.getenv("LINE_LIMIT_FOR_FETCHING_FILE_CONTENT", default=1000)) + +FILE_LIMIT_FOR_LIST_FILES = int(os.getenv("FILE_LIMIT_FOR_LIST_FILES", default=100)) -FILE_LIMIT_FOR_LIST_FILES = int(os.getenv('FILE_LIMIT_FOR_LIST_FILES', default=100)) def is_subdirectory(directory, potential_subdirectory): directory_path = Path(directory) @@ -37,6 +38,7 @@ def is_subdirectory(directory, potential_subdirectory): except ValueError: return False + def sanitize_file_path(file_path, allow_paths): file_path = os.path.realpath(file_path) for allowed_path in allow_paths: @@ -44,15 +46,18 @@ def sanitize_file_path(file_path, allow_paths): return Path(file_path) return None + def remove_root_dir(path): - return '/'.join(path.split('/')[1:]) + return "/".join(path.split("/")[1:]) + def strip_leading_dash(path): - if path and path[0] == '/': + if path and path[0] == "/": path = path[1:] return path -def search_zipfile(database_path, term, search_dir = None): + +def search_zipfile(database_path, term, search_dir=None): results = {} search_dir = strip_leading_dash(search_dir) with zipfile.ZipFile(database_path) as z: @@ -61,17 +66,18 @@ def search_zipfile(database_path, term, search_dir = None): continue if search_dir and not is_subdirectory(search_dir, remove_root_dir(entry.filename)): continue - with z.open(entry, 'r') as f: + with z.open(entry, "r") as f: for i, line in enumerate(f): if term in str(line): filename = remove_root_dir(entry.filename) if not filename in results: - results[filename] = [i+1] + results[filename] = [i + 1] else: - results[filename].append(i+1) + results[filename].append(i + 1) return results -def _list_files(database_path, root_dir = None, recursive=True): + +def _list_files(database_path, root_dir=None, recursive=True): results = [] root_dir = strip_leading_dash(root_dir) with zipfile.ZipFile(database_path) as z: @@ -80,7 +86,7 @@ def _list_files(database_path, root_dir = None, recursive=True): if not recursive: dirname = remove_root_dir(entry.filename) if Path(dirname).parent == Path(root_dir): - results.append(dirname + '/') + results.append(dirname + "/") continue filename = remove_root_dir(entry.filename) if root_dir and not is_subdirectory(root_dir, filename): @@ -90,6 +96,7 @@ def _list_files(database_path, root_dir = None, recursive=True): results.append(filename) return results + def get_file(database_path, filename): results = [] filename = strip_leading_dash(filename) @@ -98,16 +105,18 @@ def get_file(database_path, filename): if entry.is_dir(): continue if remove_root_dir(entry.filename) == filename: - with z.open(entry, 'r') as f: + with z.open(entry, "r") as f: results = [line.rstrip() for line in f] return results return results + @mcp.tool() async def fetch_file_content( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the file in the repository"))-> str: + path: str = Field(description="The path to the file in the repository"), +) -> str: """ Fetch the content of a file from a local GitHub repository. """ @@ -124,18 +133,19 @@ async def fetch_file_content( if not lines: return f"Unable to find file {path} in {owner}/{repo}" for i in range(len(lines)): - lines[i] = f"{i+1}: {lines[i]}" + lines[i] = f"{i + 1}: {lines[i]}" return "\n".join(lines) + @mcp.tool() async def get_file_lines( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), start_line: int = Field(description="The starting line number to fetch from the file", default=1), - length: int = Field(description="The ending line number to fetch from the file", default=10)) -> str: - """Fetch a range of lines from a file in a local GitHub repository. - """ + length: int = Field(description="The ending line number to fetch from the file", default=10), +) -> str: + """Fetch a range of lines from a file in a local GitHub repository.""" owner = owner.lower() repo = repo.lower() @@ -148,16 +158,18 @@ async def get_file_lines( start_line = 1 if length < 1: length = 10 - lines = lines[start_line-1:start_line-1+length] + lines = lines[start_line - 1 : start_line - 1 + length] if not lines: return f"No lines found in the range {start_line} to {start_line + length - 1} in {path}." - return "\n".join([f"{i+start_line}: {line}" for i, line in enumerate(lines)]) + return "\n".join([f"{i + start_line}: {line}" for i, line in enumerate(lines)]) + @mcp.tool() async def list_files( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository")) -> str: + path: str = Field(description="The path to the directory in the repository"), +) -> str: """ Recursively list the files of a directory from a local GitHub repository. """ @@ -173,11 +185,13 @@ async def list_files( return f"Too many files to display in {owner}/{repo} at path {path} ({len(content)} files). Try using `list_files_non_recursive` instead." return json.dumps(content, indent=2) + @mcp.tool() async def list_files_non_recursive( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository")) -> str: + path: str = Field(description="The path to the directory in the repository"), +) -> str: """ List the files of a directory from a local GitHub repository non-recursively. Subdirectories will be listed and indicated with a trailing slash. @@ -198,7 +212,10 @@ async def search_repo( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), search_term: str = Field(description="The term to search within the repo."), - directory: str = Field(description="The directory or file to restrict the search, if not provided, the whole repo is searched", default = '') + directory: str = Field( + description="The directory or file to restrict the search, if not provided, the whole repo is searched", + default="", + ), ): """ Search for the search term in the repository or a subdirectory/file in the repository. @@ -214,9 +231,10 @@ async def search_repo( return json.dumps([], indent=2) results = search_zipfile(source_path, search_term, directory) out = [] - for k,v in results.items(): + for k, v in results.items(): out.append({"owner": owner, "repo": repo, "path": k, "lines": v}) return json.dumps(out, indent=2) + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/local_gh_resources.py b/src/seclab_taskflows/mcp_servers/local_gh_resources.py index 3c48ad6..33fc641 100644 --- a/src/seclab_taskflows/mcp_servers/local_gh_resources.py +++ b/src/seclab_taskflows/mcp_servers/local_gh_resources.py @@ -15,16 +15,17 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_local_gh_resources.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_local_gh_resources.log"), + filemode="a", ) mcp = FastMCP("LocalGHResources") -GH_TOKEN = os.getenv('GH_TOKEN') +GH_TOKEN = os.getenv("GH_TOKEN") + +LOCAL_GH_DIR = mcp_data_dir("seclab-taskflows", "local_gh_resources", "LOCAL_GH_DIR") -LOCAL_GH_DIR = mcp_data_dir('seclab-taskflows', 'local_gh_resources', 'LOCAL_GH_DIR') def is_subdirectory(directory, potential_subdirectory): directory_path = Path(directory) @@ -35,6 +36,7 @@ def is_subdirectory(directory, potential_subdirectory): except ValueError: return False + def sanitize_file_path(file_path, allow_paths): file_path = os.path.realpath(file_path) for allowed_path in allow_paths: @@ -42,13 +44,18 @@ def sanitize_file_path(file_path, allow_paths): return Path(file_path) return None + async def call_api(url: str, params: dict) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = {"Accept": "application/vnd.github.raw+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github.raw+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } + async def _fetch_file(url, headers, params): try: - async with httpx.AsyncClient(headers = headers) as client: + async with httpx.AsyncClient(headers=headers) as client: r = await client.get(url, params=params, follow_redirects=True) r.raise_for_status() return r @@ -61,16 +68,20 @@ async def _fetch_file(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - return await _fetch_file(url, headers = headers, params=params) + return await _fetch_file(url, headers=headers, params=params) + async def _fetch_source_zip(owner: str, repo: str, tmp_dir): """Fetch the source code.""" url = f"https://api.github.com/repos/{owner}/{repo}/zipball" - headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}"} + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}", + } try: async with httpx.AsyncClient() as client: - async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: + async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: response.raise_for_status() expected_path = Path(tmp_dir) / owner / f"{repo}.zip" resolved_path = expected_path.resolve() @@ -78,7 +89,7 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: Invalid path for source code: {expected_path}" if not Path(f"{tmp_dir}/{owner}").exists(): os.makedirs(f"{tmp_dir}/{owner}", exist_ok=True) - async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", 'wb') as f: + async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", "wb") as f: async for chunk in response.aiter_bytes(): await f.write(chunk) return f"source code for {repo} fetched successfully." @@ -88,10 +99,10 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: HTTP error: {e}" except Exception as e: return f"Error: An unexpected error occurred: {e}" + + @mcp.tool() -async def fetch_repo_from_gh( - owner: str, repo: str -): +async def fetch_repo_from_gh(owner: str, repo: str): """ Download the source code from GitHub to the local file system to speed up file search. """ @@ -104,6 +115,7 @@ async def fetch_repo_from_gh( return result return f"Downloaded source code to {owner}/{repo}.zip" + @mcp.tool() async def clear_local_repo(owner: str, repo: str): """ diff --git a/src/seclab_taskflows/mcp_servers/repo_context.py b/src/seclab_taskflows/mcp_servers/repo_context.py index 3c71bfc..e69eaae 100644 --- a/src/seclab_taskflows/mcp_servers/repo_context.py +++ b/src/seclab_taskflows/mcp_servers/repo_context.py @@ -19,12 +19,13 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_repo_context.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_repo_context.log"), + filemode="a", ) -MEMORY = mcp_data_dir('seclab-taskflows', 'repo_context', 'REPO_CONTEXT_DIR') +MEMORY = mcp_data_dir("seclab-taskflows", "repo_context", "REPO_CONTEXT_DIR") + def app_to_dict(result): return { @@ -33,9 +34,10 @@ def app_to_dict(result): "location": result.location, "notes": result.notes, "is_app": result.is_app, - "is_library": result.is_library + "is_library": result.is_library, } + def entry_point_to_dict(ep): return { "id": ep.id, @@ -44,9 +46,10 @@ def entry_point_to_dict(ep): "user_input": ep.user_input, "repo": ep.repo.lower(), "line": ep.line, - "notes": ep.notes + "notes": ep.notes, } + def user_action_to_dict(ua): return { "id": ua.id, @@ -54,9 +57,10 @@ def user_action_to_dict(ua): "file": ua.file, "line": ua.line, "repo": ua.repo.lower(), - "notes": ua.notes + "notes": ua.notes, } + def web_entry_point_to_dict(wep): return { "id": wep.id, @@ -68,36 +72,47 @@ def web_entry_point_to_dict(wep): "middleware": wep.middleware, "roles_scopes": wep.roles_scopes, "repo": wep.repo.lower(), - "notes": wep.notes + "notes": wep.notes, } + def audit_result_to_dict(res): return { - "id" : res.id, - "repo" : res.repo.lower(), - "component_id" : res.component_id, - "issue_type" : res.issue_type, - "issue_id" : res.issue_id, - "notes" : res.notes, + "id": res.id, + "repo": res.repo.lower(), + "component_id": res.component_id, + "issue_type": res.issue_type, + "issue_id": res.issue_id, + "notes": res.notes, "has_vulnerability": res.has_vulnerability, - "has_non_security_error": res.has_non_security_error + "has_non_security_error": res.has_non_security_error, } + class RepoContextBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir - self.location_pattern = r'^([a-zA-Z]+)(:\d+){4}$' + self.location_pattern = r"^([a-zA-Z]+)(:\d+){4}$" if not Path(self.memcache_state_dir).exists(): - db_dir = 'sqlite://' + db_dir = "sqlite://" else: - db_dir = f'sqlite:///{self.memcache_state_dir}/repo_context.db' + db_dir = f"sqlite:///{self.memcache_state_dir}/repo_context.db" self.engine = create_engine(db_dir, echo=False) - Base.metadata.create_all(self.engine, tables=[Application.__table__, EntryPoint.__table__, UserAction.__table__, - WebEntryPoint.__table__, ApplicationIssue.__table__, AuditResult.__table__]) + Base.metadata.create_all( + self.engine, + tables=[ + Application.__table__, + EntryPoint.__table__, + UserAction.__table__, + WebEntryPoint.__table__, + ApplicationIssue.__table__, + AuditResult.__table__, + ], + ) def store_new_application(self, repo, location, is_app, is_library, notes): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo = repo, location = location).first() + existing = session.query(Application).filter_by(repo=repo, location=location).first() if existing: if is_app is not None: existing.is_app = is_app @@ -105,25 +120,31 @@ def store_new_application(self, repo, location, is_app, is_library, notes): existing.is_library = is_library existing.notes += notes else: - new_application = Application(repo = repo, location = location, is_app = is_app, is_library = is_library, notes = notes) + new_application = Application( + repo=repo, location=location, is_app=is_app, is_library=is_library, notes=notes + ) session.add(new_application) session.commit() return f"Updated or added application for {location} in {repo}." def store_new_component_issue(self, repo, component_id, issue_type, notes): with Session(self.engine) as session: - existing = session.query(ApplicationIssue).filter_by(repo = repo, component_id = component_id, issue_type = issue_type).first() + existing = ( + session.query(ApplicationIssue) + .filter_by(repo=repo, component_id=component_id, issue_type=issue_type) + .first() + ) if existing: existing.notes += notes else: - new_issue = ApplicationIssue(repo = repo, component_id = component_id, issue_type = issue_type, notes = notes) + new_issue = ApplicationIssue(repo=repo, component_id=component_id, issue_type=issue_type, notes=notes) session.add(new_issue) session.commit() return f"Updated or added application issue for {repo} and {component_id}" def overwrite_component_issue_notes(self, id, notes): with Session(self.engine) as session: - existing = session.query(ApplicationIssue).filter_by(id = id).first() + existing = session.query(ApplicationIssue).filter_by(id=id).first() if not existing: return f"Component issue with id {id} does not exist!" else: @@ -131,36 +152,49 @@ def overwrite_component_issue_notes(self, id, notes): session.commit() return f"Updated notes for application issue with id {id}" - def store_new_audit_result(self, repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes): + def store_new_audit_result( + self, repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes + ): with Session(self.engine) as session: - existing = session.query(AuditResult).filter_by(repo = repo, issue_id = issue_id).first() + existing = session.query(AuditResult).filter_by(repo=repo, issue_id=issue_id).first() if existing: existing.notes += notes existing.has_non_security_error = has_non_security_error existing.has_vulnerability = has_vulnerability else: - new_result = AuditResult(repo = repo, component_id = component_id, issue_type = issue_type, issue_id = issue_id, notes = notes, - has_non_security_error = has_non_security_error, has_vulnerability = has_vulnerability) + new_result = AuditResult( + repo=repo, + component_id=component_id, + issue_type=issue_type, + issue_id=issue_id, + notes=notes, + has_non_security_error=has_non_security_error, + has_vulnerability=has_vulnerability, + ) session.add(new_result) session.commit() return f"Updated or added audit result for {repo} and {issue_id}" - def store_new_entry_point(self, repo, app_id, file, user_input, line, notes, update = False): + def store_new_entry_point(self, repo, app_id, file, user_input, line, notes, update=False): with Session(self.engine) as session: - existing = session.query(EntryPoint).filter_by(repo = repo, file = file, line = line).first() + existing = session.query(EntryPoint).filter_by(repo=repo, file=file, line=line).first() if existing: existing.notes += notes else: if update: return f"No entry point exists at repo {repo}, file {file} and line {line}" - new_entry_point = EntryPoint(repo = repo, app_id = app_id, file = file, user_input = user_input, line = line, notes = notes) + new_entry_point = EntryPoint( + repo=repo, app_id=app_id, file=file, user_input=user_input, line=line, notes=notes + ) session.add(new_entry_point) session.commit() return f"Updated or added entry point for {file} and {line} in {repo}." - def store_new_web_entry_point(self, repo, entry_point_id, method, path, component, auth, middleware, roles_scopes, notes, update = False): + def store_new_web_entry_point( + self, repo, entry_point_id, method, path, component, auth, middleware, roles_scopes, notes, update=False + ): with Session(self.engine) as session: - existing = session.query(WebEntryPoint).filter_by(repo = repo, entry_point_id = entry_point_id).first() + existing = session.query(WebEntryPoint).filter_by(repo=repo, entry_point_id=entry_point_id).first() if existing: existing.notes += notes if method: @@ -179,163 +213,188 @@ def store_new_web_entry_point(self, repo, entry_point_id, method, path, componen if update: return f"No web entry point exists at repo {repo} with entry_point_id {entry_point_id}." new_web_entry_point = WebEntryPoint( - repo = repo, - entry_point_id = entry_point_id, - method = method, - path = path, - component = component, - auth = auth, - middleware = middleware, - roles_scopes = roles_scopes, - notes = notes + repo=repo, + entry_point_id=entry_point_id, + method=method, + path=path, + component=component, + auth=auth, + middleware=middleware, + roles_scopes=roles_scopes, + notes=notes, ) session.add(new_web_entry_point) session.commit() return f"Updated or added web entry point for entry_point_id {entry_point_id} in {repo}." - def store_new_user_action(self, repo, app_id, file, line, notes, update = False): + def store_new_user_action(self, repo, app_id, file, line, notes, update=False): with Session(self.engine) as session: - existing = session.query(UserAction).filter_by(repo = repo, file = file, line = line).first() + existing = session.query(UserAction).filter_by(repo=repo, file=file, line=line).first() if existing: existing.notes += notes else: if update: return f"No user action exists at repo {repo}, file {file} and line {line}." - new_user_action = UserAction(repo = repo, app_id = app_id, file = file, line = line, notes = notes) + new_user_action = UserAction(repo=repo, app_id=app_id, file=file, line=line, notes=notes) session.add(new_user_action) session.commit() return f"Updated or added user action for {file} and {line} in {repo}." def get_app(self, repo, location): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo = repo, location = location).first() + existing = session.query(Application).filter_by(repo=repo, location=location).first() if not existing: return None return existing def get_apps(self, repo): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo = repo).all() + existing = session.query(Application).filter_by(repo=repo).all() return [app_to_dict(app) for app in existing] def get_app_issues(self, repo, component_id): with Session(self.engine) as session: issues = session.query(Application, ApplicationIssue).filter( - Application.repo == repo, - Application.id == ApplicationIssue.component_id + Application.repo == repo, Application.id == ApplicationIssue.component_id ) if component_id is not None: issues = issues.filter(Application.id == component_id) issues = issues.all() - return [{ - 'component_id': app.id, - 'location' : app.location, - 'repo' : app.repo, - 'component_notes' : app.notes, - 'issue_type' : issue.issue_type, - 'issue_notes': issue.notes, - 'issue_id' : issue.id - } for app, issue in issues] + return [ + { + "component_id": app.id, + "location": app.location, + "repo": app.repo, + "component_notes": app.notes, + "issue_type": issue.issue_type, + "issue_notes": issue.notes, + "issue_id": issue.id, + } + for app, issue in issues + ] def get_app_audit_results(self, repo, component_id, has_non_security_error, has_vulnerability): with Session(self.engine) as session: - issues = session.query(Application, AuditResult).filter(Application.repo == repo - ).filter(Application.id == AuditResult.component_id) + issues = ( + session.query(Application, AuditResult) + .filter(Application.repo == repo) + .filter(Application.id == AuditResult.component_id) + ) if component_id is not None: - issues = issues.filter(Application.id == component_id) + issues = issues.filter(Application.id == component_id) if has_non_security_error is not None: issues = issues.filter(AuditResult.has_non_security_error == has_non_security_error) if has_vulnerability is not None: issues = issues.filter(AuditResult.has_vulnerability == has_vulnerability) issues = issues.all() - return [{ - 'component_id': app.id, - 'location' : app.location, - 'repo' : app.repo, - 'issue_type' : issue.issue_type, - 'issue_id' : issue.issue_id, - 'notes': issue.notes, - 'has_vulnerability' : issue.has_vulnerability, - 'has_non_security_error' : issue.has_non_security_error - } for app, issue in issues] + return [ + { + "component_id": app.id, + "location": app.location, + "repo": app.repo, + "issue_type": issue.issue_type, + "issue_id": issue.issue_id, + "notes": issue.notes, + "has_vulnerability": issue.has_vulnerability, + "has_non_security_error": issue.has_non_security_error, + } + for app, issue in issues + ] def get_app_entries(self, repo, location): with Session(self.engine) as session: - results = session.query(Application, EntryPoint - ).filter(Application.repo == repo, Application.location == location - ).filter(EntryPoint.app_id == Application.id).all() + results = ( + session.query(Application, EntryPoint) + .filter(Application.repo == repo, Application.location == location) + .filter(EntryPoint.app_id == Application.id) + .all() + ) eps = [entry_point_to_dict(ep) for app, ep in results] return eps def get_app_entries_for_repo(self, repo): with Session(self.engine) as session: - results = session.query(Application, EntryPoint - ).filter(Application.repo == repo - ).filter(EntryPoint.app_id == Application.id).all() + results = ( + session.query(Application, EntryPoint) + .filter(Application.repo == repo) + .filter(EntryPoint.app_id == Application.id) + .all() + ) eps = [entry_point_to_dict(ep) for app, ep in results] return eps def get_web_entries_for_repo(self, repo): with Session(self.engine) as session: - results = session.query(WebEntryPoint).filter_by(repo = repo).all() - return [{ - 'repo' : r.repo, - 'entry_point_id' : r.entry_point_id, - 'method' : r.method, - 'path' : r.path, - 'component' : r.component, - 'auth' : r.auth, - 'middleware' : r.middleware, - 'roles_scopes' : r.roles_scopes, - 'notes' : r.notes - } for r in results] + results = session.query(WebEntryPoint).filter_by(repo=repo).all() + return [ + { + "repo": r.repo, + "entry_point_id": r.entry_point_id, + "method": r.method, + "path": r.path, + "component": r.component, + "auth": r.auth, + "middleware": r.middleware, + "roles_scopes": r.roles_scopes, + "notes": r.notes, + } + for r in results + ] def get_web_entries(self, repo, component_id): with Session(self.engine) as session: - results = session.query(WebEntryPoint).filter_by(repo = repo, component = component_id).all() - return [{ - 'repo' : r.repo, - 'entry_point_id' : r.entry_point_id, - 'method' : r.method, - 'path' : r.path, - 'component' : r.component, - 'auth' : r.auth, - 'middleware' : r.middleware, - 'roles_scopes' : r.roles_scopes, - 'notes' : r.notes - } for r in results] - + results = session.query(WebEntryPoint).filter_by(repo=repo, component=component_id).all() + return [ + { + "repo": r.repo, + "entry_point_id": r.entry_point_id, + "method": r.method, + "path": r.path, + "component": r.component, + "auth": r.auth, + "middleware": r.middleware, + "roles_scopes": r.roles_scopes, + "notes": r.notes, + } + for r in results + ] def get_user_actions(self, repo, location): with Session(self.engine) as session: - results = session.query(Application, UserAction - ).filter(Application.repo == repo, Application.location == location - ).filter(UserAction.app_id == Application.id).all() + results = ( + session.query(Application, UserAction) + .filter(Application.repo == repo, Application.location == location) + .filter(UserAction.app_id == Application.id) + .all() + ) uas = [user_action_to_dict(ua) for app, ua in results] return uas def get_user_actions_for_repo(self, repo): with Session(self.engine) as session: - results = session.query(Application, UserAction - ).filter(Application.repo == repo - ).filter(UserAction.app_id == Application.id).all() + results = ( + session.query(Application, UserAction) + .filter(Application.repo == repo) + .filter(UserAction.app_id == Application.id) + .all() + ) uas = [user_action_to_dict(ua) for app, ua in results] return uas def clear_repo(self, repo): with Session(self.engine) as session: - session.query(Application).filter_by(repo = repo).delete() - session.query(EntryPoint).filter_by(repo = repo).delete() - session.query(UserAction).filter_by(repo = repo).delete() - session.query(ApplicationIssue).filter_by(repo = repo).delete() - session.query(WebEntryPoint).filter_by(repo = repo).delete() - session.query(AuditResult).filter_by(repo = repo).delete() + session.query(Application).filter_by(repo=repo).delete() + session.query(EntryPoint).filter_by(repo=repo).delete() + session.query(UserAction).filter_by(repo=repo).delete() + session.query(ApplicationIssue).filter_by(repo=repo).delete() + session.query(WebEntryPoint).filter_by(repo=repo).delete() + session.query(AuditResult).filter_by(repo=repo).delete() session.commit() return f"Cleared results for repo {repo}" def clear_repo_issues(self, repo): with Session(self.engine) as session: - session.query(ApplicationIssue).filter_by(repo = repo).delete() + session.query(ApplicationIssue).filter_by(repo=repo).delete() session.commit() return f"Clear application issues for repo {repo}" @@ -344,23 +403,29 @@ def clear_repo_issues(self, repo): backend = RepoContextBackend(MEMORY) + @mcp.tool() -def store_new_component(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component"), - is_app: bool = Field(description="Is this an application", default=None), - is_library: bool = Field(description="Is this a library", default=None), - notes: str = Field(description="The notes taken for this component", default="")): +def store_new_component( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component"), + is_app: bool = Field(description="Is this an application", default=None), + is_library: bool = Field(description="Is this a library", default=None), + notes: str = Field(description="The notes taken for this component", default=""), +): """ Stores a new component in the database. """ return backend.store_new_application(process_repo(owner, repo), location, is_app, is_library, notes) + @mcp.tool() -def add_component_notes(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component", default=None), - notes: str = Field(description="New notes taken for this component", default="")): +def add_component_notes( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component", default=None), + notes: str = Field(description="New notes taken for this component", default=""), +): """ Add new notes to a component """ @@ -370,14 +435,17 @@ def add_component_notes(owner: str = Field(description="The owner of the GitHub return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_application(repo, location, None, None, notes) + @mcp.tool() -def store_new_entry_point(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the entry point belongs to"), - file: str = Field(description="The file that contains the entry point"), - line: int = Field(description="The file line that contains the entry point"), - user_input: str = Field(description="The variables that are considered as user input"), - notes: str = Field(description="The notes for this entry point", default = "")): +def store_new_entry_point( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the entry point belongs to"), + file: str = Field(description="The file that contains the entry point"), + line: int = Field(description="The file line that contains the entry point"), + user_input: str = Field(description="The variables that are considered as user input"), + notes: str = Field(description="The notes for this entry point", default=""), +): """ Stores a new entry point in a component to the database. """ @@ -387,58 +455,76 @@ def store_new_entry_point(owner: str = Field(description="The owner of the GitHu return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_entry_point(repo, app.id, file, user_input, line, notes) + @mcp.tool() -def store_new_component_issue(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), - issue_type: str = Field(description="The type of issue"), - notes: str = Field(description="Notes about the issue")): +def store_new_component_issue( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), + issue_type: str = Field(description="The type of issue"), + notes: str = Field(description="Notes about the issue"), +): """ Stores a type of common issue for a component. """ repo = process_repo(owner, repo) return backend.store_new_component_issue(repo, component_id, issue_type, notes) + @mcp.tool() -def store_new_audit_result(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), - issue_type: str = Field(description="The type of issue"), - issue_id: int = Field(description="The ID of the issue"), - has_non_security_error: bool = Field(description="Set to true if there are security issues or logic error but may not be exploitable"), - has_vulnerability: bool = Field(description="Set to true if a security vulnerability is identified"), - notes: str = Field(description="The notes for the audit of this issue")): +def store_new_audit_result( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), + issue_type: str = Field(description="The type of issue"), + issue_id: int = Field(description="The ID of the issue"), + has_non_security_error: bool = Field( + description="Set to true if there are security issues or logic error but may not be exploitable" + ), + has_vulnerability: bool = Field(description="Set to true if a security vulnerability is identified"), + notes: str = Field(description="The notes for the audit of this issue"), +): """ Stores the audit result for issue with issue_id. """ repo = process_repo(owner, repo) - return backend.store_new_audit_result(repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes) + return backend.store_new_audit_result( + repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes + ) + @mcp.tool() -def store_new_web_entry_point(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - entry_point_id: int = Field(description="The ID of the entry point this web entry point refers to"), - location: str = Field(description="The directory of the component where the web entry point belongs to"), - method: str = Field(description="HTTP method (GET, POST, etc)", default=""), - path: str = Field(description="URL path (e.g., /info)", default=""), - component: int = Field(description="Component identifier", default=0), - auth: str = Field(description="Authentication information", default=""), - middleware: str = Field(description="Middleware information", default=""), - roles_scopes: str = Field(description="Roles and scopes information", default=""), - notes: str = Field(description="Notes for this web entry point", default="")): +def store_new_web_entry_point( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + entry_point_id: int = Field(description="The ID of the entry point this web entry point refers to"), + location: str = Field(description="The directory of the component where the web entry point belongs to"), + method: str = Field(description="HTTP method (GET, POST, etc)", default=""), + path: str = Field(description="URL path (e.g., /info)", default=""), + component: int = Field(description="Component identifier", default=0), + auth: str = Field(description="Authentication information", default=""), + middleware: str = Field(description="Middleware information", default=""), + roles_scopes: str = Field(description="Roles and scopes information", default=""), + notes: str = Field(description="Notes for this web entry point", default=""), +): """ Stores a new web entry point in a component to the database. A web entry point extends a regular entry point with web-specific properties like HTTP method, path, authentication, middleware, and roles/scopes. """ - return backend.store_new_web_entry_point(process_repo(owner, repo), entry_point_id, method, path, component, auth, middleware, roles_scopes, notes) + return backend.store_new_web_entry_point( + process_repo(owner, repo), entry_point_id, method, path, component, auth, middleware, roles_scopes, notes + ) + @mcp.tool() -def add_entry_point_notes(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the entry point belongs to"), - file: str = Field(description="The file that contains the entry point"), - line: int = Field(description="The file line that contains the entry point"), - notes: str = Field(description="The notes for this entry point", default = "")): +def add_entry_point_notes( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the entry point belongs to"), + file: str = Field(description="The file that contains the entry point"), + line: int = Field(description="The file line that contains the entry point"), + notes: str = Field(description="The notes for this entry point", default=""), +): """ add new notes to an entry point. """ @@ -450,12 +536,14 @@ def add_entry_point_notes(owner: str = Field(description="The owner of the GitHu @mcp.tool() -def store_new_user_action(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the user action belongs to"), - file: str = Field(description="The file that contains the user action"), - line: int = Field(description="The file line that contains the user action"), - notes: str = Field(description="New notes for this user action", default = "")): +def store_new_user_action( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the user action belongs to"), + file: str = Field(description="The file that contains the user action"), + line: int = Field(description="The file line that contains the user action"), + notes: str = Field(description="New notes for this user action", default=""), +): """ Stores a new user action in a component to the database. """ @@ -465,23 +553,29 @@ def store_new_user_action(owner: str = Field(description="The owner of the GitHu return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_user_action(repo, app.id, file, line, notes) + @mcp.tool() -def add_user_action_notes(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the user action belongs to"), - file: str = Field(description="The file that contains the user action"), - line: str = Field(description="The file line that contains the user action"), - notes: str = Field(description="The notes for user action", default = "")): +def add_user_action_notes( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the user action belongs to"), + file: str = Field(description="The file that contains the user action"), + line: str = Field(description="The file line that contains the user action"), + notes: str = Field(description="The notes for user action", default=""), +): repo = process_repo(owner, repo) app = backend.get_app(repo, location) if not app: return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_user_action(repo, app.id, file, line, notes, True) + @mcp.tool() -def get_component(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component")): +def get_component( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component"), +): """ Get a component from the database """ @@ -491,85 +585,112 @@ def get_component(owner: str = Field(description="The owner of the GitHub reposi return f"Error: No component exists in repo: {repo} and location {location}" return json.dumps(app_to_dict(app)) + @mcp.tool() -def get_components(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_components( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get components from the repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_apps(repo)) + @mcp.tool() -def get_entry_points(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component")): +def get_entry_points( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component"), +): """ Get all the entry points of a component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_entries(repo, location)) + @mcp.tool() -def get_entry_points_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_entry_points_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get all entry points of an repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_entries_for_repo(repo)) + @mcp.tool() -def get_web_entry_points_component(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component")): +def get_web_entry_points_component( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), +): """ Get all web entry points for a component """ repo = process_repo(owner, repo) return json.dumps(backend.get_web_entries(repo, component_id)) + @mcp.tool() -def get_web_entry_points_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_web_entry_points_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get all web entry points of an repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_web_entries_for_repo(repo)) + @mcp.tool() -def get_user_actions(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component")): +def get_user_actions( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component"), +): """ Get all the user actions in a component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_user_actions(repo, location)) + @mcp.tool() -def get_user_actions_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_user_actions_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get all the user actions in a repo. """ repo = process_repo(owner, repo) return json.dumps(backend.get_user_actions_for_repo(repo)) + @mcp.tool() -def get_component_issues(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component")): +def get_component_issues( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), +): """ Get issues for the component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_issues(repo, component_id)) + @mcp.tool() -def get_component_issues_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_component_issues_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get all component issues for the repository. """ @@ -578,79 +699,113 @@ def get_component_issues_for_repo(owner: str = Field(description="The owner of t @mcp.tool() -def get_component_results(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component")): +def get_component_results( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), +): """ Get audit results for the component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_audit_results(repo, component_id, None, None)) + @mcp.tool() -def get_component_vulnerable_results(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component")): +def get_component_vulnerable_results( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), +): """ Get audit results for the component that are audited as vulnerable. """ repo = process_repo(owner, repo) - return json.dumps(backend.get_app_audit_results(repo, component_id, has_non_security_error = None, has_vulnerability = True)) + return json.dumps( + backend.get_app_audit_results(repo, component_id, has_non_security_error=None, has_vulnerability=True) + ) + @mcp.tool() -def get_component_potential_results(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component")): +def get_component_potential_results( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), +): """ Get audit results for the component that are audited as an issue but may not be exploitable. """ repo = process_repo(owner, repo) - return json.dumps(backend.get_app_audit_results(repo, component_id, has_non_security_error = True, has_vulnerability = None)) + return json.dumps( + backend.get_app_audit_results(repo, component_id, has_non_security_error=True, has_vulnerability=None) + ) + @mcp.tool() -def get_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_audit_results_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get audit results for the repo. """ repo = process_repo(owner, repo) - return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = None, has_vulnerability = None)) + return json.dumps( + backend.get_app_audit_results(repo, component_id=None, has_non_security_error=None, has_vulnerability=None) + ) + @mcp.tool() -def get_vulnerable_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_vulnerable_audit_results_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get audit results for the repo that are audited as vulnerable. """ repo = process_repo(owner, repo) - return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = None, has_vulnerability = True)) + return json.dumps( + backend.get_app_audit_results(repo, component_id=None, has_non_security_error=None, has_vulnerability=True) + ) + @mcp.tool() -def get_potential_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def get_potential_audit_results_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ Get audit results for the repo that are potential issues but may not be exploitable. """ repo = process_repo(owner, repo) - return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = True, has_vulnerability = None)) + return json.dumps( + backend.get_app_audit_results(repo, component_id=None, has_non_security_error=True, has_vulnerability=None) + ) + @mcp.tool() -def clear_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def clear_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ clear all results for repo. """ repo = process_repo(owner, repo) return backend.clear_repo(repo) + @mcp.tool() -def clear_component_issues_for_repo(owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository")): +def clear_component_issues_for_repo( + owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), +): """ clear all results for repo. """ repo = process_repo(owner, repo) return backend.clear_repo_issues(repo) + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/repo_context_models.py b/src/seclab_taskflows/mcp_servers/repo_context_models.py index cd3d8a2..7dd08cc 100644 --- a/src/seclab_taskflows/mcp_servers/repo_context_models.py +++ b/src/seclab_taskflows/mcp_servers/repo_context_models.py @@ -5,56 +5,67 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship from typing import Optional + class Base(DeclarativeBase): pass + class Application(Base): - __tablename__ = 'application' + __tablename__ = "application" id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] location: Mapped[str] notes: Mapped[str] = mapped_column(Text) is_app: Mapped[bool] = mapped_column(nullable=True) - is_library: Mapped[bool] = mapped_column(nullable = True) + is_library: Mapped[bool] = mapped_column(nullable=True) def __repr__(self): - return (f"") + return ( + f"" + ) + class ApplicationIssue(Base): - __tablename__ = 'application_issue' + __tablename__ = "application_issue" id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] - component_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) + component_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) issue_type: Mapped[str] = mapped_column(Text) notes: Mapped[str] = mapped_column(Text) def __repr__(self): - return (f"") + return ( + f"" + ) + class AuditResult(Base): - __tablename__ = 'audit_result' - id: Mapped[int] = mapped_column(primary_key = True) + __tablename__ = "audit_result" + id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] - component_id = Column(Integer, ForeignKey('application.id', ondelete = 'CASCADE')) + component_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) issue_type: Mapped[str] = mapped_column(Text) - issue_id = Column(Integer, ForeignKey('application_issue.id', ondelete = 'CASCADE')) + issue_id = Column(Integer, ForeignKey("application_issue.id", ondelete="CASCADE")) has_vulnerability: Mapped[bool] has_non_security_error: Mapped[bool] notes: Mapped[str] = mapped_column(Text) def __repr__(self): - return (f"") + return ( + f"" + ) + class EntryPoint(Base): - __tablename__ = 'entry_point' + __tablename__ = "entry_point" id: Mapped[int] = mapped_column(primary_key=True) - app_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) + app_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) file: Mapped[str] user_input: Mapped[str] line: Mapped[int] @@ -62,16 +73,19 @@ class EntryPoint(Base): repo: Mapped[str] def __repr__(self): - return (f"") - -class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /info) with additional properties - __tablename__ = 'web_entry_point' + return ( + f"" + ) + + +class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /info) with additional properties + __tablename__ = "web_entry_point" id: Mapped[int] = mapped_column(primary_key=True) - entry_point_id = Column(Integer, ForeignKey('entry_point.id', ondelete='CASCADE')) - method: Mapped[str] # GET, POST, etc - path: Mapped[str] # /info + entry_point_id = Column(Integer, ForeignKey("entry_point.id", ondelete="CASCADE")) + method: Mapped[str] # GET, POST, etc + path: Mapped[str] # /info component: Mapped[int] auth: Mapped[str] middleware: Mapped[str] @@ -80,17 +94,20 @@ class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /in repo: Mapped[str] def __repr__(self): - return (f"") + return ( + f"" + ) + class UserAction(Base): - __tablename__ = 'user_action' + __tablename__ = "user_action" id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] - app_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) + app_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) file: Mapped[str] line: Mapped[int] notes: Mapped[str] = mapped_column(Text) diff --git a/src/seclab_taskflows/mcp_servers/report_alert_state.py b/src/seclab_taskflows/mcp_servers/report_alert_state.py index 89fb6fa..074c121 100644 --- a/src/seclab_taskflows/mcp_servers/report_alert_state.py +++ b/src/seclab_taskflows/mcp_servers/report_alert_state.py @@ -16,11 +16,12 @@ logging.basicConfig( level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file_name('mcp_report_alert_state.log'), - filemode='a' + format="%(asctime)s - %(levelname)s - %(message)s", + filename=log_file_name("mcp_report_alert_state.log"), + filemode="a", ) + def result_to_dict(result): return { "canonical_id": result.canonical_id, @@ -31,9 +32,10 @@ def result_to_dict(result): "location": result.location, "result": result.result, "created": result.created, - "valid": result.valid + "valid": result.valid, } + def flow_to_dict(flow): return { "id": flow.id, @@ -41,9 +43,10 @@ def flow_to_dict(flow): "flow_data": flow.flow_data, "repo": flow.repo.lower(), "prev": flow.prev, - "next": flow.next + "next": flow.next, } + def remove_line_numbers(location: str) -> str: """ Remove line numbers from a location string. @@ -51,31 +54,38 @@ def remove_line_numbers(location: str) -> str: """ if not location: return location - parts = location.split(':') + parts = location.split(":") if len(parts) < 4: # Ensure there are enough parts to remove line numbers return location # Keep the first part (file path) and the last two parts (col:col) - return ':'.join(parts[:-4]) + return ":".join(parts[:-4]) + +MEMORY = mcp_data_dir("seclab-taskflows", "report_alert_state", "ALERT_RESULTS_DIR") -MEMORY = mcp_data_dir('seclab-taskflows', 'report_alert_state', 'ALERT_RESULTS_DIR') class ReportAlertStateBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir - self.location_pattern = r'^([a-zA-Z]+)(:\d+){4}$' + self.location_pattern = r"^([a-zA-Z]+)(:\d+){4}$" if not Path(self.memcache_state_dir).exists(): - db_dir = 'sqlite://' + db_dir = "sqlite://" else: - db_dir = f'sqlite:///{self.memcache_state_dir}/alert_results.db' + db_dir = f"sqlite:///{self.memcache_state_dir}/alert_results.db" self.engine = create_engine(db_dir, echo=False) Base.metadata.create_all(self.engine, tables=[AlertResults.__table__, AlertFlowGraph.__table__]) - def set_alert_result(self, alert_id: str, repo: str, rule: str, language: str, location: str, result: str, created: str) -> str: + def set_alert_result( + self, alert_id: str, repo: str, rule: str, language: str, location: str, result: str, created: str + ) -> str: if not result: result = "" with Session(self.engine) as session: - existing = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo, rule=rule, language=language).first() + existing = ( + session.query(AlertResults) + .filter_by(alert_id=alert_id, repo=repo, rule=rule, language=language) + .first() + ) if existing: existing.result += result else: @@ -88,7 +98,7 @@ def set_alert_result(self, alert_id: str, repo: str, rule: str, language: str, l result=result, created=created, valid=True, - completed=False + completed=False, ) session.add(new_alert) session.commit() @@ -132,30 +142,32 @@ def set_alert_completed(self, alert_id: str, repo: str, completed: bool) -> str: def get_completed_alerts(self, rule: str, repo: str = None) -> Any: """Get all incomplete alerts in a repository.""" - filter_params = {'completed' : True} + filter_params = {"completed": True} if repo: - filter_params['repo'] = repo + filter_params["repo"] = repo if rule: - filter_params['rule'] = rule + filter_params["rule"] = rule with Session(self.engine) as session: results = [result_to_dict(r) for r in session.query(AlertResults).filter_by(**filter_params).all()] return results def clear_completed_alerts(self, repo: str = None, rule: str = None) -> str: """Clear all completed alerts in a repository.""" - filter_params = {'completed': True} + filter_params = {"completed": True} if repo: - filter_params['repo'] = repo + filter_params["repo"] = repo if rule: - filter_params['rule'] = rule + filter_params["rule"] = rule with Session(self.engine) as session: session.query(AlertResults).filter_by(**filter_params).delete() session.commit() - return "Cleared completed alerts with repo: {}, rule: {}".format(repo if repo else "all", rule if rule else "all") + return "Cleared completed alerts with repo: {}, rule: {}".format( + repo if repo else "all", rule if rule else "all" + ) def get_alert_results(self, alert_id: str, repo: str) -> str: with Session(self.engine) as session: - result = session.query(AlertResults).filter_by(alert_id=alert_id, repo = repo).first() + result = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo).first() if not result: return "No results found." return "Analysis results for alert ID {} in repo {}: {}".format(alert_id, repo, result.result) @@ -168,26 +180,27 @@ def get_alert_by_canonical_id(self, canonical_id: int) -> Any: return result_to_dict(result) def get_alert_results_by_rule(self, rule: str, repo: str = None, valid: bool = None) -> Any: - filter_params = {'rule': rule} + filter_params = {"rule": rule} if repo: - filter_params['repo'] = repo + filter_params["repo"] = repo if valid is not None: - filter_params['valid'] = valid + filter_params["valid"] = valid with Session(self.engine) as session: results = [result_to_dict(r) for r in session.query(AlertResults).filter_by(**filter_params).all()] return results + def delete_alert_result(self, alert_id: str, repo: str) -> str: with Session(self.engine) as session: result = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo).delete() session.commit() return f"Deleted alert result for {alert_id} in {repo}" - def clear_alert_results(self, repo : str = None, rule: str = None) -> str: + def clear_alert_results(self, repo: str = None, rule: str = None) -> str: filter_params = {} if repo: - filter_params['repo'] = repo + filter_params["repo"] = repo if rule: - filter_params['rule'] = rule + filter_params["rule"] = rule with Session(self.engine) as session: if not filter_params: session.query(AlertResults).delete() @@ -196,22 +209,21 @@ def clear_alert_results(self, repo : str = None, rule: str = None) -> str: session.commit() return "Cleared alert results with repo: {}, rule: {}".format(repo if repo else "all", rule if rule else "all") - def add_flow_to_alert(self, canonical_id: int, flow_data: str, repo: str, prev: str = None, next: str = None) -> str: + def add_flow_to_alert( + self, canonical_id: int, flow_data: str, repo: str, prev: str = None, next: str = None + ) -> str: """Add a flow graph for a specific alert result.""" with Session(self.engine) as session: flow_graph = AlertFlowGraph( - alert_canonical_id=canonical_id, - flow_data=flow_data, - repo=repo, - prev=prev, - next=next, - started = False + alert_canonical_id=canonical_id, flow_data=flow_data, repo=repo, prev=prev, next=next, started=False ) session.add(flow_graph) session.commit() return f"Added flow graph for alert with canonical ID {canonical_id}" - def batch_add_flow_to_alert(self, alert_canonical_id: int, flows: list[str], repo: str, prev: str, next: str) -> str: + def batch_add_flow_to_alert( + self, alert_canonical_id: int, flows: list[str], repo: str, prev: str, next: str + ) -> str: """Batch add flow graphs for multiple alert results.""" with Session(self.engine) as session: for flow in flows: @@ -221,7 +233,7 @@ def batch_add_flow_to_alert(self, alert_canonical_id: int, flows: list[str], rep repo=repo, prev=prev, next=next, - started = False + started=False, ) session.add(flow_graph) session.commit() @@ -250,11 +262,13 @@ def delete_flow_graph_for_alert(self, alert_canonical_id: int) -> str: with Session(self.engine) as session: result = session.query(AlertFlowGraph).filter_by(alert_canonical_id=alert_canonical_id).delete() session.commit() - return f"Deleted flow graph with for alert with canonical iD {id}" if result else "No flow graph found to delete." + return ( + f"Deleted flow graph with for alert with canonical iD {id}" if result else "No flow graph found to delete." + ) def update_all_alert_results_for_flow_graph(self, next: str, repo: str, result: str) -> str: with Session(self.engine) as session: - flow_graphs = session.query(AlertFlowGraph).filter_by(next=next, repo = repo).all() + flow_graphs = session.query(AlertFlowGraph).filter_by(next=next, repo=repo).all() if not flow_graphs: return f"No flow graphs found with next value {next}" alert_canonical_ids = set([fg.alert_canonical_id for fg in flow_graphs]) @@ -279,93 +293,136 @@ def clear_flow_graphs(self) -> str: session.commit() return "Cleared all flow graphs." + mcp = FastMCP("ReportAlertState") backend = ReportAlertStateBackend(MEMORY) + def process_repo(repo): return repo.lower() if repo else None + @mcp.tool() -def create_alert(alert_id: str, repo: str, rule: str, language: str, location: str, - result: str = Field(description="The result of the alert analysis", default=""), - created: str = Field(description = "The creation time of the alert", default="")) -> str: +def create_alert( + alert_id: str, + repo: str, + rule: str, + language: str, + location: str, + result: str = Field(description="The result of the alert analysis", default=""), + created: str = Field(description="The creation time of the alert", default=""), +) -> str: """Create an alert using a specific alert ID in a repository.""" return backend.set_alert_result(alert_id, process_repo(repo), rule, language, location, result, created) + @mcp.tool() def update_alert_result(alert_id: str, repo: str, result: str) -> str: """Update an existing alert result for a specific alert ID in a repository.""" return backend.update_alert_result(alert_id, process_repo(repo), result) + @mcp.tool() def update_alert_result_by_canonical_id(canonical_id: int, result: str) -> str: """Update an existing alert result by canonical ID.""" return backend.update_alert_result_by_canonical_id(canonical_id, result) + @mcp.tool() def set_alert_valid(alert_id: str, repo: str, valid: bool) -> str: """Set the validity of an alert result for a specific alert ID in a repository.""" return backend.set_alert_valid(alert_id, process_repo(repo), valid) + @mcp.tool() def get_alert_results(alert_id: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Get the analysis results for a specific alert ID in a repository.""" return backend.get_alert_results(alert_id, process_repo(repo)) + @mcp.tool() def get_alert_by_canonical_id(canonical_id: int) -> str: """Get alert results by canonical ID.""" return json.dumps(backend.get_alert_by_canonical_id(canonical_id)) + @mcp.tool() -def get_alert_results_by_rule(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: +def get_alert_results_by_rule( + rule: str, + repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), +) -> str: """Get all alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), None)) + @mcp.tool() -def get_valid_alert_results_by_rule(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: +def get_valid_alert_results_by_rule( + rule: str, + repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), +) -> str: """Get all valid alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), True)) + @mcp.tool() -def get_invalid_alert_results(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: +def get_invalid_alert_results( + rule: str, + repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), +) -> str: """Get all valid alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), False)) + @mcp.tool() def set_alert_completed(alert_id: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Set the completion status of an alert result for a specific alert ID in a repository.""" return backend.set_alert_completed(alert_id, process_repo(repo), True) + @mcp.tool() -def get_completed_alerts(rule: str, repo: str = Field(description="repo in the format owner/repo", default = None)) -> str: +def get_completed_alerts( + rule: str, repo: str = Field(description="repo in the format owner/repo", default=None) +) -> str: """Get all complete alerts in a repository.""" results = backend.get_completed_alerts(rule, process_repo(repo)) return json.dumps(results) + @mcp.tool() -def clear_completed_alerts(repo: str = Field(description="repo in the format owner/repo", default = None), rule: str = None) -> str: +def clear_completed_alerts( + repo: str = Field(description="repo in the format owner/repo", default=None), rule: str = None +) -> str: """Clear all completed alerts in a repository.""" return backend.clear_completed_alerts(process_repo(repo), rule) + @mcp.tool() def clear_repo_results(repo: str = Field(description="repo in the format owner/repo")) -> str: """Clear all alert results for a specific repository.""" return backend.clear_alert_results(process_repo(repo), None) + @mcp.tool() -def clear_rule_results(rule: str, repo: str = Field(description="repo in the format owner/repo", default = None)) -> str: +def clear_rule_results(rule: str, repo: str = Field(description="repo in the format owner/repo", default=None)) -> str: """Clear all alert results for a specific rule in a repository.""" return backend.clear_alert_results(process_repo(repo), rule) + @mcp.tool() def clear_alert_results() -> str: """Clear all alert results.""" return backend.clear_alert_results(None, None) + @mcp.tool() -def add_flow_to_alert(canonical_id: int, flow_data: str, repo: str = Field(description="repo in the format owner/repo"), prev: str = None, next: str = None) -> str: +def add_flow_to_alert( + canonical_id: int, + flow_data: str, + repo: str = Field(description="repo in the format owner/repo"), + prev: str = None, + next: str = None, +) -> str: """Add a flow graph for a specific alert result.""" flow_data = remove_line_numbers(flow_data) prev = remove_line_numbers(prev) if prev else None @@ -373,13 +430,17 @@ def add_flow_to_alert(canonical_id: int, flow_data: str, repo: str = Field(descr backend.add_flow_to_alert(canonical_id, flow_data, process_repo(repo), prev, next) return f"Added flow graph for alert with canonical ID {canonical_id}" + @mcp.tool() -def batch_add_flow_to_alert(alert_canonical_id: int, - repo: str = Field(description="The repository name for the alert result in the format owner/repo"), - flows: str = Field(description="A JSON string containing a list of flows to add for the alert result."), - next: str = None, prev: str = None) -> str: +def batch_add_flow_to_alert( + alert_canonical_id: int, + repo: str = Field(description="The repository name for the alert result in the format owner/repo"), + flows: str = Field(description="A JSON string containing a list of flows to add for the alert result."), + next: str = None, + prev: str = None, +) -> str: """Batch add a list of paths to flow graphs for a specific alert result.""" - flows_list = flows.split(',') + flows_list = flows.split(",") return backend.batch_add_flow_to_alert(alert_canonical_id, flows_list, process_repo(repo), prev, next) @@ -388,39 +449,48 @@ def get_alert_flow(canonical_id: int) -> str: """Get the flow graph for a specific alert result.""" return json.dumps(backend.get_alert_flow(canonical_id)) + @mcp.tool() def get_all_alert_flows() -> str: """Get all flow graphs for all alert results.""" return json.dumps(backend.get_all_alert_flows()) + @mcp.tool() def get_alert_flows_by_data(flow_data: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Get flow graphs for a specific alert result by repo and flow data.""" flow_data = remove_line_numbers(flow_data) return json.dumps(backend.get_alert_flows_by_data(process_repo(repo), flow_data)) + @mcp.tool() def delete_flow_graph(id: int) -> str: """Delete a flow graph with id.""" return backend.delete_flow_graph(id) + @mcp.tool() def delete_flow_graph_for_alert(alert_canonical_id: int) -> str: """Delete a all flow graphs for an alert with a specific canonical ID.""" return backend.delete_flow_graph_for_alert(alert_canonical_id) + @mcp.tool() -def update_all_alert_results_for_flow_graph(next: str, result: str, repo: str = Field(description="repo in the format owner/repo")) -> str: +def update_all_alert_results_for_flow_graph( + next: str, result: str, repo: str = Field(description="repo in the format owner/repo") +) -> str: """Update all alert results for flow graphs with a specific next value.""" - if not '/' in repo: + if not "/" in repo: return "Invalid repository format. Please provide a repository in the format 'owner/repo'." next = remove_line_numbers(next) if next else None return backend.update_all_alert_results_for_flow_graph(next, process_repo(repo), result) + @mcp.tool() def clear_flow_graphs() -> str: """Clear all flow graphs.""" return backend.clear_flow_graphs() + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/utils.py b/src/seclab_taskflows/mcp_servers/utils.py index 528f9c4..9e18435 100644 --- a/src/seclab_taskflows/mcp_servers/utils.py +++ b/src/seclab_taskflows/mcp_servers/utils.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT + def process_repo(owner, repo): """ Normalize repository identifier to lowercase format 'owner/repo'. diff --git a/tests/test_00.py b/tests/test_00.py index d60b706..0cf5c56 100644 --- a/tests/test_00.py +++ b/tests/test_00.py @@ -6,9 +6,11 @@ import pytest import seclab_taskflows + class Test00: def test_nothing(self): assert True -if __name__ == '__main__': - pytest.main([__file__, '-v']) + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From cdcd22ff3cab79952d340d3d82a3a281f5d98c96 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:23:32 +0000 Subject: [PATCH 3/3] Revert formatting changes and only enable linter check in CI Co-authored-by: kevinbackhouse <4358136+kevinbackhouse@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- .../mcp_servers/alert_results_models.py | 26 +- .../codeql_python/codeql_sqlite_models.py | 11 +- .../mcp_servers/codeql_python/mcp_server.py | 121 ++-- .../mcp_servers/gh_actions.py | 191 +++--- .../mcp_servers/gh_code_scanning.py | 208 +++--- .../mcp_servers/gh_file_viewer.py | 146 ++--- src/seclab_taskflows/mcp_servers/ghsa.py | 27 +- .../mcp_servers/local_file_viewer.py | 70 +- .../mcp_servers/local_gh_resources.py | 44 +- .../mcp_servers/repo_context.py | 607 +++++++----------- .../mcp_servers/repo_context_models.py | 79 +-- .../mcp_servers/report_alert_state.py | 182 ++---- src/seclab_taskflows/mcp_servers/utils.py | 1 - tests/test_00.py | 6 +- 15 files changed, 659 insertions(+), 1062 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 811a5a2..419de51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: run: pip install --upgrade hatch - name: Run static analysis - run: hatch fmt --check + run: hatch fmt --linter --check - name: Run tests run: hatch test --python ${{ matrix.python-version }} --cover --randomize --parallel --retries 2 --retry-delay 1 diff --git a/src/seclab_taskflows/mcp_servers/alert_results_models.py b/src/seclab_taskflows/mcp_servers/alert_results_models.py index a20852f..53efc2c 100644 --- a/src/seclab_taskflows/mcp_servers/alert_results_models.py +++ b/src/seclab_taskflows/mcp_servers/alert_results_models.py @@ -5,13 +5,11 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship from typing import Optional - class Base(DeclarativeBase): pass - class AlertResults(Base): - __tablename__ = "alert_results" + __tablename__ = 'alert_results' canonical_id: Mapped[int] = mapped_column(primary_key=True) alert_id: Mapped[str] @@ -24,21 +22,18 @@ class AlertResults(Base): valid: Mapped[bool] = mapped_column(nullable=False, default=True) completed: Mapped[bool] = mapped_column(nullable=False, default=False) - relationship("AlertFlowGraph", cascade="all, delete") + relationship('AlertFlowGraph', cascade='all, delete') def __repr__(self): - return ( - f"" - ) - + return (f"") class AlertFlowGraph(Base): - __tablename__ = "alert_flow_graph" + __tablename__ = 'alert_flow_graph' id: Mapped[int] = mapped_column(primary_key=True) - alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE")) + alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE')) flow_data: Mapped[str] = mapped_column(Text) repo: Mapped[str] prev: Mapped[Optional[str]] @@ -46,7 +41,6 @@ class AlertFlowGraph(Base): started: Mapped[bool] = mapped_column(nullable=False, default=False) def __repr__(self): - return ( - f"" - ) + return (f"") + diff --git a/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py b/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py index 4e1604c..51d1224 100644 --- a/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py +++ b/src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py @@ -5,13 +5,12 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped from typing import Optional - class Base(DeclarativeBase): pass class Source(Base): - __tablename__ = "source" + __tablename__ = 'source' id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] @@ -21,8 +20,6 @@ class Source(Base): notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) def __repr__(self): - return ( - f"" - ) + return (f"") diff --git a/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py b/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py index 9af9452..74fade9 100644 --- a/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py +++ b/src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py @@ -6,9 +6,8 @@ from seclab_taskflow_agent.mcp_servers.codeql.client import run_query, _debug_log from pydantic import Field - -# from mcp.server.fastmcp import FastMCP, Context -from fastmcp import FastMCP # use FastMCP 2.0 +#from mcp.server.fastmcp import FastMCP, Context +from fastmcp import FastMCP # use FastMCP 2.0 from pathlib import Path import os import csv @@ -24,20 +23,22 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_codeql_python.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_codeql_python.log'), + filemode='a' ) -MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR") -CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH") +MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR') +CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH') mcp = FastMCP("CodeQL-Python") # tool name -> templated query lookup for supported languages TEMPLATED_QUERY_PATHS = { # to add a language, port the templated query pack and add its definition here - "python": {"remote_sources": "queries/mcp-python/remote_sources.ql"} + 'python': { + 'remote_sources': 'queries/mcp-python/remote_sources.ql' + } } @@ -48,10 +49,9 @@ def source_to_dict(result): "source_location": result.source_location, "line": result.line, "source_type": result.source_type, - "notes": result.notes, + "notes": result.notes } - def _resolve_query_path(language: str, query: str) -> Path: global TEMPLATED_QUERY_PATHS if language not in TEMPLATED_QUERY_PATHS: @@ -66,7 +66,7 @@ def _resolve_db_path(relative_db_path: str | Path): global CODEQL_DBS_BASE_PATH # path joins will return "/B" if "/A" / "////B" etc. as well # not windows compatible and probably needs additional hardening - relative_db_path = str(relative_db_path).strip().lstrip("/") + relative_db_path = str(relative_db_path).strip().lstrip('/') relative_db_path = Path(relative_db_path) absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve() if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()): @@ -76,21 +76,21 @@ def _resolve_db_path(relative_db_path: str | Path): raise RuntimeError(f"Error: Database not found at {absolute_path}!") return str(absolute_path) - # This sqlite database is specifically made for CodeQL for Python MCP. class CodeqlSqliteBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir if not Path(self.memcache_state_dir).exists(): - db_dir = "sqlite://" + db_dir = 'sqlite://' else: - db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db" + db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db' self.engine = create_engine(db_dir, echo=False) Base.metadata.create_all(self.engine, tables=[Source.__table__]) - def store_new_source(self, repo, source_location, line, source_type, notes, update=False): + + def store_new_source(self, repo, source_location, line, source_type, notes, update = False): with Session(self.engine) as session: - existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first() + existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first() if existing: existing.notes = (existing.notes or "") + notes session.commit() @@ -98,16 +98,14 @@ def store_new_source(self, repo, source_location, line, source_type, notes, upda else: if update: return f"No source exists at repo {repo}, location {source_location}, line {line} to update." - new_source = Source( - repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes - ) + new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes) session.add(new_source) session.commit() return f"Added new source for {source_location} in {repo}." def get_sources(self, repo): with Session(self.engine) as session: - results = session.query(Source).filter_by(repo=repo).all() + results = session.query(Source).filter_by(repo = repo).all() sources = [source_to_dict(source) for source in results] return sources @@ -121,8 +119,8 @@ def _csv_parse(raw): if i == 0: continue # col1 has what we care about, but offer flexibility - keys = row[1].split(",") - this_obj = {"description": row[0].format(*row[2:])} + keys = row[1].split(',') + this_obj = {'description': row[0].format(*row[2:])} for j, k in enumerate(keys): this_obj[k.strip()] = row[j + 2] results.append(this_obj) @@ -143,32 +141,27 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu except RuntimeError: return f"The query {query_name} is not supported for language: {language}" try: - csv = run_query( - Path(__file__).parent.resolve() / query_path, - database_path, - fmt="csv", - template_values=template_values, - log_stderr=True, - ) + csv = run_query(Path(__file__).parent.resolve() / + query_path, + database_path, + fmt='csv', + template_values=template_values, + log_stderr=True) return _csv_parse(csv) except Exception as e: return f"The query {query_name} encountered an error: {e}" - backend = CodeqlSqliteBackend(MEMORY) - @mcp.tool() -def remote_sources( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - database_path: str = Field(description="The CodeQL database path."), - language: str = Field(description="The language used for the CodeQL database."), -): +def remote_sources(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + database_path: str = Field(description="The CodeQL database path."), + language: str = Field(description="The language used for the CodeQL database.")): """List all remote sources and their locations in a CodeQL database, then store the results in a database.""" repo = process_repo(owner, repo) - results = _run_query("remote_sources", database_path, language, {}) + results = _run_query('remote_sources', database_path, language, {}) # Check if results is an error (list of strings) or valid data (list of dicts) if isinstance(results, str): @@ -179,67 +172,53 @@ def remote_sources( for result in results: backend.store_new_source( repo=repo, - source_location=result.get("location", ""), - source_type=result.get("source", ""), - line=int(result.get("line", "0")), - notes=None, # result.get('description', ''), - update=False, + source_location=result.get('location', ''), + source_type=result.get('source', ''), + line=int(result.get('line', '0')), + notes=None, #result.get('description', ''), + update=False ) stored_count += 1 return f"Stored {stored_count} remote sources in {repo}." - @mcp.tool() -def fetch_sources( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Fetch all sources from the repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_sources(repo)) - @mcp.tool() -def add_source_notes( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - source_location: str = Field(description="The path to the file"), - line: int = Field(description="The line number of the source"), - notes: str = Field(description="The notes to append to this source"), -): +def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + source_location: str = Field(description="The path to the file"), + line: int = Field(description="The line number of the source"), + notes: str = Field(description="The notes to append to this source")): """ Add new notes to an existing source. The notes will be appended to any existing notes. """ repo = process_repo(owner, repo) - return backend.store_new_source( - repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True - ) - + return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True) @mcp.tool() -def clear_codeql_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Clear all data for a given repo from the database """ repo = process_repo(owner, repo) with Session(backend.engine) as session: - deleted_sources = session.query(Source).filter_by(repo=repo).delete() + deleted_sources = session.query(Source).filter_by(repo = repo).delete() session.commit() return f"Cleared {deleted_sources} sources from repo {repo}." - if __name__ == "__main__": # Check if codeql/python-all pack is installed, if not install it - if not os.path.isdir("/.codeql/packages/codeql/python-all"): - pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath( - "mcp-python" - ) + if not os.path.isdir('/.codeql/packages/codeql/python-all'): + pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python') print(f"Installing CodeQL pack from {pack_path}") subprocess.run(["codeql", "pack", "install", pack_path]) mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998) diff --git a/src/seclab_taskflows/mcp_servers/gh_actions.py b/src/seclab_taskflows/mcp_servers/gh_actions.py index 1deaa1f..51f7451 100644 --- a/src/seclab_taskflows/mcp_servers/gh_actions.py +++ b/src/seclab_taskflows/mcp_servers/gh_actions.py @@ -16,18 +16,16 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_gh_actions.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_gh_actions.log'), + filemode='a' ) - class Base(DeclarativeBase): pass - class WorkflowUses(Base): - __tablename__ = "workflow_uses" + __tablename__ = 'workflow_uses' id: Mapped[int] = mapped_column(primary_key=True) user: Mapped[str] @@ -36,45 +34,34 @@ class WorkflowUses(Base): repo: Mapped[str] def __repr__(self): - return f"" + return (f"") mcp = FastMCP("GitHubCodeScanning") -high_privileged_triggers = set( - [ - "issues", - "issue_comment", - "pull_request_comment", - "pull_request_review", - "pull_request_review_comment", - "pull_request_target", - ] -) +high_privileged_triggers = set(["issues", "issue_comment", "pull_request_comment", "pull_request_review", "pull_request_review_comment", + "pull_request_target"]) -unimportant_triggers = set(["pull_request", "workflow_dispatch"]) +unimportant_triggers = set(['pull_request', 'workflow_dispatch']) -GH_TOKEN = os.getenv("GH_TOKEN", default="") +GH_TOKEN = os.getenv('GH_TOKEN', default='') -ACTIONS_DB_DIR = mcp_data_dir("seclab-taskflows", "gh_actions", "ACTIONS_DB_DIR") +ACTIONS_DB_DIR = mcp_data_dir('seclab-taskflows', 'gh_actions', 'ACTIONS_DB_DIR') -engine = create_engine(f"sqlite:///{os.path.abspath(ACTIONS_DB_DIR)}/actions.db", echo=False) -Base.metadata.create_all(engine, tables=[WorkflowUses.__table__]) +engine = create_engine(f'sqlite:///{os.path.abspath(ACTIONS_DB_DIR)}/actions.db', echo=False) +Base.metadata.create_all(engine, tables = [WorkflowUses.__table__]) -async def call_api(url: str, params: dict, raw=False) -> str: +async def call_api(url: str, params: dict, raw = False) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } + headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} if raw: headers["Accept"] = "application/vnd.github.raw+json" - async def _fetch(url, headers, params): try: - async with httpx.AsyncClient(headers=headers) as client: + async with httpx.AsyncClient(headers = headers) as client: r = await client.get(url, params=params) r.raise_for_status() return r @@ -87,40 +74,41 @@ async def _fetch(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - r = await _fetch(url, headers=headers, params=params) + r = await _fetch(url, headers = headers, params=params) return r - @mcp.tool() async def fetch_workflow( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_id: str = Field(description="The ID or name of the workflow"), -) -> str: + workflow_id: str = Field(description="The ID or name of the workflow")) -> str: """ Fetch the details of a GitHub Actions workflow. """ - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", + params={} + ) if isinstance(r, str): return r return r.json() - @mcp.tool() async def check_workflow_active( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_id: str = Field(description="The ID or name of the workflow"), -) -> str: + workflow_id: str = Field(description="The ID or name of the workflow")) -> str: """ Check if a GitHub Actions workflow is active. """ - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", + params={} + ) if isinstance(r, str): return r return f"Workflow {workflow_id} is {'active' if r.json().get('state') == 'active' else 'inactive'}." - def find_in_yaml(key, node): if isinstance(node, dict): for k, v in node.items(): @@ -134,11 +122,12 @@ def find_in_yaml(key, node): for result in find_in_yaml(key, item): yield result - async def get_workflow_triggers(owner: str, repo: str, workflow_file_path: str) -> str: + r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", params={}, raw=True - ) + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", + params={}, raw = True + ) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) @@ -146,76 +135,81 @@ async def get_workflow_triggers(owner: str, repo: str, workflow_file_path: str) triggers = list(find_in_yaml(True, data)) return triggers - @mcp.tool() async def find_workflow_run_dependency( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), workflow_file_path: str = Field(description="The file path of the workflow that is triggered by `workflow_run`"), - high_privileged: bool = Field(description="Whether to return high privileged dependencies only."), -) -> str: + high_privileged: bool = Field(description="Whether to return high privileged dependencies only.") +)->str: """ Find the workflow that triggers this workflow_run. """ r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", params={}, raw=True + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{workflow_file_path}", + params={}, raw=True ) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) - trigger_workflow = list(find_in_yaml("workflow_run", data))[0].get("workflows", []) + trigger_workflow = list(find_in_yaml('workflow_run', data))[0].get('workflows', []) if not trigger_workflow: return json.dumps([], indent=2) r = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", params={}, raw=True + url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", + params={}, raw=True ) if isinstance(r, str): return json.dumps([r]) if not r.json(): return json.dumps([], indent=2) - paths_list = [item["path"] for item in r.json() if item["path"].endswith(".yml") or item["path"].endswith(".yaml")] + paths_list = [item['path'] for item in r.json() if item['path'].endswith('.yml') or item['path'].endswith('.yaml')] results = [] for path in paths_list: - workflow_id = path.split("/")[-1] - active = await call_api( - url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", params={} + workflow_id = path.split('/')[-1] + active = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/actions/workflows/{workflow_id}", + params={} ) - if not isinstance(active, str) and active.json().get("state") == "active": - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}, raw=True) + if not isinstance(active, str) and active.json().get('state') == "active": + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={}, raw=True + ) if isinstance(r, str): return json.dumps([r]) data = yaml.safe_load(r.text) - name = data.get("name", "") + name = data.get('name', '') if name in trigger_workflow or "*" in trigger_workflow: triggers = data.get(True, {}) if not high_privileged or high_privileged_triggers.intersection(set(triggers)): - results.append({"path": path, "name": name, "triggers": triggers}) + results.append({ + "path": path, + "name": name, + "triggers": triggers + }) return json.dumps(results, indent=2) - @mcp.tool() async def get_workflow_trigger( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow"), -) -> str: + workflow_file_path: str = Field(description="The file path of the workflow")) -> str: """ Get the trigger of a GitHub Actions workflow. """ return json.dumps(await get_workflow_triggers(owner, repo, workflow_file_path), indent=2) - @mcp.tool() async def check_workflow_reusable( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow"), -) -> str: + workflow_file_path: str = Field(description="The file path of the workflow")) -> str: """ Check if a GitHub Actions workflow is reusable. """ - if workflow_file_path.endswith("/action.yml") or workflow_file_path.endswith("/action.yaml"): + if workflow_file_path.endswith('/action.yml') or workflow_file_path.endswith('/action.yaml'): return "This workflow is reusable as an action." triggers = await get_workflow_triggers(owner, repo, workflow_file_path) print(f"Triggers found: {triggers}") @@ -224,17 +218,15 @@ async def check_workflow_reusable( return "This workflow is reusable as a workflow call." elif isinstance(trigger, dict): for k, v in trigger.items(): - if "workflow_call" == k: + if 'workflow_call' == k: return "This workflow is reusable." return "This workflow is not reusable." - @mcp.tool() async def get_high_privileged_workflow_triggers( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - workflow_file_path: str = Field(description="The file path of the workflow"), -) -> str: + workflow_file_path: str = Field(description="The file path of the workflow")) -> str: """ Gets the high privileged triggers for a workflow, if none returns, then the workflow is not high privileged. """ @@ -244,55 +236,55 @@ async def get_high_privileged_workflow_triggers( if isinstance(trigger, str): if trigger in high_privileged_triggers: results.append(trigger) - elif trigger == "workflow_run": + elif trigger == 'workflow_run': results.append(trigger) elif isinstance(trigger, dict): this_results = {} for k, v in trigger.items(): if k in high_privileged_triggers: this_results[k] = v - elif k == "workflow_run": + elif k == 'workflow_run': if not v or isinstance(v, str): this_results[k] = v - elif isinstance(v, dict) and not "branches" in v: + elif isinstance(v, dict) and not 'branches' in v: this_results[k] = v if this_results: results.append(this_results) - return json.dumps( - ["Workflow is high privileged" if results else "Workflow is not high privileged", results], indent=2 - ) - + return json.dumps(["Workflow is high privileged" if results else "Workflow is not high privileged", results], indent = 2) @mcp.tool() async def get_workflow_user( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), workflow_file_path: str = Field(description="The file path of the workflow"), - save_to_db: bool = Field(description="Save the results to database.", default=False), -) -> str: + save_to_db: bool = Field(description="Save the results to database.", default=False)) -> str: """ Get the user of a reusable workflow in repo. """ - paths = workflow_file_path.split("/") - if workflow_file_path.endswith("/action.yml") or workflow_file_path.endswith("/action.yaml"): + paths = workflow_file_path.split('/') + if workflow_file_path.endswith('/action.yml') or workflow_file_path.endswith('/action.yaml'): action_name = paths[-2] else: - action_name = paths[-1].replace(".yml", "").replace(".yaml", "") - paths = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", params={}) + action_name = paths[-1].replace('.yml', '').replace('.yaml', '') + paths = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/.github/workflows", + params={} + ) if isinstance(paths, str) or not paths.json(): return json.dumps([], indent=2) - paths_list = [ - item["path"] for item in paths.json() if item["path"].endswith(".yml") or item["path"].endswith(".yaml") - ] + paths_list = [item['path'] for item in paths.json() if item['path'].endswith('.yml') or item['path'].endswith('.yaml')] results = [] for path in paths_list: - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}, raw=True) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={}, raw=True + ) if isinstance(r, str): continue data = yaml.safe_load(r.text) - uses = list(find_in_yaml("uses", data)) + uses = list(find_in_yaml('uses', data)) lines = r.text.splitlines() actual_name = {} for use in uses: @@ -305,24 +297,26 @@ async def get_workflow_user( for use, line_numbers in actual_name.items(): if not line_numbers: continue - results.append( - {"user": path, "lines": line_numbers, "action_name": workflow_file_path, "repo": f"{owner}/{repo}"} - ) + results.append({ + "user": path, + "lines": line_numbers, + "action_name": workflow_file_path, + "repo": f"{owner}/{repo}" + }) if not results: return json.dumps([]) if save_to_db: with Session(engine) as session: for result in results: - result["lines"] = json.dumps(result["lines"]) # Convert list of lines to JSON string - result["repo"] = result["repo"].lower() + result['lines'] = json.dumps(result['lines']) # Convert list of lines to JSON string + result['repo'] = result['repo'].lower() workflow_use = WorkflowUses(**result) session.add(workflow_use) session.commit() return f"Search results saved to database." return json.dumps(results) - @mcp.tool() def fetch_last_workflow_users_results() -> str: """ @@ -332,18 +326,7 @@ def fetch_last_workflow_users_results() -> str: results = session.query(WorkflowUses).all() session.query(WorkflowUses).delete() session.commit() - return json.dumps( - [ - { - "user": result.user, - "lines": json.loads(result.lines), - "action": result.action_name, - "repo": result.repo.lower(), - } - for result in results - ] - ) - + return json.dumps([{"user": result.user, "lines" : json.loads(result.lines), "action": result.action_name, "repo" : result.repo.lower()} for result in results]) if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/gh_code_scanning.py b/src/seclab_taskflows/mcp_servers/gh_code_scanning.py index bfb1300..bbdff6f 100644 --- a/src/seclab_taskflows/mcp_servers/gh_code_scanning.py +++ b/src/seclab_taskflows/mcp_servers/gh_code_scanning.py @@ -20,71 +20,62 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_gh_code_scanning.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_gh_code_scanning.log'), + filemode='a' ) mcp = FastMCP("GitHubCodeScanning") -GH_TOKEN = os.getenv("GH_TOKEN", default="") - -CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH") -ALERT_RESULTS_DIR = mcp_data_dir("seclab-taskflows", "gh_code_scanning", "ALERT_RESULTS_DIR") +GH_TOKEN = os.getenv('GH_TOKEN', default='') +CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH') +ALERT_RESULTS_DIR = mcp_data_dir('seclab-taskflows', 'gh_code_scanning', 'ALERT_RESULTS_DIR') def parse_alert(alert: dict) -> dict: """Parse the alert dictionary to extract relevant information.""" - def _parse_location(location: dict) -> str: """Parse the location dictionary to extract file and line information.""" if not location: - return "No location information available" - file_path = location.get("path", "") - start_line = location.get("start_line", "") - end_line = location.get("end_line", "") - start_column = location.get("start_column", "") - end_column = location.get("end_column", "") + return 'No location information available' + file_path = location.get('path', '') + start_line = location.get('start_line', '') + end_line = location.get('end_line', '') + start_column = location.get('start_column', '') + end_column = location.get('end_column', '') if not file_path or not start_line or not end_line or not start_column or not end_column: - return "No location information available" + return 'No location information available' return f"{file_path}:{start_line}:{start_column}:{end_line}:{end_column}" - def _get_language(category: str) -> str: - return category.split(":")[1] if category and ":" in category else "" - + return category.split(':')[1] if category and ':' in category else '' def _get_repo_from_html_url(html_url: str) -> str: """Extract the repository name from the HTML URL.""" if not html_url: - return "" - parts = html_url.split("/") + return '' + parts = html_url.split('/') if len(parts) < 5: - return "" + return '' return f"{parts[3]}/{parts[4]}".lower() parsed = { - "alert_id": alert.get("number", "No number"), - "rule": alert.get("rule", {}).get("id", "No rule"), - "state": alert.get("state", "No state"), - "location": _parse_location(alert.get("most_recent_instance", {}).get("location", "No location")), - "language": _get_language(alert.get("most_recent_instance", {}).get("category", "No language")), - "created": alert.get("created_at", "No created"), - "updated": alert.get("updated_at", "No updated"), - "dismissed_comment": alert.get("dismissed_comment", ""), + 'alert_id': alert.get('number', 'No number'), + 'rule': alert.get('rule', {}).get('id', 'No rule'), + 'state': alert.get('state', 'No state'), + 'location': _parse_location(alert.get('most_recent_instance', {}).get('location', 'No location')), + 'language': _get_language(alert.get('most_recent_instance', {}).get('category', 'No language')), + 'created': alert.get('created_at', 'No created'), + 'updated': alert.get('updated_at', 'No updated'), + 'dismissed_comment': alert.get('dismissed_comment', ''), } return parsed - async def call_api(url: str, params: dict) -> str | httpx.Response: """Call the GitHub code scanning API to fetch alert.""" - headers = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } - + headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} async def _fetch_alerts(url, headers, params): try: - async with httpx.AsyncClient(headers=headers) as client: + async with httpx.AsyncClient(headers = headers) as client: r = await client.get(url, params=params) r.raise_for_status() return r @@ -97,16 +88,14 @@ async def _fetch_alerts(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - r = await _fetch_alerts(url, headers=headers, params=params) + r = await _fetch_alerts(url, headers = headers, params=params) return r @mcp.tool() -async def get_alert_by_number( - owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name."), - alert_number: int = Field(description="The alert number to get the alert for. Example: 1"), -) -> str: +async def get_alert_by_number(owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name."), + alert_number: int = Field(description="The alert number to get the alert for. Example: 1")) -> str: """Get the alert by number for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/alerts/{alert_number}" resp = await call_api(url, {}) @@ -116,25 +105,24 @@ async def get_alert_by_number( return json.dumps(parsed_alert) return resp - -async def fetch_alerts_from_gh(owner: str, repo: str, state: str = "open", rule="") -> str: +async def fetch_alerts_from_gh(owner: str, repo: str, state: str = 'open', rule = '') -> str: """Fetch all code scanning alerts for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/alerts" - if state not in ["open", "closed", "dismissed"]: - state = "open" - params = {"state": state, "per_page": 100} - # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + if state not in ['open', 'closed', 'dismissed']: + state = 'open' + params = {'state': state, 'per_page': 100} + #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') results = [] while True: resp = await call_api(url, params) resp_headers = resp.headers - link = resp_headers.get("link", "") + link = resp_headers.get('link', '') resp = resp.json() if isinstance(resp, list): this_results = [parse_alert(alert) for alert in resp] if rule: - this_results = [alert for alert in this_results if alert.get("rule") == rule] + this_results = [alert for alert in this_results if alert.get('rule') == rule] results += this_results else: return resp + " url: " + url @@ -148,76 +136,63 @@ async def fetch_alerts_from_gh(owner: str, repo: str, state: str = "open", rule= return results return "No alerts found." - @mcp.tool() -async def fetch_alerts( - owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name."), - state: str = Field(default="open", description="The state of the alert to filter by. Default is 'open'."), - rule: str = Field(description="The rule of the alert to fetch", default=""), -) -> str: +async def fetch_alerts(owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name."), + state: str = Field(default='open', description="The state of the alert to filter by. Default is 'open'."), + rule: str = Field(description='The rule of the alert to fetch', default = '')) -> str: """Fetch all code scanning alerts for a specific repository.""" results = await fetch_alerts_from_gh(owner, repo, state, rule) if isinstance(results, str): return results return json.dumps(results, indent=2) - @mcp.tool() async def fetch_alerts_to_sql( owner: str = Field(description="The owner of the repo"), repo: str = Field(description="The repository name."), - state: str = Field(default="open", description="The state of the alert to filter by. Default is 'open'."), - rule=Field(description="The rule of the alert to fetch", default=""), - rename_repo: str = Field( - description="An optional alternative repo name for storing the alerts, if not specify, repo is used ", - default="", - ), -) -> str: + state: str = Field(default='open', description="The state of the alert to filter by. Default is 'open'."), + rule = Field(description='The rule of the alert to fetch', default = ''), + rename_repo: str = Field(description="An optional alternative repo name for storing the alerts, if not specify, repo is used ", default = '') + ) -> str: """Fetch all code scanning alerts for a specific repository and store them in a SQL database.""" results = await fetch_alerts_from_gh(owner, repo, state, rule) - sql_db_path = f"sqlite:///{ALERT_RESULTS_DIR}/alert_results.db" + sql_db_path = f"sqlite:///{ALERT_RESULTS_DIR}/alert_results.db" if isinstance(results, str) or not results: return results engine = create_engine(sql_db_path, echo=False) Base.metadata.create_all(engine, tables=[AlertResults.__table__, AlertFlowGraph.__table__]) with Session(engine) as session: for alert in results: - session.add( - AlertResults( - alert_id=alert.get("alert_id", ""), - repo=rename_repo.lower() if rename_repo else repo.lower(), - language=alert.get("language", ""), - rule=alert.get("rule", ""), - location=alert.get("location", ""), - result="", - created=alert.get("created", ""), - valid=True, - ) - ) + session.add(AlertResults( + alert_id=alert.get('alert_id', ''), + repo = rename_repo.lower() if rename_repo else repo.lower(), + language=alert.get('language', ''), + rule=alert.get('rule', ''), + location=alert.get('location', ''), + result='', + created=alert.get('created', ''), + valid=True + )) session.commit() return f"Stored {len(results)} alerts in the SQL database at {sql_db_path}." - async def _fetch_codeql_databases(owner: str, repo: str, language: str): """Fetch the CodeQL databases for a given repo and language.""" url = f"https://api.github.com/repos/{owner}/{repo}/code-scanning/codeql/databases/{language}" - headers = { - "Accept": "application/zip,application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {os.getenv('GH_TOKEN')}", - } + headers = {"Accept": "application/zip,application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {os.getenv('GH_TOKEN')}"} try: async with httpx.AsyncClient() as client: - async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: + async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: response.raise_for_status() expected_path = f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip" if os.path.realpath(expected_path) != expected_path: return f"Error: Invalid path for CodeQL database: {expected_path}" if not Path(f"{CODEQL_DBS_BASE_PATH}/{owner}").exists(): os.makedirs(f"{CODEQL_DBS_BASE_PATH}/{owner}", exist_ok=True) - async with aiofiles.open(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip", "wb") as f: + async with aiofiles.open(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}.zip", 'wb') as f: async for chunk in response.aiter_bytes(): await f.write(chunk) # Unzip the downloaded file @@ -225,7 +200,7 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): if not zip_path.exists(): return f"Error: CodeQL database for {repo} ({language}) does not exist." - with zipfile.ZipFile(zip_path, "r") as zip_ref: + with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(Path(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}")) # Remove the zip file after extraction os.remove(zip_path) @@ -234,12 +209,7 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): if Path(f"{CODEQL_DBS_BASE_PATH}/{owner}/{repo}/codeql_db").exists(): qldb_subfolder = "codeql_db" - return json.dumps( - { - "message": f"CodeQL database for {repo} ({language}) fetched successfully.", - "relative_database_path": f"{owner}/{repo}/{qldb_subfolder}", - } - ) + return json.dumps({'message': f"CodeQL database for {repo} ({language}) fetched successfully.", 'relative_database_path': f"{owner}/{repo}/{qldb_subfolder}"}) except httpx.RequestError as e: return f"Error: Request error: {e}" except httpx.HTTPStatusError as e: @@ -247,23 +217,19 @@ async def _fetch_codeql_databases(owner: str, repo: str, language: str): except Exception as e: return f"Error: An unexpected error occurred: {e}" - @mcp.tool() -async def fetch_database( - owner: str = Field(description="The owner of the repo."), - repo: str = Field(description="The name of the repo."), - language: str = Field(description="The language used for the CodeQL database."), -): +async def fetch_database(owner: str = Field(description="The owner of the repo."), + repo: str = Field(description="The name of the repo."), + language: str = Field(description="The language used for the CodeQL database.")): """Fetch the CodeQL database for a given repo and language.""" return await _fetch_codeql_databases(owner, repo, language) - @mcp.tool() async def dismiss_alert( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), alert_id: str = Field(description="The ID of the alert to dismiss"), - reason: str = Field(description="The reason for dismissing the alert. It must be less than 280 characters."), + reason: str = Field(description="The reason for dismissing the alert. It must be less than 280 characters.") ) -> str: """ Dismiss a code scanning alert. @@ -272,34 +238,31 @@ async def dismiss_alert( headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", + "Authorization": f"Bearer {GH_TOKEN}" } async with httpx.AsyncClient(headers=headers) as client: - response = await client.patch( - url, json={"state": "dismissed", "dismissed_reason": "false positive", "dismissed_comment": reason} - ) + response = await client.patch(url, json={"state": "dismissed", "dismissed_reason": "false positive", "dismissed_comment": reason}) response.raise_for_status() return response.text - @mcp.tool() async def check_alert_issue_exists( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - alert_id: str = Field(description="The ID of the alert to check for an associated issue"), + alert_id: str = Field(description="The ID of the alert to check for an associated issue") ) -> str: """ Check if an issue exists for a specific alert in a repository. """ url = f"https://api.github.com/repos/{owner}/{repo}/issues" - # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') params = {"state": "open", "per_page": 100} while True: resp = await call_api(url, params=params) resp_headers = resp.headers - link = resp_headers.get("link", "") + link = resp_headers.get('link', '') resp = resp.json() if isinstance(resp, list): for issue in resp: @@ -314,16 +277,12 @@ async def check_alert_issue_exists( params = parse_qs(urlparse(url).query) return "No issue found for this alert." - @mcp.tool() async def fetch_issues_matches( - repo: str = Field( - description="A comma separated list of repositories to search in. Each term is of the form owner/repo. For example: 'owner1/repo1,owner2/repo2'" - ), + repo: str = Field(description="A comma separated list of repositories to search in. Each term is of the form owner/repo. For example: 'owner1/repo1,owner2/repo2'"), matches: str = Field(description="The search term to match against issue titles"), - state: str = Field(default="open", description="The state of the issues to filter by. Default is 'open'."), - labels: str = Field(default="", description="Labels to filter issues by"), -) -> str: + state: str = Field(default='open', description="The state of the issues to filter by. Default is 'open'."), + labels: str = Field(default="", description="Labels to filter issues by")) -> str: """ Fetch issues from a repository that match a specific title pattern. """ @@ -339,25 +298,18 @@ async def fetch_issues_matches( } if labels: params["labels"] = labels - # see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 + #see https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') while True: resp = await call_api(url, params=params) resp_headers = resp.headers - link = resp_headers.get("link", "") + link = resp_headers.get('link', '') resp = resp.json() if isinstance(resp, list): for issue in resp: if matches in issue.get("title", "") or matches in issue.get("body", ""): - results.append( - { - "title": issue["title"], - "number": issue["number"], - "repo": r, - "body": issue.get("body", ""), - "labels": issue.get("labels", []), - } - ) + results.append({"title": issue["title"], "number": issue["number"], "repo": r, "body": issue.get("body", ""), + "labels": issue.get("labels", [])}) else: return resp + " url: " + url m = link_pattern.search(link) diff --git a/src/seclab_taskflows/mcp_servers/gh_file_viewer.py b/src/seclab_taskflows/mcp_servers/gh_file_viewer.py index 30519ca..ea9a40c 100644 --- a/src/seclab_taskflows/mcp_servers/gh_file_viewer.py +++ b/src/seclab_taskflows/mcp_servers/gh_file_viewer.py @@ -19,18 +19,16 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_gh_file_viewer.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_gh_file_viewer.log'), + filemode='a' ) - class Base(DeclarativeBase): pass - class SearchResults(Base): - __tablename__ = "search_results" + __tablename__ = 'search_results' id: Mapped[int] = mapped_column(primary_key=True) path: Mapped[str] @@ -40,33 +38,26 @@ class SearchResults(Base): repo: Mapped[str] def __repr__(self): - return ( - f"" - ) - + return (f"") mcp = FastMCP("GitHubFileViewer") -GH_TOKEN = os.getenv("GH_TOKEN", default="") +GH_TOKEN = os.getenv('GH_TOKEN', default='') -SEARCH_RESULT_DIR = mcp_data_dir("seclab-taskflows", "gh_file_viewer", "SEARCH_RESULTS_DIR") +SEARCH_RESULT_DIR = mcp_data_dir('seclab-taskflows', 'gh_file_viewer', 'SEARCH_RESULTS_DIR') -engine = create_engine(f"sqlite:///{os.path.abspath(SEARCH_RESULT_DIR)}/search_result.db", echo=False) -Base.metadata.create_all(engine, tables=[SearchResults.__table__]) +engine = create_engine(f'sqlite:///{os.path.abspath(SEARCH_RESULT_DIR)}/search_result.db', echo=False) +Base.metadata.create_all(engine, tables = [SearchResults.__table__]) async def call_api(url: str, params: dict) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = { - "Accept": "application/vnd.github.raw+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } - + headers = {"Accept": "application/vnd.github.raw+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} async def _fetch_file(url, headers, params): try: - async with httpx.AsyncClient(headers=headers) as client: + async with httpx.AsyncClient(headers = headers) as client: r = await client.get(url, params=params, follow_redirects=True) r.raise_for_status() return r @@ -79,24 +70,19 @@ async def _fetch_file(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - return await _fetch_file(url, headers=headers, params=params) - + return await _fetch_file(url, headers = headers, params=params) def remove_root_dir(path): - return "/".join(path.split("/")[1:]) - + return '/'.join(path.split('/')[1:]) async def _fetch_source_zip(owner: str, repo: str, tmp_dir): """Fetch the source code.""" url = f"https://api.github.com/repos/{owner}/{repo}/zipball" - headers = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } + headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} try: async with httpx.AsyncClient() as client: - async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: + async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: response.raise_for_status() expected_path = Path(tmp_dir) / owner / f"{repo}.zip" resolved_path = expected_path.resolve() @@ -104,7 +90,7 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: Invalid path for source code: {expected_path}" if not Path(f"{tmp_dir}/{owner}").exists(): os.makedirs(f"{tmp_dir}/{owner}", exist_ok=True) - async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", "wb") as f: + async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", 'wb') as f: async for chunk in response.aiter_bytes(): await f.write(chunk) return f"source code for {repo} fetched successfully." @@ -115,21 +101,20 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): except Exception as e: return f"Error: An unexpected error occurred: {e}" - def search_zipfile(database_path, term): results = {} with zipfile.ZipFile(database_path) as z: for entry in z.infolist(): if entry.is_dir(): continue - with z.open(entry, "r") as f: + with z.open(entry, 'r') as f: for i, line in enumerate(f): if term in str(line): filename = remove_root_dir(entry.filename) if not filename in results: - results[filename] = [i + 1] + results[filename] = [i+1] else: - results[filename].append(i + 1) + results[filename].append(i+1) return results @@ -137,36 +122,40 @@ def search_zipfile(database_path, term): async def fetch_file_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the file in the repository"), -) -> str: + path: str = Field(description="The path to the file in the repository"))-> str: """ Fetch the content of a file from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={} + ) if isinstance(r, str): return r lines = r.text.splitlines() for i in range(len(lines)): - lines[i] = f"{i + 1}: {lines[i]}" + lines[i] = f"{i+1}: {lines[i]}" return "\n".join(lines) - @mcp.tool() async def get_file_lines_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), start_line: int = Field(description="The starting line number to fetch from the file", default=1), - length: int = Field(description="The ending line number to fetch from the file", default=10), -) -> str: - """Fetch a range of lines from a file in a GitHub repository.""" + length: int = Field(description="The ending line number to fetch from the file", default=10)) -> str: + """Fetch a range of lines from a file in a GitHub repository. + """ owner = owner.lower() repo = repo.lower() - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={} + ) if isinstance(r, str): return r lines = r.text.splitlines() @@ -174,63 +163,61 @@ async def get_file_lines_from_gh( start_line = 1 if length < 1: length = 10 - lines = lines[start_line - 1 : start_line - 1 + length] + lines = lines[start_line-1:start_line-1+length] if not lines: return f"No lines found in the range {start_line} to {start_line + length - 1} in {path}." - return "\n".join([f"{i + start_line}: {line}" for i, line in enumerate(lines)]) - + return "\n".join([f"{i+start_line}: {line}" for i, line in enumerate(lines)]) @mcp.tool() async def search_file_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), - search_term: str = Field(description="The term to search for in the file"), -) -> str: + search_term: str = Field(description="The term to search for in the file")) -> str: """ Search for a term in a file from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={} + ) if isinstance(r, str): return r lines = r.text.splitlines() - matches = [f"{i + 1}: {line}" for i, line in enumerate(lines) if search_term in line] + matches = [f"{i+1}: {line}" for i,line in enumerate(lines) if search_term in line] if not matches: return f"No matches found for '{search_term}' in {path}." return "\n".join(matches) - @mcp.tool() async def search_files_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), paths: str = Field(description="A comma separated list of paths to the file in the repository"), search_term: str = Field(description="The term to search for in the file"), - save_to_db: bool = Field(description="Save the results to database.", default=False), -) -> str: + save_to_db: bool = Field(description="Save the results to database.", default=False)) -> str: """ Search for a term in a list of files from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - paths_list = [path.strip() for path in paths.split(",")] + paths_list = [path.strip() for path in paths.split(',')] if not paths_list: return "No paths provided for search." results = [] for path in paths_list: - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={} + ) if isinstance(r, str): return r lines = r.text.splitlines() - matches = [ - {"path": path, "line": i + 1, "search_term": search_term, "owner": owner.lower(), "repo": repo.lower()} - for i, line in enumerate(lines) - if search_term in line - ] + matches = [{"path": path, "line" : i+1, "search_term": search_term, "owner": owner.lower(), "repo" : repo.lower()} for i,line in enumerate(lines) if search_term in line] if matches: results.extend(matches) if not results: @@ -244,7 +231,6 @@ async def search_files_from_gh( return f"Search results saved to database." return json.dumps(results) - @mcp.tool() def fetch_last_search_results() -> str: """ @@ -254,54 +240,43 @@ def fetch_last_search_results() -> str: results = session.query(SearchResults).all() session.query(SearchResults).delete() session.commit() - return json.dumps( - [ - { - "path": result.path, - "line": result.line, - "search_term": result.search_term, - "owner": result.owner.lower(), - "repo": result.repo.lower(), - } - for result in results - ] - ) - + return json.dumps([{"path": result.path, "line" : result.line, "search_term": result.search_term, "owner": result.owner.lower(), "repo" : result.repo.lower()} for result in results]) @mcp.tool() async def list_directory_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository"), -) -> str: + path: str = Field(description="The path to the directory in the repository")) -> str: """ Fetch the content of a directory from a GitHub repository. """ owner = owner.lower() repo = repo.lower() - r = await call_api(url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", params={}) + r = await call_api( + url=f"https://api.github.com/repos/{owner}/{repo}/contents/{path}", + params={} + ) if isinstance(r, str): return r if not r.json(): return json.dumps([], indent=2) - content = [item["path"] for item in r.json()] + content = [item['path'] for item in r.json()] return json.dumps(content, indent=2) - @mcp.tool() async def search_repo_from_gh( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - search_term: str = Field(description="The term to search within the repo."), + search_term: str = Field(description="The term to search within the repo.") ): """ Search for the search term in the entire repository. """ owner = owner.lower() repo = repo.lower() - + with tempfile.TemporaryDirectory() as tmp_dir: result = await _fetch_source_zip(owner, repo, tmp_dir) source_path = Path(f"{tmp_dir}/{owner}/{repo}.zip") @@ -309,10 +284,9 @@ async def search_repo_from_gh( return json.dumps([result], indent=2) results = search_zipfile(source_path, search_term) out = [] - for k, v in results.items(): + for k,v in results.items(): out.append({"owner": owner, "repo": repo, "path": k, "lines": v}) return json.dumps(out, indent=2) - if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/ghsa.py b/src/seclab_taskflows/mcp_servers/ghsa.py index 4611c71..c149df6 100644 --- a/src/seclab_taskflows/mcp_servers/ghsa.py +++ b/src/seclab_taskflows/mcp_servers/ghsa.py @@ -10,9 +10,9 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_ghsa.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_ghsa.log'), + filemode='a' ) mcp = FastMCP("GitHubRepoAdvisories") @@ -30,11 +30,10 @@ def parse_advisory(advisory: dict) -> dict: "state": advisory.get("state", ""), } - async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: """Fetch all security advisories for a specific repository.""" url = f"https://api.github.com/repos/{owner}/{repo}/security-advisories" - params = {"per_page": 100} + params = {'per_page': 100} # See https://github.com/octokit/plugin-paginate-rest.js/blob/8ec2713699ee473ee630be5c8a66b9665bcd4173/src/iterator.ts#L40 link_pattern = re.compile(r'<([^<>]+)>;\s*rel="next"') results = [] @@ -43,7 +42,7 @@ async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: if isinstance(resp, str): return resp resp_headers = resp.headers - link = resp_headers.get("link", "") + link = resp_headers.get('link', '') resp = resp.json() if isinstance(resp, list): results += [parse_advisory(advisory) for advisory in resp] @@ -59,11 +58,9 @@ async def fetch_GHSA_list_from_gh(owner: str, repo: str) -> str | list: return results return "No advisories found." - @mcp.tool() -async def fetch_GHSA_list( - owner: str = Field(description="The owner of the repo"), repo: str = Field(description="The repository name") -) -> str: +async def fetch_GHSA_list(owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name")) -> str: """Fetch all GitHub Security Advisories (GHSAs) for a specific repository.""" results = await fetch_GHSA_list_from_gh(owner, repo) if isinstance(results, str): @@ -81,19 +78,15 @@ async def fetch_GHSA_details_from_gh(owner: str, repo: str, ghsa_id: str) -> str return resp.json() return "Not found." - @mcp.tool() -async def fetch_GHSA_details( - owner: str = Field(description="The owner of the repo"), - repo: str = Field(description="The repository name"), - ghsa_id: str = Field(description="The ghsa_id of the advisory"), -) -> str: +async def fetch_GHSA_details(owner: str = Field(description="The owner of the repo"), + repo: str = Field(description="The repository name"), + ghsa_id: str = Field(description="The ghsa_id of the advisory")) -> str: """Fetch a GitHub Security Advisory for a specific repository and GHSA ID.""" results = await fetch_GHSA_details_from_gh(owner, repo, ghsa_id) if isinstance(results, str): return results return json.dumps(results, indent=2) - if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/local_file_viewer.py b/src/seclab_taskflows/mcp_servers/local_file_viewer.py index a85297b..4dd73bc 100644 --- a/src/seclab_taskflows/mcp_servers/local_file_viewer.py +++ b/src/seclab_taskflows/mcp_servers/local_file_viewer.py @@ -15,19 +15,18 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_local_file_viewer.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_local_file_viewer.log'), + filemode='a' ) mcp = FastMCP("LocalFileViewer") -LOCAL_GH_DIR = mcp_data_dir("seclab-taskflows", "local_file_viewer", "LOCAL_GH_DIR") +LOCAL_GH_DIR = mcp_data_dir('seclab-taskflows', 'local_file_viewer', 'LOCAL_GH_DIR') -LINE_LIMIT_FOR_FETCHING_FILE_CONTENT = int(os.getenv("LINE_LIMIT_FOR_FETCHING_FILE_CONTENT", default=1000)) - -FILE_LIMIT_FOR_LIST_FILES = int(os.getenv("FILE_LIMIT_FOR_LIST_FILES", default=100)) +LINE_LIMIT_FOR_FETCHING_FILE_CONTENT = int(os.getenv('LINE_LIMIT_FOR_FETCHING_FILE_CONTENT', default=1000)) +FILE_LIMIT_FOR_LIST_FILES = int(os.getenv('FILE_LIMIT_FOR_LIST_FILES', default=100)) def is_subdirectory(directory, potential_subdirectory): directory_path = Path(directory) @@ -38,7 +37,6 @@ def is_subdirectory(directory, potential_subdirectory): except ValueError: return False - def sanitize_file_path(file_path, allow_paths): file_path = os.path.realpath(file_path) for allowed_path in allow_paths: @@ -46,18 +44,15 @@ def sanitize_file_path(file_path, allow_paths): return Path(file_path) return None - def remove_root_dir(path): - return "/".join(path.split("/")[1:]) - + return '/'.join(path.split('/')[1:]) def strip_leading_dash(path): - if path and path[0] == "/": + if path and path[0] == '/': path = path[1:] return path - -def search_zipfile(database_path, term, search_dir=None): +def search_zipfile(database_path, term, search_dir = None): results = {} search_dir = strip_leading_dash(search_dir) with zipfile.ZipFile(database_path) as z: @@ -66,18 +61,17 @@ def search_zipfile(database_path, term, search_dir=None): continue if search_dir and not is_subdirectory(search_dir, remove_root_dir(entry.filename)): continue - with z.open(entry, "r") as f: + with z.open(entry, 'r') as f: for i, line in enumerate(f): if term in str(line): filename = remove_root_dir(entry.filename) if not filename in results: - results[filename] = [i + 1] + results[filename] = [i+1] else: - results[filename].append(i + 1) + results[filename].append(i+1) return results - -def _list_files(database_path, root_dir=None, recursive=True): +def _list_files(database_path, root_dir = None, recursive=True): results = [] root_dir = strip_leading_dash(root_dir) with zipfile.ZipFile(database_path) as z: @@ -86,7 +80,7 @@ def _list_files(database_path, root_dir=None, recursive=True): if not recursive: dirname = remove_root_dir(entry.filename) if Path(dirname).parent == Path(root_dir): - results.append(dirname + "/") + results.append(dirname + '/') continue filename = remove_root_dir(entry.filename) if root_dir and not is_subdirectory(root_dir, filename): @@ -96,7 +90,6 @@ def _list_files(database_path, root_dir=None, recursive=True): results.append(filename) return results - def get_file(database_path, filename): results = [] filename = strip_leading_dash(filename) @@ -105,18 +98,16 @@ def get_file(database_path, filename): if entry.is_dir(): continue if remove_root_dir(entry.filename) == filename: - with z.open(entry, "r") as f: + with z.open(entry, 'r') as f: results = [line.rstrip() for line in f] return results return results - @mcp.tool() async def fetch_file_content( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the file in the repository"), -) -> str: + path: str = Field(description="The path to the file in the repository"))-> str: """ Fetch the content of a file from a local GitHub repository. """ @@ -133,19 +124,18 @@ async def fetch_file_content( if not lines: return f"Unable to find file {path} in {owner}/{repo}" for i in range(len(lines)): - lines[i] = f"{i + 1}: {lines[i]}" + lines[i] = f"{i+1}: {lines[i]}" return "\n".join(lines) - @mcp.tool() async def get_file_lines( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), path: str = Field(description="The path to the file in the repository"), start_line: int = Field(description="The starting line number to fetch from the file", default=1), - length: int = Field(description="The ending line number to fetch from the file", default=10), -) -> str: - """Fetch a range of lines from a file in a local GitHub repository.""" + length: int = Field(description="The ending line number to fetch from the file", default=10)) -> str: + """Fetch a range of lines from a file in a local GitHub repository. + """ owner = owner.lower() repo = repo.lower() @@ -158,18 +148,16 @@ async def get_file_lines( start_line = 1 if length < 1: length = 10 - lines = lines[start_line - 1 : start_line - 1 + length] + lines = lines[start_line-1:start_line-1+length] if not lines: return f"No lines found in the range {start_line} to {start_line + length - 1} in {path}." - return "\n".join([f"{i + start_line}: {line}" for i, line in enumerate(lines)]) - + return "\n".join([f"{i+start_line}: {line}" for i, line in enumerate(lines)]) @mcp.tool() async def list_files( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository"), -) -> str: + path: str = Field(description="The path to the directory in the repository")) -> str: """ Recursively list the files of a directory from a local GitHub repository. """ @@ -185,13 +173,11 @@ async def list_files( return f"Too many files to display in {owner}/{repo} at path {path} ({len(content)} files). Try using `list_files_non_recursive` instead." return json.dumps(content, indent=2) - @mcp.tool() async def list_files_non_recursive( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), - path: str = Field(description="The path to the directory in the repository"), -) -> str: + path: str = Field(description="The path to the directory in the repository")) -> str: """ List the files of a directory from a local GitHub repository non-recursively. Subdirectories will be listed and indicated with a trailing slash. @@ -212,10 +198,7 @@ async def search_repo( owner: str = Field(description="The owner of the repository"), repo: str = Field(description="The name of the repository"), search_term: str = Field(description="The term to search within the repo."), - directory: str = Field( - description="The directory or file to restrict the search, if not provided, the whole repo is searched", - default="", - ), + directory: str = Field(description="The directory or file to restrict the search, if not provided, the whole repo is searched", default = '') ): """ Search for the search term in the repository or a subdirectory/file in the repository. @@ -231,10 +214,9 @@ async def search_repo( return json.dumps([], indent=2) results = search_zipfile(source_path, search_term, directory) out = [] - for k, v in results.items(): + for k,v in results.items(): out.append({"owner": owner, "repo": repo, "path": k, "lines": v}) return json.dumps(out, indent=2) - if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/local_gh_resources.py b/src/seclab_taskflows/mcp_servers/local_gh_resources.py index 33fc641..3c48ad6 100644 --- a/src/seclab_taskflows/mcp_servers/local_gh_resources.py +++ b/src/seclab_taskflows/mcp_servers/local_gh_resources.py @@ -15,17 +15,16 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_local_gh_resources.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_local_gh_resources.log'), + filemode='a' ) mcp = FastMCP("LocalGHResources") -GH_TOKEN = os.getenv("GH_TOKEN") - -LOCAL_GH_DIR = mcp_data_dir("seclab-taskflows", "local_gh_resources", "LOCAL_GH_DIR") +GH_TOKEN = os.getenv('GH_TOKEN') +LOCAL_GH_DIR = mcp_data_dir('seclab-taskflows', 'local_gh_resources', 'LOCAL_GH_DIR') def is_subdirectory(directory, potential_subdirectory): directory_path = Path(directory) @@ -36,7 +35,6 @@ def is_subdirectory(directory, potential_subdirectory): except ValueError: return False - def sanitize_file_path(file_path, allow_paths): file_path = os.path.realpath(file_path) for allowed_path in allow_paths: @@ -44,18 +42,13 @@ def sanitize_file_path(file_path, allow_paths): return Path(file_path) return None - async def call_api(url: str, params: dict) -> str: """Call the GitHub code scanning API to fetch alert.""" - headers = { - "Accept": "application/vnd.github.raw+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } - + headers = {"Accept": "application/vnd.github.raw+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} async def _fetch_file(url, headers, params): try: - async with httpx.AsyncClient(headers=headers) as client: + async with httpx.AsyncClient(headers = headers) as client: r = await client.get(url, params=params, follow_redirects=True) r.raise_for_status() return r @@ -68,20 +61,16 @@ async def _fetch_file(url, headers, params): except httpx.AuthenticationError as e: return f"Authentication error: {e}" - return await _fetch_file(url, headers=headers, params=params) - + return await _fetch_file(url, headers = headers, params=params) async def _fetch_source_zip(owner: str, repo: str, tmp_dir): """Fetch the source code.""" url = f"https://api.github.com/repos/{owner}/{repo}/zipball" - headers = { - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", - "Authorization": f"Bearer {GH_TOKEN}", - } + headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", + "Authorization": f"Bearer {GH_TOKEN}"} try: async with httpx.AsyncClient() as client: - async with client.stream("GET", url, headers=headers, follow_redirects=True) as response: + async with client.stream('GET', url, headers =headers, follow_redirects=True) as response: response.raise_for_status() expected_path = Path(tmp_dir) / owner / f"{repo}.zip" resolved_path = expected_path.resolve() @@ -89,7 +78,7 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: Invalid path for source code: {expected_path}" if not Path(f"{tmp_dir}/{owner}").exists(): os.makedirs(f"{tmp_dir}/{owner}", exist_ok=True) - async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", "wb") as f: + async with aiofiles.open(f"{tmp_dir}/{owner}/{repo}.zip", 'wb') as f: async for chunk in response.aiter_bytes(): await f.write(chunk) return f"source code for {repo} fetched successfully." @@ -99,10 +88,10 @@ async def _fetch_source_zip(owner: str, repo: str, tmp_dir): return f"Error: HTTP error: {e}" except Exception as e: return f"Error: An unexpected error occurred: {e}" - - @mcp.tool() -async def fetch_repo_from_gh(owner: str, repo: str): +async def fetch_repo_from_gh( + owner: str, repo: str +): """ Download the source code from GitHub to the local file system to speed up file search. """ @@ -115,7 +104,6 @@ async def fetch_repo_from_gh(owner: str, repo: str): return result return f"Downloaded source code to {owner}/{repo}.zip" - @mcp.tool() async def clear_local_repo(owner: str, repo: str): """ diff --git a/src/seclab_taskflows/mcp_servers/repo_context.py b/src/seclab_taskflows/mcp_servers/repo_context.py index e69eaae..3c71bfc 100644 --- a/src/seclab_taskflows/mcp_servers/repo_context.py +++ b/src/seclab_taskflows/mcp_servers/repo_context.py @@ -19,13 +19,12 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_repo_context.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_repo_context.log'), + filemode='a' ) -MEMORY = mcp_data_dir("seclab-taskflows", "repo_context", "REPO_CONTEXT_DIR") - +MEMORY = mcp_data_dir('seclab-taskflows', 'repo_context', 'REPO_CONTEXT_DIR') def app_to_dict(result): return { @@ -34,10 +33,9 @@ def app_to_dict(result): "location": result.location, "notes": result.notes, "is_app": result.is_app, - "is_library": result.is_library, + "is_library": result.is_library } - def entry_point_to_dict(ep): return { "id": ep.id, @@ -46,10 +44,9 @@ def entry_point_to_dict(ep): "user_input": ep.user_input, "repo": ep.repo.lower(), "line": ep.line, - "notes": ep.notes, + "notes": ep.notes } - def user_action_to_dict(ua): return { "id": ua.id, @@ -57,10 +54,9 @@ def user_action_to_dict(ua): "file": ua.file, "line": ua.line, "repo": ua.repo.lower(), - "notes": ua.notes, + "notes": ua.notes } - def web_entry_point_to_dict(wep): return { "id": wep.id, @@ -72,47 +68,36 @@ def web_entry_point_to_dict(wep): "middleware": wep.middleware, "roles_scopes": wep.roles_scopes, "repo": wep.repo.lower(), - "notes": wep.notes, + "notes": wep.notes } - def audit_result_to_dict(res): return { - "id": res.id, - "repo": res.repo.lower(), - "component_id": res.component_id, - "issue_type": res.issue_type, - "issue_id": res.issue_id, - "notes": res.notes, + "id" : res.id, + "repo" : res.repo.lower(), + "component_id" : res.component_id, + "issue_type" : res.issue_type, + "issue_id" : res.issue_id, + "notes" : res.notes, "has_vulnerability": res.has_vulnerability, - "has_non_security_error": res.has_non_security_error, + "has_non_security_error": res.has_non_security_error } - class RepoContextBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir - self.location_pattern = r"^([a-zA-Z]+)(:\d+){4}$" + self.location_pattern = r'^([a-zA-Z]+)(:\d+){4}$' if not Path(self.memcache_state_dir).exists(): - db_dir = "sqlite://" + db_dir = 'sqlite://' else: - db_dir = f"sqlite:///{self.memcache_state_dir}/repo_context.db" + db_dir = f'sqlite:///{self.memcache_state_dir}/repo_context.db' self.engine = create_engine(db_dir, echo=False) - Base.metadata.create_all( - self.engine, - tables=[ - Application.__table__, - EntryPoint.__table__, - UserAction.__table__, - WebEntryPoint.__table__, - ApplicationIssue.__table__, - AuditResult.__table__, - ], - ) + Base.metadata.create_all(self.engine, tables=[Application.__table__, EntryPoint.__table__, UserAction.__table__, + WebEntryPoint.__table__, ApplicationIssue.__table__, AuditResult.__table__]) def store_new_application(self, repo, location, is_app, is_library, notes): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo=repo, location=location).first() + existing = session.query(Application).filter_by(repo = repo, location = location).first() if existing: if is_app is not None: existing.is_app = is_app @@ -120,31 +105,25 @@ def store_new_application(self, repo, location, is_app, is_library, notes): existing.is_library = is_library existing.notes += notes else: - new_application = Application( - repo=repo, location=location, is_app=is_app, is_library=is_library, notes=notes - ) + new_application = Application(repo = repo, location = location, is_app = is_app, is_library = is_library, notes = notes) session.add(new_application) session.commit() return f"Updated or added application for {location} in {repo}." def store_new_component_issue(self, repo, component_id, issue_type, notes): with Session(self.engine) as session: - existing = ( - session.query(ApplicationIssue) - .filter_by(repo=repo, component_id=component_id, issue_type=issue_type) - .first() - ) + existing = session.query(ApplicationIssue).filter_by(repo = repo, component_id = component_id, issue_type = issue_type).first() if existing: existing.notes += notes else: - new_issue = ApplicationIssue(repo=repo, component_id=component_id, issue_type=issue_type, notes=notes) + new_issue = ApplicationIssue(repo = repo, component_id = component_id, issue_type = issue_type, notes = notes) session.add(new_issue) session.commit() return f"Updated or added application issue for {repo} and {component_id}" def overwrite_component_issue_notes(self, id, notes): with Session(self.engine) as session: - existing = session.query(ApplicationIssue).filter_by(id=id).first() + existing = session.query(ApplicationIssue).filter_by(id = id).first() if not existing: return f"Component issue with id {id} does not exist!" else: @@ -152,49 +131,36 @@ def overwrite_component_issue_notes(self, id, notes): session.commit() return f"Updated notes for application issue with id {id}" - def store_new_audit_result( - self, repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes - ): + def store_new_audit_result(self, repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes): with Session(self.engine) as session: - existing = session.query(AuditResult).filter_by(repo=repo, issue_id=issue_id).first() + existing = session.query(AuditResult).filter_by(repo = repo, issue_id = issue_id).first() if existing: existing.notes += notes existing.has_non_security_error = has_non_security_error existing.has_vulnerability = has_vulnerability else: - new_result = AuditResult( - repo=repo, - component_id=component_id, - issue_type=issue_type, - issue_id=issue_id, - notes=notes, - has_non_security_error=has_non_security_error, - has_vulnerability=has_vulnerability, - ) + new_result = AuditResult(repo = repo, component_id = component_id, issue_type = issue_type, issue_id = issue_id, notes = notes, + has_non_security_error = has_non_security_error, has_vulnerability = has_vulnerability) session.add(new_result) session.commit() return f"Updated or added audit result for {repo} and {issue_id}" - def store_new_entry_point(self, repo, app_id, file, user_input, line, notes, update=False): + def store_new_entry_point(self, repo, app_id, file, user_input, line, notes, update = False): with Session(self.engine) as session: - existing = session.query(EntryPoint).filter_by(repo=repo, file=file, line=line).first() + existing = session.query(EntryPoint).filter_by(repo = repo, file = file, line = line).first() if existing: existing.notes += notes else: if update: return f"No entry point exists at repo {repo}, file {file} and line {line}" - new_entry_point = EntryPoint( - repo=repo, app_id=app_id, file=file, user_input=user_input, line=line, notes=notes - ) + new_entry_point = EntryPoint(repo = repo, app_id = app_id, file = file, user_input = user_input, line = line, notes = notes) session.add(new_entry_point) session.commit() return f"Updated or added entry point for {file} and {line} in {repo}." - def store_new_web_entry_point( - self, repo, entry_point_id, method, path, component, auth, middleware, roles_scopes, notes, update=False - ): + def store_new_web_entry_point(self, repo, entry_point_id, method, path, component, auth, middleware, roles_scopes, notes, update = False): with Session(self.engine) as session: - existing = session.query(WebEntryPoint).filter_by(repo=repo, entry_point_id=entry_point_id).first() + existing = session.query(WebEntryPoint).filter_by(repo = repo, entry_point_id = entry_point_id).first() if existing: existing.notes += notes if method: @@ -213,188 +179,163 @@ def store_new_web_entry_point( if update: return f"No web entry point exists at repo {repo} with entry_point_id {entry_point_id}." new_web_entry_point = WebEntryPoint( - repo=repo, - entry_point_id=entry_point_id, - method=method, - path=path, - component=component, - auth=auth, - middleware=middleware, - roles_scopes=roles_scopes, - notes=notes, + repo = repo, + entry_point_id = entry_point_id, + method = method, + path = path, + component = component, + auth = auth, + middleware = middleware, + roles_scopes = roles_scopes, + notes = notes ) session.add(new_web_entry_point) session.commit() return f"Updated or added web entry point for entry_point_id {entry_point_id} in {repo}." - def store_new_user_action(self, repo, app_id, file, line, notes, update=False): + def store_new_user_action(self, repo, app_id, file, line, notes, update = False): with Session(self.engine) as session: - existing = session.query(UserAction).filter_by(repo=repo, file=file, line=line).first() + existing = session.query(UserAction).filter_by(repo = repo, file = file, line = line).first() if existing: existing.notes += notes else: if update: return f"No user action exists at repo {repo}, file {file} and line {line}." - new_user_action = UserAction(repo=repo, app_id=app_id, file=file, line=line, notes=notes) + new_user_action = UserAction(repo = repo, app_id = app_id, file = file, line = line, notes = notes) session.add(new_user_action) session.commit() return f"Updated or added user action for {file} and {line} in {repo}." def get_app(self, repo, location): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo=repo, location=location).first() + existing = session.query(Application).filter_by(repo = repo, location = location).first() if not existing: return None return existing def get_apps(self, repo): with Session(self.engine) as session: - existing = session.query(Application).filter_by(repo=repo).all() + existing = session.query(Application).filter_by(repo = repo).all() return [app_to_dict(app) for app in existing] def get_app_issues(self, repo, component_id): with Session(self.engine) as session: issues = session.query(Application, ApplicationIssue).filter( - Application.repo == repo, Application.id == ApplicationIssue.component_id + Application.repo == repo, + Application.id == ApplicationIssue.component_id ) if component_id is not None: issues = issues.filter(Application.id == component_id) issues = issues.all() - return [ - { - "component_id": app.id, - "location": app.location, - "repo": app.repo, - "component_notes": app.notes, - "issue_type": issue.issue_type, - "issue_notes": issue.notes, - "issue_id": issue.id, - } - for app, issue in issues - ] + return [{ + 'component_id': app.id, + 'location' : app.location, + 'repo' : app.repo, + 'component_notes' : app.notes, + 'issue_type' : issue.issue_type, + 'issue_notes': issue.notes, + 'issue_id' : issue.id + } for app, issue in issues] def get_app_audit_results(self, repo, component_id, has_non_security_error, has_vulnerability): with Session(self.engine) as session: - issues = ( - session.query(Application, AuditResult) - .filter(Application.repo == repo) - .filter(Application.id == AuditResult.component_id) - ) + issues = session.query(Application, AuditResult).filter(Application.repo == repo + ).filter(Application.id == AuditResult.component_id) if component_id is not None: - issues = issues.filter(Application.id == component_id) + issues = issues.filter(Application.id == component_id) if has_non_security_error is not None: issues = issues.filter(AuditResult.has_non_security_error == has_non_security_error) if has_vulnerability is not None: issues = issues.filter(AuditResult.has_vulnerability == has_vulnerability) issues = issues.all() - return [ - { - "component_id": app.id, - "location": app.location, - "repo": app.repo, - "issue_type": issue.issue_type, - "issue_id": issue.issue_id, - "notes": issue.notes, - "has_vulnerability": issue.has_vulnerability, - "has_non_security_error": issue.has_non_security_error, - } - for app, issue in issues - ] + return [{ + 'component_id': app.id, + 'location' : app.location, + 'repo' : app.repo, + 'issue_type' : issue.issue_type, + 'issue_id' : issue.issue_id, + 'notes': issue.notes, + 'has_vulnerability' : issue.has_vulnerability, + 'has_non_security_error' : issue.has_non_security_error + } for app, issue in issues] def get_app_entries(self, repo, location): with Session(self.engine) as session: - results = ( - session.query(Application, EntryPoint) - .filter(Application.repo == repo, Application.location == location) - .filter(EntryPoint.app_id == Application.id) - .all() - ) + results = session.query(Application, EntryPoint + ).filter(Application.repo == repo, Application.location == location + ).filter(EntryPoint.app_id == Application.id).all() eps = [entry_point_to_dict(ep) for app, ep in results] return eps def get_app_entries_for_repo(self, repo): with Session(self.engine) as session: - results = ( - session.query(Application, EntryPoint) - .filter(Application.repo == repo) - .filter(EntryPoint.app_id == Application.id) - .all() - ) + results = session.query(Application, EntryPoint + ).filter(Application.repo == repo + ).filter(EntryPoint.app_id == Application.id).all() eps = [entry_point_to_dict(ep) for app, ep in results] return eps def get_web_entries_for_repo(self, repo): with Session(self.engine) as session: - results = session.query(WebEntryPoint).filter_by(repo=repo).all() - return [ - { - "repo": r.repo, - "entry_point_id": r.entry_point_id, - "method": r.method, - "path": r.path, - "component": r.component, - "auth": r.auth, - "middleware": r.middleware, - "roles_scopes": r.roles_scopes, - "notes": r.notes, - } - for r in results - ] + results = session.query(WebEntryPoint).filter_by(repo = repo).all() + return [{ + 'repo' : r.repo, + 'entry_point_id' : r.entry_point_id, + 'method' : r.method, + 'path' : r.path, + 'component' : r.component, + 'auth' : r.auth, + 'middleware' : r.middleware, + 'roles_scopes' : r.roles_scopes, + 'notes' : r.notes + } for r in results] def get_web_entries(self, repo, component_id): with Session(self.engine) as session: - results = session.query(WebEntryPoint).filter_by(repo=repo, component=component_id).all() - return [ - { - "repo": r.repo, - "entry_point_id": r.entry_point_id, - "method": r.method, - "path": r.path, - "component": r.component, - "auth": r.auth, - "middleware": r.middleware, - "roles_scopes": r.roles_scopes, - "notes": r.notes, - } - for r in results - ] + results = session.query(WebEntryPoint).filter_by(repo = repo, component = component_id).all() + return [{ + 'repo' : r.repo, + 'entry_point_id' : r.entry_point_id, + 'method' : r.method, + 'path' : r.path, + 'component' : r.component, + 'auth' : r.auth, + 'middleware' : r.middleware, + 'roles_scopes' : r.roles_scopes, + 'notes' : r.notes + } for r in results] + def get_user_actions(self, repo, location): with Session(self.engine) as session: - results = ( - session.query(Application, UserAction) - .filter(Application.repo == repo, Application.location == location) - .filter(UserAction.app_id == Application.id) - .all() - ) + results = session.query(Application, UserAction + ).filter(Application.repo == repo, Application.location == location + ).filter(UserAction.app_id == Application.id).all() uas = [user_action_to_dict(ua) for app, ua in results] return uas def get_user_actions_for_repo(self, repo): with Session(self.engine) as session: - results = ( - session.query(Application, UserAction) - .filter(Application.repo == repo) - .filter(UserAction.app_id == Application.id) - .all() - ) + results = session.query(Application, UserAction + ).filter(Application.repo == repo + ).filter(UserAction.app_id == Application.id).all() uas = [user_action_to_dict(ua) for app, ua in results] return uas def clear_repo(self, repo): with Session(self.engine) as session: - session.query(Application).filter_by(repo=repo).delete() - session.query(EntryPoint).filter_by(repo=repo).delete() - session.query(UserAction).filter_by(repo=repo).delete() - session.query(ApplicationIssue).filter_by(repo=repo).delete() - session.query(WebEntryPoint).filter_by(repo=repo).delete() - session.query(AuditResult).filter_by(repo=repo).delete() + session.query(Application).filter_by(repo = repo).delete() + session.query(EntryPoint).filter_by(repo = repo).delete() + session.query(UserAction).filter_by(repo = repo).delete() + session.query(ApplicationIssue).filter_by(repo = repo).delete() + session.query(WebEntryPoint).filter_by(repo = repo).delete() + session.query(AuditResult).filter_by(repo = repo).delete() session.commit() return f"Cleared results for repo {repo}" def clear_repo_issues(self, repo): with Session(self.engine) as session: - session.query(ApplicationIssue).filter_by(repo=repo).delete() + session.query(ApplicationIssue).filter_by(repo = repo).delete() session.commit() return f"Clear application issues for repo {repo}" @@ -403,29 +344,23 @@ def clear_repo_issues(self, repo): backend = RepoContextBackend(MEMORY) - @mcp.tool() -def store_new_component( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component"), - is_app: bool = Field(description="Is this an application", default=None), - is_library: bool = Field(description="Is this a library", default=None), - notes: str = Field(description="The notes taken for this component", default=""), -): +def store_new_component(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component"), + is_app: bool = Field(description="Is this an application", default=None), + is_library: bool = Field(description="Is this a library", default=None), + notes: str = Field(description="The notes taken for this component", default="")): """ Stores a new component in the database. """ return backend.store_new_application(process_repo(owner, repo), location, is_app, is_library, notes) - @mcp.tool() -def add_component_notes( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component", default=None), - notes: str = Field(description="New notes taken for this component", default=""), -): +def add_component_notes(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component", default=None), + notes: str = Field(description="New notes taken for this component", default="")): """ Add new notes to a component """ @@ -435,17 +370,14 @@ def add_component_notes( return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_application(repo, location, None, None, notes) - @mcp.tool() -def store_new_entry_point( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the entry point belongs to"), - file: str = Field(description="The file that contains the entry point"), - line: int = Field(description="The file line that contains the entry point"), - user_input: str = Field(description="The variables that are considered as user input"), - notes: str = Field(description="The notes for this entry point", default=""), -): +def store_new_entry_point(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the entry point belongs to"), + file: str = Field(description="The file that contains the entry point"), + line: int = Field(description="The file line that contains the entry point"), + user_input: str = Field(description="The variables that are considered as user input"), + notes: str = Field(description="The notes for this entry point", default = "")): """ Stores a new entry point in a component to the database. """ @@ -455,76 +387,58 @@ def store_new_entry_point( return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_entry_point(repo, app.id, file, user_input, line, notes) - @mcp.tool() -def store_new_component_issue( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), - issue_type: str = Field(description="The type of issue"), - notes: str = Field(description="Notes about the issue"), -): +def store_new_component_issue(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), + issue_type: str = Field(description="The type of issue"), + notes: str = Field(description="Notes about the issue")): """ Stores a type of common issue for a component. """ repo = process_repo(owner, repo) return backend.store_new_component_issue(repo, component_id, issue_type, notes) - @mcp.tool() -def store_new_audit_result( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), - issue_type: str = Field(description="The type of issue"), - issue_id: int = Field(description="The ID of the issue"), - has_non_security_error: bool = Field( - description="Set to true if there are security issues or logic error but may not be exploitable" - ), - has_vulnerability: bool = Field(description="Set to true if a security vulnerability is identified"), - notes: str = Field(description="The notes for the audit of this issue"), -): +def store_new_audit_result(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component"), + issue_type: str = Field(description="The type of issue"), + issue_id: int = Field(description="The ID of the issue"), + has_non_security_error: bool = Field(description="Set to true if there are security issues or logic error but may not be exploitable"), + has_vulnerability: bool = Field(description="Set to true if a security vulnerability is identified"), + notes: str = Field(description="The notes for the audit of this issue")): """ Stores the audit result for issue with issue_id. """ repo = process_repo(owner, repo) - return backend.store_new_audit_result( - repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes - ) - + return backend.store_new_audit_result(repo, component_id, issue_type, issue_id, has_non_security_error, has_vulnerability, notes) @mcp.tool() -def store_new_web_entry_point( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - entry_point_id: int = Field(description="The ID of the entry point this web entry point refers to"), - location: str = Field(description="The directory of the component where the web entry point belongs to"), - method: str = Field(description="HTTP method (GET, POST, etc)", default=""), - path: str = Field(description="URL path (e.g., /info)", default=""), - component: int = Field(description="Component identifier", default=0), - auth: str = Field(description="Authentication information", default=""), - middleware: str = Field(description="Middleware information", default=""), - roles_scopes: str = Field(description="Roles and scopes information", default=""), - notes: str = Field(description="Notes for this web entry point", default=""), -): +def store_new_web_entry_point(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + entry_point_id: int = Field(description="The ID of the entry point this web entry point refers to"), + location: str = Field(description="The directory of the component where the web entry point belongs to"), + method: str = Field(description="HTTP method (GET, POST, etc)", default=""), + path: str = Field(description="URL path (e.g., /info)", default=""), + component: int = Field(description="Component identifier", default=0), + auth: str = Field(description="Authentication information", default=""), + middleware: str = Field(description="Middleware information", default=""), + roles_scopes: str = Field(description="Roles and scopes information", default=""), + notes: str = Field(description="Notes for this web entry point", default="")): """ Stores a new web entry point in a component to the database. A web entry point extends a regular entry point with web-specific properties like HTTP method, path, authentication, middleware, and roles/scopes. """ - return backend.store_new_web_entry_point( - process_repo(owner, repo), entry_point_id, method, path, component, auth, middleware, roles_scopes, notes - ) - + return backend.store_new_web_entry_point(process_repo(owner, repo), entry_point_id, method, path, component, auth, middleware, roles_scopes, notes) @mcp.tool() -def add_entry_point_notes( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the entry point belongs to"), - file: str = Field(description="The file that contains the entry point"), - line: int = Field(description="The file line that contains the entry point"), - notes: str = Field(description="The notes for this entry point", default=""), -): +def add_entry_point_notes(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the entry point belongs to"), + file: str = Field(description="The file that contains the entry point"), + line: int = Field(description="The file line that contains the entry point"), + notes: str = Field(description="The notes for this entry point", default = "")): """ add new notes to an entry point. """ @@ -536,14 +450,12 @@ def add_entry_point_notes( @mcp.tool() -def store_new_user_action( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the user action belongs to"), - file: str = Field(description="The file that contains the user action"), - line: int = Field(description="The file line that contains the user action"), - notes: str = Field(description="New notes for this user action", default=""), -): +def store_new_user_action(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the user action belongs to"), + file: str = Field(description="The file that contains the user action"), + line: int = Field(description="The file line that contains the user action"), + notes: str = Field(description="New notes for this user action", default = "")): """ Stores a new user action in a component to the database. """ @@ -553,29 +465,23 @@ def store_new_user_action( return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_user_action(repo, app.id, file, line, notes) - @mcp.tool() -def add_user_action_notes( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component where the user action belongs to"), - file: str = Field(description="The file that contains the user action"), - line: str = Field(description="The file line that contains the user action"), - notes: str = Field(description="The notes for user action", default=""), -): +def add_user_action_notes(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component where the user action belongs to"), + file: str = Field(description="The file that contains the user action"), + line: str = Field(description="The file line that contains the user action"), + notes: str = Field(description="The notes for user action", default = "")): repo = process_repo(owner, repo) app = backend.get_app(repo, location) if not app: return f"Error: No component exists in repo: {repo} and location {location}" return backend.store_new_user_action(repo, app.id, file, line, notes, True) - @mcp.tool() -def get_component( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component"), -): +def get_component(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component")): """ Get a component from the database """ @@ -585,112 +491,85 @@ def get_component( return f"Error: No component exists in repo: {repo} and location {location}" return json.dumps(app_to_dict(app)) - @mcp.tool() -def get_components( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_components(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get components from the repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_apps(repo)) - @mcp.tool() -def get_entry_points( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component"), -): +def get_entry_points(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component")): """ Get all the entry points of a component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_entries(repo, location)) - @mcp.tool() -def get_entry_points_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_entry_points_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get all entry points of an repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_entries_for_repo(repo)) - @mcp.tool() -def get_web_entry_points_component( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), -): +def get_web_entry_points_component(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component")): """ Get all web entry points for a component """ repo = process_repo(owner, repo) return json.dumps(backend.get_web_entries(repo, component_id)) - @mcp.tool() -def get_web_entry_points_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_web_entry_points_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get all web entry points of an repo """ repo = process_repo(owner, repo) return json.dumps(backend.get_web_entries_for_repo(repo)) - @mcp.tool() -def get_user_actions( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - location: str = Field(description="The directory of the component"), -): +def get_user_actions(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + location: str = Field(description="The directory of the component")): """ Get all the user actions in a component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_user_actions(repo, location)) - @mcp.tool() -def get_user_actions_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_user_actions_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get all the user actions in a repo. """ repo = process_repo(owner, repo) return json.dumps(backend.get_user_actions_for_repo(repo)) - @mcp.tool() -def get_component_issues( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), -): +def get_component_issues(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component")): """ Get issues for the component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_issues(repo, component_id)) - @mcp.tool() -def get_component_issues_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_component_issues_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get all component issues for the repository. """ @@ -699,113 +578,79 @@ def get_component_issues_for_repo( @mcp.tool() -def get_component_results( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), -): +def get_component_results(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component")): """ Get audit results for the component. """ repo = process_repo(owner, repo) return json.dumps(backend.get_app_audit_results(repo, component_id, None, None)) - @mcp.tool() -def get_component_vulnerable_results( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), -): +def get_component_vulnerable_results(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component")): """ Get audit results for the component that are audited as vulnerable. """ repo = process_repo(owner, repo) - return json.dumps( - backend.get_app_audit_results(repo, component_id, has_non_security_error=None, has_vulnerability=True) - ) - + return json.dumps(backend.get_app_audit_results(repo, component_id, has_non_security_error = None, has_vulnerability = True)) @mcp.tool() -def get_component_potential_results( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), - component_id: int = Field(description="The ID of the component"), -): +def get_component_potential_results(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository"), + component_id: int = Field(description="The ID of the component")): """ Get audit results for the component that are audited as an issue but may not be exploitable. """ repo = process_repo(owner, repo) - return json.dumps( - backend.get_app_audit_results(repo, component_id, has_non_security_error=True, has_vulnerability=None) - ) - + return json.dumps(backend.get_app_audit_results(repo, component_id, has_non_security_error = True, has_vulnerability = None)) @mcp.tool() -def get_audit_results_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get audit results for the repo. """ repo = process_repo(owner, repo) - return json.dumps( - backend.get_app_audit_results(repo, component_id=None, has_non_security_error=None, has_vulnerability=None) - ) - + return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = None, has_vulnerability = None)) @mcp.tool() -def get_vulnerable_audit_results_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_vulnerable_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get audit results for the repo that are audited as vulnerable. """ repo = process_repo(owner, repo) - return json.dumps( - backend.get_app_audit_results(repo, component_id=None, has_non_security_error=None, has_vulnerability=True) - ) - + return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = None, has_vulnerability = True)) @mcp.tool() -def get_potential_audit_results_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def get_potential_audit_results_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ Get audit results for the repo that are potential issues but may not be exploitable. """ repo = process_repo(owner, repo) - return json.dumps( - backend.get_app_audit_results(repo, component_id=None, has_non_security_error=True, has_vulnerability=None) - ) - + return json.dumps(backend.get_app_audit_results(repo, component_id = None, has_non_security_error = True, has_vulnerability = None)) @mcp.tool() -def clear_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def clear_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ clear all results for repo. """ repo = process_repo(owner, repo) return backend.clear_repo(repo) - @mcp.tool() -def clear_component_issues_for_repo( - owner: str = Field(description="The owner of the GitHub repository"), - repo: str = Field(description="The name of the GitHub repository"), -): +def clear_component_issues_for_repo(owner: str = Field(description="The owner of the GitHub repository"), + repo: str = Field(description="The name of the GitHub repository")): """ clear all results for repo. """ repo = process_repo(owner, repo) return backend.clear_repo_issues(repo) - if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/repo_context_models.py b/src/seclab_taskflows/mcp_servers/repo_context_models.py index 7dd08cc..cd3d8a2 100644 --- a/src/seclab_taskflows/mcp_servers/repo_context_models.py +++ b/src/seclab_taskflows/mcp_servers/repo_context_models.py @@ -5,67 +5,56 @@ from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship from typing import Optional - class Base(DeclarativeBase): pass - class Application(Base): - __tablename__ = "application" + __tablename__ = 'application' id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] location: Mapped[str] notes: Mapped[str] = mapped_column(Text) is_app: Mapped[bool] = mapped_column(nullable=True) - is_library: Mapped[bool] = mapped_column(nullable=True) + is_library: Mapped[bool] = mapped_column(nullable = True) def __repr__(self): - return ( - f"" - ) - + return (f"") class ApplicationIssue(Base): - __tablename__ = "application_issue" + __tablename__ = 'application_issue' id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] - component_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) + component_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) issue_type: Mapped[str] = mapped_column(Text) notes: Mapped[str] = mapped_column(Text) def __repr__(self): - return ( - f"" - ) - + return (f"") class AuditResult(Base): - __tablename__ = "audit_result" - id: Mapped[int] = mapped_column(primary_key=True) + __tablename__ = 'audit_result' + id: Mapped[int] = mapped_column(primary_key = True) repo: Mapped[str] - component_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) + component_id = Column(Integer, ForeignKey('application.id', ondelete = 'CASCADE')) issue_type: Mapped[str] = mapped_column(Text) - issue_id = Column(Integer, ForeignKey("application_issue.id", ondelete="CASCADE")) + issue_id = Column(Integer, ForeignKey('application_issue.id', ondelete = 'CASCADE')) has_vulnerability: Mapped[bool] has_non_security_error: Mapped[bool] notes: Mapped[str] = mapped_column(Text) def __repr__(self): - return ( - f"" - ) - + return (f"") class EntryPoint(Base): - __tablename__ = "entry_point" + __tablename__ = 'entry_point' id: Mapped[int] = mapped_column(primary_key=True) - app_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) + app_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) file: Mapped[str] user_input: Mapped[str] line: Mapped[int] @@ -73,19 +62,16 @@ class EntryPoint(Base): repo: Mapped[str] def __repr__(self): - return ( - f"" - ) - - -class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /info) with additional properties - __tablename__ = "web_entry_point" + return (f"") + +class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /info) with additional properties + __tablename__ = 'web_entry_point' id: Mapped[int] = mapped_column(primary_key=True) - entry_point_id = Column(Integer, ForeignKey("entry_point.id", ondelete="CASCADE")) - method: Mapped[str] # GET, POST, etc - path: Mapped[str] # /info + entry_point_id = Column(Integer, ForeignKey('entry_point.id', ondelete='CASCADE')) + method: Mapped[str] # GET, POST, etc + path: Mapped[str] # /info component: Mapped[int] auth: Mapped[str] middleware: Mapped[str] @@ -94,20 +80,17 @@ class WebEntryPoint(Base): # an entrypoint of a web application (such as GET /i repo: Mapped[str] def __repr__(self): - return ( - f"" - ) - + return (f"") class UserAction(Base): - __tablename__ = "user_action" + __tablename__ = 'user_action' id: Mapped[int] = mapped_column(primary_key=True) repo: Mapped[str] - app_id = Column(Integer, ForeignKey("application.id", ondelete="CASCADE")) + app_id = Column(Integer, ForeignKey('application.id', ondelete='CASCADE')) file: Mapped[str] line: Mapped[int] notes: Mapped[str] = mapped_column(Text) diff --git a/src/seclab_taskflows/mcp_servers/report_alert_state.py b/src/seclab_taskflows/mcp_servers/report_alert_state.py index 074c121..89fb6fa 100644 --- a/src/seclab_taskflows/mcp_servers/report_alert_state.py +++ b/src/seclab_taskflows/mcp_servers/report_alert_state.py @@ -16,12 +16,11 @@ logging.basicConfig( level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(message)s", - filename=log_file_name("mcp_report_alert_state.log"), - filemode="a", + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file_name('mcp_report_alert_state.log'), + filemode='a' ) - def result_to_dict(result): return { "canonical_id": result.canonical_id, @@ -32,10 +31,9 @@ def result_to_dict(result): "location": result.location, "result": result.result, "created": result.created, - "valid": result.valid, + "valid": result.valid } - def flow_to_dict(flow): return { "id": flow.id, @@ -43,10 +41,9 @@ def flow_to_dict(flow): "flow_data": flow.flow_data, "repo": flow.repo.lower(), "prev": flow.prev, - "next": flow.next, + "next": flow.next } - def remove_line_numbers(location: str) -> str: """ Remove line numbers from a location string. @@ -54,38 +51,31 @@ def remove_line_numbers(location: str) -> str: """ if not location: return location - parts = location.split(":") + parts = location.split(':') if len(parts) < 4: # Ensure there are enough parts to remove line numbers return location # Keep the first part (file path) and the last two parts (col:col) - return ":".join(parts[:-4]) - + return ':'.join(parts[:-4]) -MEMORY = mcp_data_dir("seclab-taskflows", "report_alert_state", "ALERT_RESULTS_DIR") +MEMORY = mcp_data_dir('seclab-taskflows', 'report_alert_state', 'ALERT_RESULTS_DIR') class ReportAlertStateBackend: def __init__(self, memcache_state_dir: str): self.memcache_state_dir = memcache_state_dir - self.location_pattern = r"^([a-zA-Z]+)(:\d+){4}$" + self.location_pattern = r'^([a-zA-Z]+)(:\d+){4}$' if not Path(self.memcache_state_dir).exists(): - db_dir = "sqlite://" + db_dir = 'sqlite://' else: - db_dir = f"sqlite:///{self.memcache_state_dir}/alert_results.db" + db_dir = f'sqlite:///{self.memcache_state_dir}/alert_results.db' self.engine = create_engine(db_dir, echo=False) Base.metadata.create_all(self.engine, tables=[AlertResults.__table__, AlertFlowGraph.__table__]) - def set_alert_result( - self, alert_id: str, repo: str, rule: str, language: str, location: str, result: str, created: str - ) -> str: + def set_alert_result(self, alert_id: str, repo: str, rule: str, language: str, location: str, result: str, created: str) -> str: if not result: result = "" with Session(self.engine) as session: - existing = ( - session.query(AlertResults) - .filter_by(alert_id=alert_id, repo=repo, rule=rule, language=language) - .first() - ) + existing = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo, rule=rule, language=language).first() if existing: existing.result += result else: @@ -98,7 +88,7 @@ def set_alert_result( result=result, created=created, valid=True, - completed=False, + completed=False ) session.add(new_alert) session.commit() @@ -142,32 +132,30 @@ def set_alert_completed(self, alert_id: str, repo: str, completed: bool) -> str: def get_completed_alerts(self, rule: str, repo: str = None) -> Any: """Get all incomplete alerts in a repository.""" - filter_params = {"completed": True} + filter_params = {'completed' : True} if repo: - filter_params["repo"] = repo + filter_params['repo'] = repo if rule: - filter_params["rule"] = rule + filter_params['rule'] = rule with Session(self.engine) as session: results = [result_to_dict(r) for r in session.query(AlertResults).filter_by(**filter_params).all()] return results def clear_completed_alerts(self, repo: str = None, rule: str = None) -> str: """Clear all completed alerts in a repository.""" - filter_params = {"completed": True} + filter_params = {'completed': True} if repo: - filter_params["repo"] = repo + filter_params['repo'] = repo if rule: - filter_params["rule"] = rule + filter_params['rule'] = rule with Session(self.engine) as session: session.query(AlertResults).filter_by(**filter_params).delete() session.commit() - return "Cleared completed alerts with repo: {}, rule: {}".format( - repo if repo else "all", rule if rule else "all" - ) + return "Cleared completed alerts with repo: {}, rule: {}".format(repo if repo else "all", rule if rule else "all") def get_alert_results(self, alert_id: str, repo: str) -> str: with Session(self.engine) as session: - result = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo).first() + result = session.query(AlertResults).filter_by(alert_id=alert_id, repo = repo).first() if not result: return "No results found." return "Analysis results for alert ID {} in repo {}: {}".format(alert_id, repo, result.result) @@ -180,27 +168,26 @@ def get_alert_by_canonical_id(self, canonical_id: int) -> Any: return result_to_dict(result) def get_alert_results_by_rule(self, rule: str, repo: str = None, valid: bool = None) -> Any: - filter_params = {"rule": rule} + filter_params = {'rule': rule} if repo: - filter_params["repo"] = repo + filter_params['repo'] = repo if valid is not None: - filter_params["valid"] = valid + filter_params['valid'] = valid with Session(self.engine) as session: results = [result_to_dict(r) for r in session.query(AlertResults).filter_by(**filter_params).all()] return results - def delete_alert_result(self, alert_id: str, repo: str) -> str: with Session(self.engine) as session: result = session.query(AlertResults).filter_by(alert_id=alert_id, repo=repo).delete() session.commit() return f"Deleted alert result for {alert_id} in {repo}" - def clear_alert_results(self, repo: str = None, rule: str = None) -> str: + def clear_alert_results(self, repo : str = None, rule: str = None) -> str: filter_params = {} if repo: - filter_params["repo"] = repo + filter_params['repo'] = repo if rule: - filter_params["rule"] = rule + filter_params['rule'] = rule with Session(self.engine) as session: if not filter_params: session.query(AlertResults).delete() @@ -209,21 +196,22 @@ def clear_alert_results(self, repo: str = None, rule: str = None) -> str: session.commit() return "Cleared alert results with repo: {}, rule: {}".format(repo if repo else "all", rule if rule else "all") - def add_flow_to_alert( - self, canonical_id: int, flow_data: str, repo: str, prev: str = None, next: str = None - ) -> str: + def add_flow_to_alert(self, canonical_id: int, flow_data: str, repo: str, prev: str = None, next: str = None) -> str: """Add a flow graph for a specific alert result.""" with Session(self.engine) as session: flow_graph = AlertFlowGraph( - alert_canonical_id=canonical_id, flow_data=flow_data, repo=repo, prev=prev, next=next, started=False + alert_canonical_id=canonical_id, + flow_data=flow_data, + repo=repo, + prev=prev, + next=next, + started = False ) session.add(flow_graph) session.commit() return f"Added flow graph for alert with canonical ID {canonical_id}" - def batch_add_flow_to_alert( - self, alert_canonical_id: int, flows: list[str], repo: str, prev: str, next: str - ) -> str: + def batch_add_flow_to_alert(self, alert_canonical_id: int, flows: list[str], repo: str, prev: str, next: str) -> str: """Batch add flow graphs for multiple alert results.""" with Session(self.engine) as session: for flow in flows: @@ -233,7 +221,7 @@ def batch_add_flow_to_alert( repo=repo, prev=prev, next=next, - started=False, + started = False ) session.add(flow_graph) session.commit() @@ -262,13 +250,11 @@ def delete_flow_graph_for_alert(self, alert_canonical_id: int) -> str: with Session(self.engine) as session: result = session.query(AlertFlowGraph).filter_by(alert_canonical_id=alert_canonical_id).delete() session.commit() - return ( - f"Deleted flow graph with for alert with canonical iD {id}" if result else "No flow graph found to delete." - ) + return f"Deleted flow graph with for alert with canonical iD {id}" if result else "No flow graph found to delete." def update_all_alert_results_for_flow_graph(self, next: str, repo: str, result: str) -> str: with Session(self.engine) as session: - flow_graphs = session.query(AlertFlowGraph).filter_by(next=next, repo=repo).all() + flow_graphs = session.query(AlertFlowGraph).filter_by(next=next, repo = repo).all() if not flow_graphs: return f"No flow graphs found with next value {next}" alert_canonical_ids = set([fg.alert_canonical_id for fg in flow_graphs]) @@ -293,136 +279,93 @@ def clear_flow_graphs(self) -> str: session.commit() return "Cleared all flow graphs." - mcp = FastMCP("ReportAlertState") backend = ReportAlertStateBackend(MEMORY) - def process_repo(repo): return repo.lower() if repo else None - @mcp.tool() -def create_alert( - alert_id: str, - repo: str, - rule: str, - language: str, - location: str, - result: str = Field(description="The result of the alert analysis", default=""), - created: str = Field(description="The creation time of the alert", default=""), -) -> str: +def create_alert(alert_id: str, repo: str, rule: str, language: str, location: str, + result: str = Field(description="The result of the alert analysis", default=""), + created: str = Field(description = "The creation time of the alert", default="")) -> str: """Create an alert using a specific alert ID in a repository.""" return backend.set_alert_result(alert_id, process_repo(repo), rule, language, location, result, created) - @mcp.tool() def update_alert_result(alert_id: str, repo: str, result: str) -> str: """Update an existing alert result for a specific alert ID in a repository.""" return backend.update_alert_result(alert_id, process_repo(repo), result) - @mcp.tool() def update_alert_result_by_canonical_id(canonical_id: int, result: str) -> str: """Update an existing alert result by canonical ID.""" return backend.update_alert_result_by_canonical_id(canonical_id, result) - @mcp.tool() def set_alert_valid(alert_id: str, repo: str, valid: bool) -> str: """Set the validity of an alert result for a specific alert ID in a repository.""" return backend.set_alert_valid(alert_id, process_repo(repo), valid) - @mcp.tool() def get_alert_results(alert_id: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Get the analysis results for a specific alert ID in a repository.""" return backend.get_alert_results(alert_id, process_repo(repo)) - @mcp.tool() def get_alert_by_canonical_id(canonical_id: int) -> str: """Get alert results by canonical ID.""" return json.dumps(backend.get_alert_by_canonical_id(canonical_id)) - @mcp.tool() -def get_alert_results_by_rule( - rule: str, - repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), -) -> str: +def get_alert_results_by_rule(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: """Get all alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), None)) - @mcp.tool() -def get_valid_alert_results_by_rule( - rule: str, - repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), -) -> str: +def get_valid_alert_results_by_rule(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: """Get all valid alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), True)) - @mcp.tool() -def get_invalid_alert_results( - rule: str, - repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default=None), -) -> str: +def get_invalid_alert_results(rule: str, repo: str = Field(description="Optional repository of the alert in the format of owner/repo", default = None)) -> str: """Get all valid alert results for a specific rule in a repository.""" return json.dumps(backend.get_alert_results_by_rule(rule, process_repo(repo), False)) - @mcp.tool() def set_alert_completed(alert_id: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Set the completion status of an alert result for a specific alert ID in a repository.""" return backend.set_alert_completed(alert_id, process_repo(repo), True) - @mcp.tool() -def get_completed_alerts( - rule: str, repo: str = Field(description="repo in the format owner/repo", default=None) -) -> str: +def get_completed_alerts(rule: str, repo: str = Field(description="repo in the format owner/repo", default = None)) -> str: """Get all complete alerts in a repository.""" results = backend.get_completed_alerts(rule, process_repo(repo)) return json.dumps(results) - @mcp.tool() -def clear_completed_alerts( - repo: str = Field(description="repo in the format owner/repo", default=None), rule: str = None -) -> str: +def clear_completed_alerts(repo: str = Field(description="repo in the format owner/repo", default = None), rule: str = None) -> str: """Clear all completed alerts in a repository.""" return backend.clear_completed_alerts(process_repo(repo), rule) - @mcp.tool() def clear_repo_results(repo: str = Field(description="repo in the format owner/repo")) -> str: """Clear all alert results for a specific repository.""" return backend.clear_alert_results(process_repo(repo), None) - @mcp.tool() -def clear_rule_results(rule: str, repo: str = Field(description="repo in the format owner/repo", default=None)) -> str: +def clear_rule_results(rule: str, repo: str = Field(description="repo in the format owner/repo", default = None)) -> str: """Clear all alert results for a specific rule in a repository.""" return backend.clear_alert_results(process_repo(repo), rule) - @mcp.tool() def clear_alert_results() -> str: """Clear all alert results.""" return backend.clear_alert_results(None, None) - @mcp.tool() -def add_flow_to_alert( - canonical_id: int, - flow_data: str, - repo: str = Field(description="repo in the format owner/repo"), - prev: str = None, - next: str = None, -) -> str: +def add_flow_to_alert(canonical_id: int, flow_data: str, repo: str = Field(description="repo in the format owner/repo"), prev: str = None, next: str = None) -> str: """Add a flow graph for a specific alert result.""" flow_data = remove_line_numbers(flow_data) prev = remove_line_numbers(prev) if prev else None @@ -430,17 +373,13 @@ def add_flow_to_alert( backend.add_flow_to_alert(canonical_id, flow_data, process_repo(repo), prev, next) return f"Added flow graph for alert with canonical ID {canonical_id}" - @mcp.tool() -def batch_add_flow_to_alert( - alert_canonical_id: int, - repo: str = Field(description="The repository name for the alert result in the format owner/repo"), - flows: str = Field(description="A JSON string containing a list of flows to add for the alert result."), - next: str = None, - prev: str = None, -) -> str: +def batch_add_flow_to_alert(alert_canonical_id: int, + repo: str = Field(description="The repository name for the alert result in the format owner/repo"), + flows: str = Field(description="A JSON string containing a list of flows to add for the alert result."), + next: str = None, prev: str = None) -> str: """Batch add a list of paths to flow graphs for a specific alert result.""" - flows_list = flows.split(",") + flows_list = flows.split(',') return backend.batch_add_flow_to_alert(alert_canonical_id, flows_list, process_repo(repo), prev, next) @@ -449,48 +388,39 @@ def get_alert_flow(canonical_id: int) -> str: """Get the flow graph for a specific alert result.""" return json.dumps(backend.get_alert_flow(canonical_id)) - @mcp.tool() def get_all_alert_flows() -> str: """Get all flow graphs for all alert results.""" return json.dumps(backend.get_all_alert_flows()) - @mcp.tool() def get_alert_flows_by_data(flow_data: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Get flow graphs for a specific alert result by repo and flow data.""" flow_data = remove_line_numbers(flow_data) return json.dumps(backend.get_alert_flows_by_data(process_repo(repo), flow_data)) - @mcp.tool() def delete_flow_graph(id: int) -> str: """Delete a flow graph with id.""" return backend.delete_flow_graph(id) - @mcp.tool() def delete_flow_graph_for_alert(alert_canonical_id: int) -> str: """Delete a all flow graphs for an alert with a specific canonical ID.""" return backend.delete_flow_graph_for_alert(alert_canonical_id) - @mcp.tool() -def update_all_alert_results_for_flow_graph( - next: str, result: str, repo: str = Field(description="repo in the format owner/repo") -) -> str: +def update_all_alert_results_for_flow_graph(next: str, result: str, repo: str = Field(description="repo in the format owner/repo")) -> str: """Update all alert results for flow graphs with a specific next value.""" - if not "/" in repo: + if not '/' in repo: return "Invalid repository format. Please provide a repository in the format 'owner/repo'." next = remove_line_numbers(next) if next else None return backend.update_all_alert_results_for_flow_graph(next, process_repo(repo), result) - @mcp.tool() def clear_flow_graphs() -> str: """Clear all flow graphs.""" return backend.clear_flow_graphs() - if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflows/mcp_servers/utils.py b/src/seclab_taskflows/mcp_servers/utils.py index 9e18435..528f9c4 100644 --- a/src/seclab_taskflows/mcp_servers/utils.py +++ b/src/seclab_taskflows/mcp_servers/utils.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT - def process_repo(owner, repo): """ Normalize repository identifier to lowercase format 'owner/repo'. diff --git a/tests/test_00.py b/tests/test_00.py index 0cf5c56..d60b706 100644 --- a/tests/test_00.py +++ b/tests/test_00.py @@ -6,11 +6,9 @@ import pytest import seclab_taskflows - class Test00: def test_nothing(self): assert True - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +if __name__ == '__main__': + pytest.main([__file__, '-v'])