Skip to content

Commit 91156dc

Browse files
authored
feat: add client-side utility for getting OAuth tokens simply (#230)
Our model sucks at correct tool calling but other than things are working. <img width="2029" alt="image" src="https://github.com/user-attachments/assets/65f9e647-73ca-422a-86a8-7a04c479e2da" />
1 parent 3de187f commit 91156dc

File tree

5 files changed

+462
-86
lines changed

5 files changed

+462
-86
lines changed

examples/mcp_agent.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
import logging
3+
from urllib.parse import urlparse
4+
5+
import fire
6+
import httpx
7+
from llama_stack_client import Agent, AgentEventLogger, LlamaStackClient
8+
from llama_stack_client.lib import get_oauth_token_for_mcp_server
9+
from rich import print as rprint
10+
11+
logging.basicConfig(level=logging.INFO)
12+
logger = logging.getLogger(__name__)
13+
14+
15+
import tempfile
16+
from pathlib import Path
17+
18+
TMP_DIR = Path(tempfile.gettempdir()) / "llama-stack"
19+
TMP_DIR.mkdir(parents=True, exist_ok=True)
20+
21+
CACHE_FILE = TMP_DIR / "mcp_tokens.json"
22+
23+
24+
def main(model_id: str, mcp_servers: str = "https://mcp.asana.com/sse", llama_stack_url: str = "http://localhost:8321"):
25+
"""Run an MCP agent with the specified model and servers.
26+
27+
Args:
28+
model_id: The model to use for the agent.
29+
mcp_servers: Comma-separated list of MCP servers to use for the agent.
30+
llama_stack_url: The URL of the Llama Stack server to use.
31+
32+
Examples:
33+
python mcp_agent.py "meta-llama/Llama-4-Scout-17B-16E-Instruct" \
34+
-m "https://mcp.asana.com/sse" \
35+
-l "http://localhost:8321"
36+
"""
37+
client = LlamaStackClient(base_url=llama_stack_url)
38+
if not check_model_exists(client, model_id):
39+
return
40+
41+
servers = [s.strip() for s in mcp_servers.split(",")]
42+
mcp_headers = get_and_cache_mcp_headers(servers)
43+
44+
toolgroup_ids = []
45+
for server in servers:
46+
# we cannot use "/" in the toolgroup_id because we have some tech debt from earlier which uses
47+
# "/" as a separator for toolgroup_id and tool_name. We should fix this in the future.
48+
group_id = urlparse(server).netloc
49+
toolgroup_ids.append(group_id)
50+
client.toolgroups.register(
51+
toolgroup_id=group_id, mcp_endpoint=dict(uri=server), provider_id="model-context-protocol"
52+
)
53+
54+
agent = Agent(
55+
client=client,
56+
model=model_id,
57+
instructions="You are a helpful assistant who can use tools when necessary to answer questions.",
58+
tools=toolgroup_ids,
59+
extra_headers={
60+
"X-LlamaStack-Provider-Data": json.dumps(
61+
{
62+
"mcp_headers": mcp_headers,
63+
}
64+
),
65+
},
66+
)
67+
68+
session_id = agent.create_session("test-session")
69+
70+
while True:
71+
user_input = input("Enter a question: ")
72+
if user_input.lower() in ("q", "quit", "exit", "bye", ""):
73+
print("Exiting...")
74+
break
75+
response = agent.create_turn(
76+
session_id=session_id,
77+
messages=[{"role": "user", "content": user_input}],
78+
stream=True,
79+
)
80+
for log in AgentEventLogger().log(response):
81+
log.print()
82+
83+
84+
def check_model_exists(client: LlamaStackClient, model_id: str) -> bool:
85+
models = [m for m in client.models.list() if m.model_type == "llm"]
86+
if model_id not in [m.identifier for m in models]:
87+
rprint(f"[red]Model {model_id} not found[/red]")
88+
rprint("[yellow]Available models:[/yellow]")
89+
for model in models:
90+
rprint(f" - {model.identifier}")
91+
return False
92+
return True
93+
94+
95+
def get_and_cache_mcp_headers(servers: list[str]) -> dict[str, dict[str, str]]:
96+
mcp_headers = {}
97+
98+
logger.info(f"Using cache file: {CACHE_FILE} for MCP tokens")
99+
tokens = {}
100+
if CACHE_FILE.exists():
101+
with open(CACHE_FILE, "r") as f:
102+
tokens = json.load(f)
103+
for server, token in tokens.items():
104+
mcp_headers[server] = {
105+
"Authorization": f"Bearer {token}",
106+
}
107+
108+
for server in servers:
109+
with httpx.Client() as http_client:
110+
headers = mcp_headers.get(server, {})
111+
try:
112+
response = http_client.get(server, headers=headers, timeout=1.0)
113+
except httpx.TimeoutException:
114+
# timeout means success since we did not get an immediate 40X
115+
continue
116+
117+
if response.status_code in (401, 403):
118+
logger.info(f"Server {server} requires authentication, getting token")
119+
token = get_oauth_token_for_mcp_server(server)
120+
if not token:
121+
logger.error(f"No token obtained for {server}")
122+
return
123+
124+
tokens[server] = token
125+
mcp_headers[server] = {
126+
"Authorization": f"Bearer {token}",
127+
}
128+
129+
with open(CACHE_FILE, "w") as f:
130+
json.dump(tokens, f, indent=2)
131+
132+
return mcp_headers
133+
134+
135+
if __name__ == "__main__":
136+
fire.Fire(main)

pyproject.toml

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ version = "0.2.7"
44
description = "The official Python library for the llama-stack-client API"
55
dynamic = ["readme"]
66
license = "Apache-2.0"
7-
authors = [
8-
{ name = "Llama Stack Client", email = "dev-feedback@llama-stack-client.com" },
9-
]
7+
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
108
dependencies = [
119
"httpx>=0.23.0, <1",
1210
"pydantic>=1.9.0, <3",
@@ -48,52 +46,6 @@ Repository = "https://github.com/meta-llama/llama-stack-client-python"
4846

4947

5048

51-
[tool.rye]
52-
managed = true
53-
# version pins are in requirements-dev.lock
54-
dev-dependencies = [
55-
"pyright>=1.1.359",
56-
"mypy",
57-
"respx",
58-
"pytest",
59-
"pytest-asyncio",
60-
"ruff",
61-
"time-machine",
62-
"nox",
63-
"dirty-equals>=0.6.0",
64-
"importlib-metadata>=6.7.0",
65-
"rich>=13.7.1",
66-
]
67-
68-
[tool.rye.scripts]
69-
format = { chain = [
70-
"format:ruff",
71-
"format:docs",
72-
"fix:ruff",
73-
]}
74-
"format:black" = "black ."
75-
"format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md"
76-
"format:ruff" = "ruff format"
77-
"format:isort" = "isort ."
78-
79-
"lint" = { chain = [
80-
"check:ruff",
81-
"typecheck",
82-
"check:importable",
83-
]}
84-
"check:ruff" = "ruff check ."
85-
"fix:ruff" = "ruff check --fix ."
86-
87-
"check:importable" = "python -c 'import llama_stack_client'"
88-
89-
typecheck = { chain = [
90-
"typecheck:pyright",
91-
"typecheck:mypy"
92-
]}
93-
"typecheck:pyright" = "pyright"
94-
"typecheck:verify-types" = "pyright --verifytypes llama_stack_client --ignoreexternal"
95-
"typecheck:mypy" = "mypy ."
96-
9749
[build-system]
9850
requires = ["hatchling", "hatch-fancy-pypi-readme"]
9951
build-backend = "hatchling.build"
@@ -132,37 +84,6 @@ path = "README.md"
13284
pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)'
13385
replacement = '[\1](https://github.com/meta-llama/llama-stack-client-python/tree/main/\g<2>)'
13486

135-
[tool.black]
136-
line-length = 120
137-
138-
[tool.pytest.ini_options]
139-
testpaths = ["tests"]
140-
addopts = "--tb=short"
141-
xfail_strict = true
142-
asyncio_mode = "auto"
143-
filterwarnings = [
144-
"error"
145-
]
146-
147-
[tool.pyright]
148-
# this enables practically every flag given by pyright.
149-
# there are a couple of flags that are still disabled by
150-
# default in strict mode as they are experimental and niche.
151-
typeCheckingMode = "strict"
152-
pythonVersion = "3.7"
153-
154-
exclude = [
155-
"_dev",
156-
".venv",
157-
".nox",
158-
]
159-
160-
reportImplicitOverride = true
161-
162-
reportImportCycles = false
163-
reportPrivateUsage = false
164-
165-
16687
[tool.ruff]
16788
line-length = 120
16889
output-format = "grouped"

src/llama_stack_client/lib/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
7+
from .tools.mcp_oauth import get_oauth_token_for_mcp_server
8+
9+
__all__ = ["get_oauth_token_for_mcp_server"]

0 commit comments

Comments
 (0)