Skip to content

Commit 35c4c6b

Browse files
committed
feat: add client-side utility for getting OAuth tokens simply
1 parent 3de187f commit 35c4c6b

File tree

4 files changed

+432
-80
lines changed

4 files changed

+432
-80
lines changed

examples/mcp_agent.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import json
2+
import logging
3+
4+
import fire
5+
import httpx
6+
from llama_stack_client import Agent, AgentEventLogger, LlamaStackClient
7+
from llama_stack_client.lib import get_oauth_token_for_mcp_server
8+
from rich import print as rprint
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
13+
14+
import tempfile
15+
from pathlib import Path
16+
17+
TMP_DIR = Path(tempfile.gettempdir()) / "llama-stack"
18+
TMP_DIR.mkdir(parents=True, exist_ok=True)
19+
20+
CACHE_FILE = TMP_DIR / "mcp_tokens.json"
21+
22+
23+
def main(model_id: str, mcp_servers: str = "https://mcp.asana.com/sse", llama_stack_url: str = "http://localhost:8321"):
24+
"""Run an MCP agent with the specified model and servers.
25+
26+
Args:
27+
model_id: The model to use for the agent.
28+
mcp_servers: Comma-separated list of MCP servers to use for the agent.
29+
llama_stack_url: The URL of the Llama Stack server to use.
30+
31+
Examples:
32+
python mcp_agent.py "meta-llama/Llama-4-Scout-17B-16E-Instruct" \
33+
-m "https://mcp.asana.com/sse" \
34+
-l "http://localhost:8321"
35+
"""
36+
client = LlamaStackClient(base_url=llama_stack_url)
37+
if not check_model_exists(client, model_id):
38+
return
39+
40+
servers = [s.strip() for s in mcp_servers.split(",")]
41+
mcp_headers = get_and_cache_mcp_headers(servers)
42+
43+
for server in servers:
44+
client.toolgroups.register(
45+
toolgroup_id=server, mcp_endpoint=dict(uri=server), provider_id="model-context-protocol"
46+
)
47+
48+
agent = Agent(
49+
client=client,
50+
model=model_id,
51+
instructions="You are a helpful assistant who can use tools when necessary to answer questions.",
52+
tools=servers,
53+
)
54+
55+
session_id = agent.create_session("test-session")
56+
57+
while True:
58+
user_input = input("Enter a question: ")
59+
if user_input.lower() in ("q", "quit", "exit", "bye", ""):
60+
print("Exiting...")
61+
break
62+
response = agent.create_turn(
63+
session_id=session_id,
64+
messages=[{"role": "user", "content": user_input}],
65+
stream=True,
66+
extra_headers={
67+
"X-LlamaStack-Provider-Data": json.dumps(
68+
{
69+
"mcp_headers": mcp_headers,
70+
}
71+
),
72+
},
73+
)
74+
for log in AgentEventLogger().log(response):
75+
log.print()
76+
77+
78+
def check_model_exists(client: LlamaStackClient, model_id: str) -> bool:
79+
models = [m for m in client.models.list() if m.model_type == "llm"]
80+
if model_id not in [m.identifier for m in models]:
81+
rprint(f"[red]Model {model_id} not found[/red]")
82+
rprint("[yellow]Available models:[/yellow]")
83+
for model in models:
84+
rprint(f" - {model.identifier}")
85+
return False
86+
return True
87+
88+
89+
def get_and_cache_mcp_headers(servers: list[str]) -> dict[str, dict[str, str]]:
90+
mcp_headers = {}
91+
92+
logger.info(f"Using cache file: {CACHE_FILE} for MCP tokens")
93+
tokens = {}
94+
if CACHE_FILE.exists():
95+
with open(CACHE_FILE, "r") as f:
96+
tokens = json.load(f)
97+
for server, token in tokens.items():
98+
mcp_headers[server] = {
99+
"Authorization": f"Bearer {token}",
100+
}
101+
102+
for server in servers:
103+
with httpx.Client() as http_client:
104+
headers = mcp_headers.get(server, {})
105+
try:
106+
response = http_client.get(server, headers=headers, timeout=1.0)
107+
except httpx.TimeoutException:
108+
# timeout means success since we did not get an immediate 40X
109+
continue
110+
111+
if response.status_code in (401, 403):
112+
logger.info(f"Server {server} requires authentication, getting token")
113+
token = get_oauth_token_for_mcp_server(server)
114+
if not token:
115+
logger.error(f"No token obtained for {server}")
116+
return
117+
118+
tokens[server] = token
119+
mcp_headers[server] = {
120+
"Authorization": f"Bearer {token}",
121+
}
122+
123+
with open(CACHE_FILE, "w") as f:
124+
json.dump(tokens, f, indent=2)
125+
126+
return mcp_headers
127+
128+
129+
if __name__ == "__main__":
130+
fire.Fire(main)

pyproject.toml

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ version = "0.2.7"
44
description = "The official Python library for the llama-stack-client API"
55
dynamic = ["readme"]
66
license = "Apache-2.0"
7-
authors = [
8-
{ name = "Llama Stack Client", email = "dev-feedback@llama-stack-client.com" },
9-
]
7+
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
108
dependencies = [
119
"httpx>=0.23.0, <1",
1210
"pydantic>=1.9.0, <3",
@@ -48,52 +46,6 @@ Repository = "https://github.com/meta-llama/llama-stack-client-python"
4846

4947

5048

51-
[tool.rye]
52-
managed = true
53-
# version pins are in requirements-dev.lock
54-
dev-dependencies = [
55-
"pyright>=1.1.359",
56-
"mypy",
57-
"respx",
58-
"pytest",
59-
"pytest-asyncio",
60-
"ruff",
61-
"time-machine",
62-
"nox",
63-
"dirty-equals>=0.6.0",
64-
"importlib-metadata>=6.7.0",
65-
"rich>=13.7.1",
66-
]
67-
68-
[tool.rye.scripts]
69-
format = { chain = [
70-
"format:ruff",
71-
"format:docs",
72-
"fix:ruff",
73-
]}
74-
"format:black" = "black ."
75-
"format:docs" = "python scripts/utils/ruffen-docs.py README.md api.md"
76-
"format:ruff" = "ruff format"
77-
"format:isort" = "isort ."
78-
79-
"lint" = { chain = [
80-
"check:ruff",
81-
"typecheck",
82-
"check:importable",
83-
]}
84-
"check:ruff" = "ruff check ."
85-
"fix:ruff" = "ruff check --fix ."
86-
87-
"check:importable" = "python -c 'import llama_stack_client'"
88-
89-
typecheck = { chain = [
90-
"typecheck:pyright",
91-
"typecheck:mypy"
92-
]}
93-
"typecheck:pyright" = "pyright"
94-
"typecheck:verify-types" = "pyright --verifytypes llama_stack_client --ignoreexternal"
95-
"typecheck:mypy" = "mypy ."
96-
9749
[build-system]
9850
requires = ["hatchling", "hatch-fancy-pypi-readme"]
9951
build-backend = "hatchling.build"
@@ -132,37 +84,6 @@ path = "README.md"
13284
pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)'
13385
replacement = '[\1](https://github.com/meta-llama/llama-stack-client-python/tree/main/\g<2>)'
13486

135-
[tool.black]
136-
line-length = 120
137-
138-
[tool.pytest.ini_options]
139-
testpaths = ["tests"]
140-
addopts = "--tb=short"
141-
xfail_strict = true
142-
asyncio_mode = "auto"
143-
filterwarnings = [
144-
"error"
145-
]
146-
147-
[tool.pyright]
148-
# this enables practically every flag given by pyright.
149-
# there are a couple of flags that are still disabled by
150-
# default in strict mode as they are experimental and niche.
151-
typeCheckingMode = "strict"
152-
pythonVersion = "3.7"
153-
154-
exclude = [
155-
"_dev",
156-
".venv",
157-
".nox",
158-
]
159-
160-
reportImplicitOverride = true
161-
162-
reportImportCycles = false
163-
reportPrivateUsage = false
164-
165-
16687
[tool.ruff]
16788
line-length = 120
16889
output-format = "grouped"

src/llama_stack_client/lib/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
7+
from .tools.mcp_oauth import get_oauth_token_for_mcp_server
8+
9+
__all__ = ["get_oauth_token_for_mcp_server"]

0 commit comments

Comments
 (0)