Skip to content

Commit 3149be8

Browse files
authored
Fix probes=None server incompatibility (#3543)
This fixes server compatibility with clients prior to 0.20.8 that don't support `probes=None`, by replacing `None` with `[]` in responses for older clients. This incompatibility could be observed when there are both new and old clients in the same project, so old clients would fail when viewing runs submitted by new clients.
1 parent 78678c1 commit 3149be8

File tree

3 files changed

+271
-27
lines changed

3 files changed

+271
-27
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Optional
2+
3+
from packaging.version import Version
4+
5+
from dstack._internal.core.models.configurations import ServiceConfiguration
6+
from dstack._internal.core.models.runs import Run, RunPlan, RunSpec
7+
from dstack._internal.server.compatibility.common import patch_offers_list
8+
9+
10+
def patch_run_plan(run_plan: RunPlan, client_version: Optional[Version]) -> None:
11+
if client_version is None:
12+
return
13+
patch_run_spec(run_plan.run_spec, client_version)
14+
if run_plan.effective_run_spec is not None:
15+
patch_run_spec(run_plan.effective_run_spec, client_version)
16+
if run_plan.current_resource is not None:
17+
patch_run(run_plan.current_resource, client_version)
18+
for job_plan in run_plan.job_plans:
19+
patch_offers_list(job_plan.offers, client_version)
20+
21+
22+
def patch_run(run: Run, client_version: Optional[Version]) -> None:
23+
if client_version is None:
24+
return
25+
patch_run_spec(run.run_spec, client_version)
26+
27+
28+
def patch_run_spec(run_spec: RunSpec, client_version: Optional[Version]) -> None:
29+
if client_version is None:
30+
return
31+
# Clients prior to 0.20.8 do not support probes = None
32+
if client_version < Version("0.20.8") and isinstance(
33+
run_spec.configuration, ServiceConfiguration
34+
):
35+
if run_spec.configuration.probes is None:
36+
run_spec.configuration.probes = []

src/dstack/_internal/server/routers/runs.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dstack._internal.core.errors import ResourceNotExistsError
88
from dstack._internal.core.models.runs import Run, RunPlan
9-
from dstack._internal.server.compatibility.common import patch_offers_list
9+
from dstack._internal.server.compatibility.runs import patch_run, patch_run_plan
1010
from dstack._internal.server.db import get_session
1111
from dstack._internal.server.models import ProjectModel, UserModel
1212
from dstack._internal.server.schemas.runs import (
@@ -52,6 +52,7 @@ async def list_runs(
5252
body: ListRunsRequest,
5353
session: AsyncSession = Depends(get_session),
5454
user: UserModel = Depends(Authenticated()),
55+
client_version: Optional[Version] = Depends(get_client_version),
5556
):
5657
"""
5758
Returns all runs visible to user sorted by descending `submitted_at`.
@@ -62,22 +63,23 @@ async def list_runs(
6263
The results are paginated. To get the next page, pass `submitted_at` and `id` of
6364
the last run from the previous page as `prev_submitted_at` and `prev_run_id`.
6465
"""
65-
return CustomORJSONResponse(
66-
await runs.list_user_runs(
67-
session=session,
68-
user=user,
69-
project_name=body.project_name,
70-
repo_id=body.repo_id,
71-
username=body.username,
72-
only_active=body.only_active,
73-
include_jobs=body.include_jobs,
74-
job_submissions_limit=body.job_submissions_limit,
75-
prev_submitted_at=body.prev_submitted_at,
76-
prev_run_id=body.prev_run_id,
77-
limit=body.limit,
78-
ascending=body.ascending,
79-
)
66+
run_list = await runs.list_user_runs(
67+
session=session,
68+
user=user,
69+
project_name=body.project_name,
70+
repo_id=body.repo_id,
71+
username=body.username,
72+
only_active=body.only_active,
73+
include_jobs=body.include_jobs,
74+
job_submissions_limit=body.job_submissions_limit,
75+
prev_submitted_at=body.prev_submitted_at,
76+
prev_run_id=body.prev_run_id,
77+
limit=body.limit,
78+
ascending=body.ascending,
8079
)
80+
for run in run_list:
81+
patch_run(run, client_version)
82+
return CustomORJSONResponse(run_list)
8183

8284

8385
@project_router.post(
@@ -88,6 +90,7 @@ async def get_run(
8890
body: GetRunRequest,
8991
session: AsyncSession = Depends(get_session),
9092
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
93+
client_version: Optional[Version] = Depends(get_client_version),
9194
):
9295
"""
9396
Returns a run given `run_name` or `id`.
@@ -103,6 +106,7 @@ async def get_run(
103106
)
104107
if run is None:
105108
raise ResourceNotExistsError("Run not found")
109+
patch_run(run, client_version)
106110
return CustomORJSONResponse(run)
107111

108112

@@ -132,8 +136,7 @@ async def get_plan(
132136
max_offers=body.max_offers,
133137
legacy_repo_dir=legacy_repo_dir,
134138
)
135-
for job_plan in run_plan.job_plans:
136-
patch_offers_list(job_plan.offers, client_version)
139+
patch_run_plan(run_plan, client_version)
137140
return CustomORJSONResponse(run_plan)
138141

139142

@@ -146,6 +149,7 @@ async def apply_plan(
146149
session: Annotated[AsyncSession, Depends(get_session)],
147150
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
148151
legacy_repo_dir: Annotated[bool, Depends(use_legacy_repo_dir)],
152+
client_version: Annotated[Optional[Version], Depends(get_client_version)],
149153
):
150154
"""
151155
Creates a new run or updates an existing run.
@@ -156,16 +160,16 @@ async def apply_plan(
156160
user, project = user_project
157161
if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub:
158162
await users.refresh_ssh_key(session=session, actor=user)
159-
return CustomORJSONResponse(
160-
await runs.apply_plan(
161-
session=session,
162-
user=user,
163-
project=project,
164-
plan=body.plan,
165-
force=body.force,
166-
legacy_repo_dir=legacy_repo_dir,
167-
)
163+
run = await runs.apply_plan(
164+
session=session,
165+
user=user,
166+
project=project,
167+
plan=body.plan,
168+
force=body.force,
169+
legacy_repo_dir=legacy_repo_dir,
168170
)
171+
patch_run(run, client_version)
172+
return CustomORJSONResponse(run)
169173

170174

171175
@project_router.post("/stop")

src/tests/_internal/server/routers/test_runs.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,53 @@ async def test_limits_job_submissions(
941941
},
942942
]
943943

944+
@pytest.mark.asyncio
945+
@pytest.mark.parametrize(
946+
"client_version,expected_probes",
947+
[
948+
("0.20.7", []),
949+
("0.20.8", None),
950+
(None, None),
951+
],
952+
)
953+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
954+
async def test_patches_service_configuration_probes_for_old_clients(
955+
self,
956+
test_db,
957+
session: AsyncSession,
958+
client: AsyncClient,
959+
client_version: Optional[str],
960+
expected_probes: Optional[list],
961+
) -> None:
962+
user = await create_user(session=session)
963+
project = await create_project(session=session, owner=user)
964+
repo = await create_repo(session=session, project_id=project.id)
965+
966+
service_conf = ServiceConfiguration(
967+
commands=["echo hello"],
968+
port=80,
969+
probes=None, # This should be patched to [] for clients prior to 0.20.8
970+
)
971+
run_spec = get_run_spec(
972+
configuration=service_conf,
973+
repo_id=repo.name,
974+
)
975+
await create_run(session=session, project=project, repo=repo, user=user, run_spec=run_spec)
976+
977+
headers = get_auth_headers(user.token)
978+
if client_version is not None:
979+
headers["X-API-Version"] = client_version
980+
response = await client.post(
981+
"/api/runs/list",
982+
headers=headers,
983+
json={"project_name": project.name},
984+
)
985+
986+
assert response.status_code == 200
987+
runs_list = response.json()
988+
assert len(runs_list) == 1
989+
assert runs_list[0]["run_spec"]["configuration"]["probes"] == expected_probes
990+
944991

945992
class TestGetRun:
946993
@pytest.mark.asyncio
@@ -1020,6 +1067,53 @@ async def test_returns_deleted_run_given_id(
10201067
assert response.status_code == 200, response.json()
10211068
assert response.json()["id"] == str(run.id)
10221069

1070+
@pytest.mark.asyncio
1071+
@pytest.mark.parametrize(
1072+
"client_version,expected_probes",
1073+
[
1074+
("0.20.7", []),
1075+
("0.20.8", None),
1076+
(None, None),
1077+
],
1078+
)
1079+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1080+
async def test_patches_service_configuration_probes_for_old_clients(
1081+
self,
1082+
test_db,
1083+
session: AsyncSession,
1084+
client: AsyncClient,
1085+
client_version: Optional[str],
1086+
expected_probes: Optional[list],
1087+
) -> None:
1088+
user = await create_user(session=session)
1089+
project = await create_project(session=session, owner=user)
1090+
repo = await create_repo(session=session, project_id=project.id)
1091+
1092+
service_conf = ServiceConfiguration(
1093+
commands=["echo hello"],
1094+
port=80,
1095+
probes=None, # This should be patched to [] for clients prior to 0.20.8
1096+
)
1097+
run_spec = get_run_spec(
1098+
configuration=service_conf,
1099+
repo_id=repo.name,
1100+
)
1101+
run = await create_run(
1102+
session=session, project=project, repo=repo, user=user, run_spec=run_spec
1103+
)
1104+
1105+
headers = get_auth_headers(user.token)
1106+
if client_version is not None:
1107+
headers["X-API-Version"] = client_version
1108+
response = await client.post(
1109+
f"/api/project/{project.name}/runs/get",
1110+
headers=headers,
1111+
json={"run_name": run.run_name},
1112+
)
1113+
1114+
assert response.status_code == 200
1115+
assert response.json()["run_spec"]["configuration"]["probes"] == expected_probes
1116+
10231117

10241118
class TestGetRunPlan:
10251119
@pytest.mark.asyncio
@@ -1477,6 +1571,65 @@ async def test_generates_user_ssh_key(self, session: AsyncSession, client: Async
14771571
assert user.ssh_public_key == run_spec_ssh_public_key
14781572
assert user.ssh_private_key is not None
14791573

1574+
@pytest.mark.asyncio
1575+
@pytest.mark.parametrize(
1576+
"client_version,expected_probes",
1577+
[
1578+
("0.20.7", []),
1579+
("0.20.8", None),
1580+
(None, None),
1581+
],
1582+
)
1583+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1584+
async def test_patches_service_configuration_probes_for_old_clients(
1585+
self,
1586+
test_db,
1587+
session: AsyncSession,
1588+
client: AsyncClient,
1589+
client_version: Optional[str],
1590+
expected_probes: Optional[list],
1591+
) -> None:
1592+
user = await create_user(session=session)
1593+
project = await create_project(session=session, owner=user)
1594+
repo = await create_repo(session=session, project_id=project.id)
1595+
1596+
service_conf = ServiceConfiguration(
1597+
commands=["echo hello"],
1598+
port=80,
1599+
probes=None, # This should be patched to [] for clients prior to 0.20.8
1600+
)
1601+
run_spec = get_run_spec(
1602+
run_name="test-service",
1603+
configuration=service_conf,
1604+
repo_id=repo.name,
1605+
)
1606+
await create_run(
1607+
session=session,
1608+
project=project,
1609+
repo=repo,
1610+
user=user,
1611+
run_spec=run_spec,
1612+
run_name="test-service",
1613+
)
1614+
1615+
body = {"run_spec": run_spec.dict()}
1616+
headers = get_auth_headers(user.token)
1617+
if client_version is not None:
1618+
headers["X-API-Version"] = client_version
1619+
response = await client.post(
1620+
f"/api/project/{project.name}/runs/get_plan",
1621+
headers=headers,
1622+
json=body,
1623+
)
1624+
1625+
assert response.status_code == 200
1626+
run_plan = response.json()
1627+
assert run_plan["run_spec"]["configuration"]["probes"] == expected_probes
1628+
assert run_plan["effective_run_spec"]["configuration"]["probes"] == expected_probes
1629+
assert (
1630+
run_plan["current_resource"]["run_spec"]["configuration"]["probes"] == expected_probes
1631+
)
1632+
14801633

14811634
class TestApplyPlan:
14821635
@pytest.mark.asyncio
@@ -1668,6 +1821,57 @@ async def test_generates_user_ssh_key(self, session: AsyncSession, client: Async
16681821
assert user.ssh_public_key == run_spec_ssh_public_key
16691822
assert user.ssh_private_key is not None
16701823

1824+
@pytest.mark.asyncio
1825+
@pytest.mark.parametrize(
1826+
"client_version,expected_probes",
1827+
[
1828+
("0.20.7", []), # Prior to 0.20.8, probes=None should be patched to []
1829+
("0.20.8", None), # 0.20.8 and later should keep probes=None
1830+
(None, None), # None client version should keep probes=None
1831+
],
1832+
)
1833+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1834+
async def test_patches_service_configuration_probes_for_old_clients(
1835+
self,
1836+
test_db,
1837+
session: AsyncSession,
1838+
client: AsyncClient,
1839+
client_version: Optional[str],
1840+
expected_probes: Optional[list],
1841+
) -> None:
1842+
user = await create_user(session=session)
1843+
project = await create_project(session=session, owner=user)
1844+
repo = await create_repo(session=session, project_id=project.id)
1845+
1846+
service_conf = ServiceConfiguration(
1847+
commands=["echo hello"],
1848+
port=80,
1849+
probes=None, # This should be patched to [] for clients prior to 0.20.8
1850+
)
1851+
run_spec = get_run_spec(
1852+
run_name="test-service",
1853+
configuration=service_conf,
1854+
repo_id=repo.name,
1855+
)
1856+
1857+
headers = get_auth_headers(user.token)
1858+
if client_version is not None:
1859+
headers["X-API-Version"] = client_version
1860+
response = await client.post(
1861+
f"/api/project/{project.name}/runs/apply",
1862+
headers=headers,
1863+
json={
1864+
"plan": {
1865+
"run_spec": run_spec.dict(),
1866+
"current_resource": None,
1867+
},
1868+
"force": False,
1869+
},
1870+
)
1871+
1872+
assert response.status_code == 200
1873+
assert response.json()["run_spec"]["configuration"]["probes"] == expected_probes
1874+
16711875

16721876
class TestSubmitRun:
16731877
@pytest.mark.asyncio

0 commit comments

Comments
 (0)