Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/seclab_taskflows/mcp_servers/alert_results_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship
from typing import Optional


class Base(DeclarativeBase):
pass


class AlertResults(Base):
__tablename__ = 'alert_results'
__tablename__ = "alert_results"

canonical_id: Mapped[int] = mapped_column(primary_key=True)
alert_id: Mapped[str]
Expand All @@ -22,25 +24,29 @@ class AlertResults(Base):
valid: Mapped[bool] = mapped_column(nullable=False, default=True)
completed: Mapped[bool] = mapped_column(nullable=False, default=False)

relationship('AlertFlowGraph', cascade='all, delete')
relationship("AlertFlowGraph", cascade="all, delete")

def __repr__(self):
return (f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
f"rule={self.rule}, language={self.language}, location={self.location}, "
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>")
return (
f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
f"rule={self.rule}, language={self.language}, location={self.location}, "
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>"
)


class AlertFlowGraph(Base):
__tablename__ = 'alert_flow_graph'
__tablename__ = "alert_flow_graph"

id: Mapped[int] = mapped_column(primary_key=True)
alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE'))
alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE"))
flow_data: Mapped[str] = mapped_column(Text)
repo: Mapped[str]
prev: Mapped[Optional[str]]
next: Mapped[Optional[str]]
started: Mapped[bool] = mapped_column(nullable=False, default=False)

def __repr__(self):
return (f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>")

return (
f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>"
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped
from typing import Optional


class Base(DeclarativeBase):
pass


class Source(Base):
__tablename__ = 'source'
__tablename__ = "source"

id: Mapped[int] = mapped_column(primary_key=True)
repo: Mapped[str]
Expand All @@ -20,6 +21,8 @@ class Source(Base):
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)

def __repr__(self):
return (f"<Source(id={self.id}, repo={self.repo}, "
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
f"notes={self.notes})>")
return (
f"<Source(id={self.id}, repo={self.repo}, "
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
f"notes={self.notes})>"
)
121 changes: 71 additions & 50 deletions src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from seclab_taskflow_agent.mcp_servers.codeql.client import run_query, _debug_log

from pydantic import Field
#from mcp.server.fastmcp import FastMCP, Context
from fastmcp import FastMCP # use FastMCP 2.0

# from mcp.server.fastmcp import FastMCP, Context
from fastmcp import FastMCP # use FastMCP 2.0
from pathlib import Path
import os
import csv
Expand All @@ -23,22 +24,20 @@

logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
filename=log_file_name('mcp_codeql_python.log'),
filemode='a'
format="%(asctime)s - %(levelname)s - %(message)s",
filename=log_file_name("mcp_codeql_python.log"),
filemode="a",
)

MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR')
CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH')
MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR")
CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH")

mcp = FastMCP("CodeQL-Python")

# tool name -> templated query lookup for supported languages
TEMPLATED_QUERY_PATHS = {
# to add a language, port the templated query pack and add its definition here
'python': {
'remote_sources': 'queries/mcp-python/remote_sources.ql'
}
"python": {"remote_sources": "queries/mcp-python/remote_sources.ql"}
}


Expand All @@ -49,9 +48,10 @@ def source_to_dict(result):
"source_location": result.source_location,
"line": result.line,
"source_type": result.source_type,
"notes": result.notes
"notes": result.notes,
}


def _resolve_query_path(language: str, query: str) -> Path:
global TEMPLATED_QUERY_PATHS
if language not in TEMPLATED_QUERY_PATHS:
Expand All @@ -66,7 +66,7 @@ def _resolve_db_path(relative_db_path: str | Path):
global CODEQL_DBS_BASE_PATH
# path joins will return "/B" if "/A" / "////B" etc. as well
# not windows compatible and probably needs additional hardening
relative_db_path = str(relative_db_path).strip().lstrip('/')
relative_db_path = str(relative_db_path).strip().lstrip("/")
relative_db_path = Path(relative_db_path)
absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve()
if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()):
Expand All @@ -76,36 +76,38 @@ def _resolve_db_path(relative_db_path: str | Path):
raise RuntimeError(f"Error: Database not found at {absolute_path}!")
return str(absolute_path)


# This sqlite database is specifically made for CodeQL for Python MCP.
class CodeqlSqliteBackend:
def __init__(self, memcache_state_dir: str):
self.memcache_state_dir = memcache_state_dir
if not Path(self.memcache_state_dir).exists():
db_dir = 'sqlite://'
db_dir = "sqlite://"
else:
db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db'
db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db"
self.engine = create_engine(db_dir, echo=False)
Base.metadata.create_all(self.engine, tables=[Source.__table__])


def store_new_source(self, repo, source_location, line, source_type, notes, update = False):
def store_new_source(self, repo, source_location, line, source_type, notes, update=False):
with Session(self.engine) as session:
existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first()
existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first()
if existing:
existing.notes = (existing.notes or "") + notes
session.commit()
return f"Updated notes for source at {source_location}, line {line} in {repo}."
else:
if update:
return f"No source exists at repo {repo}, location {source_location}, line {line} to update."
new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes)
new_source = Source(
repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes
)
session.add(new_source)
session.commit()
return f"Added new source for {source_location} in {repo}."

def get_sources(self, repo):
with Session(self.engine) as session:
results = session.query(Source).filter_by(repo = repo).all()
results = session.query(Source).filter_by(repo=repo).all()
sources = [source_to_dict(source) for source in results]
return sources

Expand All @@ -119,8 +121,8 @@ def _csv_parse(raw):
if i == 0:
continue
# col1 has what we care about, but offer flexibility
keys = row[1].split(',')
this_obj = {'description': row[0].format(*row[2:])}
keys = row[1].split(",")
this_obj = {"description": row[0].format(*row[2:])}
for j, k in enumerate(keys):
this_obj[k.strip()] = row[j + 2]
results.append(this_obj)
Expand All @@ -141,27 +143,32 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu
except RuntimeError:
return f"The query {query_name} is not supported for language: {language}"
try:
csv = run_query(Path(__file__).parent.resolve() /
query_path,
database_path,
fmt='csv',
template_values=template_values,
log_stderr=True)
csv = run_query(
Path(__file__).parent.resolve() / query_path,
database_path,
fmt="csv",
template_values=template_values,
log_stderr=True,
)
return _csv_parse(csv)
except Exception as e:
return f"The query {query_name} encountered an error: {e}"


backend = CodeqlSqliteBackend(MEMORY)


@mcp.tool()
def remote_sources(owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
database_path: str = Field(description="The CodeQL database path."),
language: str = Field(description="The language used for the CodeQL database.")):
def remote_sources(
owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
database_path: str = Field(description="The CodeQL database path."),
language: str = Field(description="The language used for the CodeQL database."),
):
"""List all remote sources and their locations in a CodeQL database, then store the results in a database."""

repo = process_repo(owner, repo)
results = _run_query('remote_sources', database_path, language, {})
results = _run_query("remote_sources", database_path, language, {})

# Check if results is an error (list of strings) or valid data (list of dicts)
if isinstance(results, str):
Expand All @@ -172,53 +179,67 @@ def remote_sources(owner: str = Field(description="The owner of the GitHub repos
for result in results:
backend.store_new_source(
repo=repo,
source_location=result.get('location', ''),
source_type=result.get('source', ''),
line=int(result.get('line', '0')),
notes=None, #result.get('description', ''),
update=False
source_location=result.get("location", ""),
source_type=result.get("source", ""),
line=int(result.get("line", "0")),
notes=None, # result.get('description', ''),
update=False,
)
stored_count += 1

return f"Stored {stored_count} remote sources in {repo}."


@mcp.tool()
def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository")):
def fetch_sources(
owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
):
"""
Fetch all sources from the repo
"""
repo = process_repo(owner, repo)
return json.dumps(backend.get_sources(repo))


@mcp.tool()
def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
source_location: str = Field(description="The path to the file"),
line: int = Field(description="The line number of the source"),
notes: str = Field(description="The notes to append to this source")):
def add_source_notes(
owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
source_location: str = Field(description="The path to the file"),
line: int = Field(description="The line number of the source"),
notes: str = Field(description="The notes to append to this source"),
):
"""
Add new notes to an existing source. The notes will be appended to any existing notes.
"""
repo = process_repo(owner, repo)
return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True)
return backend.store_new_source(
repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True
)


@mcp.tool()
def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository")):
def clear_codeql_repo(
owner: str = Field(description="The owner of the GitHub repository"),
repo: str = Field(description="The name of the GitHub repository"),
):
"""
Clear all data for a given repo from the database
"""
repo = process_repo(owner, repo)
with Session(backend.engine) as session:
deleted_sources = session.query(Source).filter_by(repo = repo).delete()
deleted_sources = session.query(Source).filter_by(repo=repo).delete()
session.commit()
return f"Cleared {deleted_sources} sources from repo {repo}."


if __name__ == "__main__":
# Check if codeql/python-all pack is installed, if not install it
if not os.path.isdir('/.codeql/packages/codeql/python-all'):
pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python')
if not os.path.isdir("/.codeql/packages/codeql/python-all"):
pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath(
"mcp-python"
)
print(f"Installing CodeQL pack from {pack_path}")
subprocess.run(["codeql", "pack", "install", pack_path])
mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998)
Loading
Loading