Skip to content

Commit 1e49482

Browse files
Add utility methods to kr8s objects (#741)
1 parent 0c3ede7 commit 1e49482

File tree

2 files changed

+179
-1
lines changed

2 files changed

+179
-1
lines changed

dask_kubernetes/operator/controller/tests/test_controller.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import yaml
1010
from dask.distributed import Client
1111

12+
from kr8s.asyncio.objects import Pod, Deployment, Service
1213
from dask_kubernetes.operator.controller import (
1314
KUBERNETES_DATETIME_FORMAT,
1415
get_job_runner_pod_name,
1516
)
17+
from dask_kubernetes.operator.objects import DaskCluster, DaskWorkerGroup, DaskJob
1618

1719
DIR = pathlib.Path(__file__).parent.absolute()
1820

@@ -590,3 +592,70 @@ async def test_failed_job(k8s_cluster, kopf_runner, gen_job):
590592

591593
assert "A DaskJob has been created" in runner.stdout
592594
assert "Job failed, deleting Dask cluster." in runner.stdout
595+
596+
597+
@pytest.mark.asyncio
598+
async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster):
599+
with kopf_runner as runner:
600+
async with gen_cluster() as (cluster_name, ns):
601+
cluster = await DaskCluster.get(cluster_name, namespace=ns)
602+
603+
worker_groups = []
604+
while not worker_groups:
605+
worker_groups = await cluster.worker_groups()
606+
await asyncio.sleep(0.1)
607+
assert len(worker_groups) == 1 # Just the default worker group
608+
wg = worker_groups[0]
609+
assert isinstance(wg, DaskWorkerGroup)
610+
611+
scheduler_pod = await cluster.scheduler_pod()
612+
assert isinstance(scheduler_pod, Pod)
613+
614+
scheduler_deployment = await cluster.scheduler_deployment()
615+
assert isinstance(scheduler_deployment, Deployment)
616+
617+
scheduler_service = await cluster.scheduler_service()
618+
assert isinstance(scheduler_service, Service)
619+
620+
621+
@pytest.mark.asyncio
622+
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster):
623+
with kopf_runner as runner:
624+
async with gen_cluster() as (cluster_name, ns):
625+
cluster = await DaskCluster.get(cluster_name, namespace=ns)
626+
627+
worker_groups = []
628+
while not worker_groups:
629+
worker_groups = await cluster.worker_groups()
630+
await asyncio.sleep(0.1)
631+
assert len(worker_groups) == 1 # Just the default worker group
632+
wg = worker_groups[0]
633+
assert isinstance(wg, DaskWorkerGroup)
634+
635+
pods = []
636+
while not pods:
637+
pods = await wg.pods()
638+
await asyncio.sleep(0.1)
639+
assert all([isinstance(p, Pod) for p in pods])
640+
641+
deployments = []
642+
while not deployments:
643+
deployments = await wg.deployments()
644+
await asyncio.sleep(0.1)
645+
assert all([isinstance(d, Deployment) for d in deployments])
646+
647+
assert (await wg.cluster()).name == cluster.name
648+
649+
650+
@pytest.mark.asyncio
651+
@pytest.mark.skip(reason="Flaky in CI")
652+
async def test_object_dask_job(k8s_cluster, kopf_runner, gen_job):
653+
with kopf_runner as runner:
654+
async with gen_job("simplejob.yaml") as (job_name, ns):
655+
job = await DaskJob.get(job_name, namespace=ns)
656+
657+
job_pod = await job.pod()
658+
assert isinstance(job_pod, Pod)
659+
660+
cluster = await job.cluster()
661+
assert isinstance(cluster, DaskCluster)

dask_kubernetes/operator/objects.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from kr8s.asyncio.objects import APIObject
1+
from __future__ import annotations
2+
from typing import List
3+
4+
from kr8s.asyncio.objects import APIObject, Pod, Deployment, Service
25

36

47
class DaskCluster(APIObject):
@@ -11,6 +14,61 @@ class DaskCluster(APIObject):
1114
scalable = True
1215
scalable_spec = "worker.replicas"
1316

17+
async def worker_groups(self) -> List[DaskWorkerGroup]:
18+
return await self.api.get(
19+
DaskWorkerGroup.endpoint,
20+
label_selector=f"dask.org/cluster-name={self.name}",
21+
namespace=self.namespace,
22+
)
23+
24+
async def scheduler_pod(self) -> Pod:
25+
pods = []
26+
while not pods:
27+
pods = await self.api.get(
28+
Pod.endpoint,
29+
label_selector=",".join(
30+
[
31+
f"dask.org/cluster-name={self.name}",
32+
"dask.org/component=scheduler",
33+
]
34+
),
35+
namespace=self.namespace,
36+
)
37+
assert len(pods) == 1
38+
return pods[0]
39+
40+
async def scheduler_deployment(self) -> Deployment:
41+
deployments = []
42+
while not deployments:
43+
deployments = await self.api.get(
44+
Deployment.endpoint,
45+
label_selector=",".join(
46+
[
47+
f"dask.org/cluster-name={self.name}",
48+
"dask.org/component=scheduler",
49+
]
50+
),
51+
namespace=self.namespace,
52+
)
53+
assert len(deployments) == 1
54+
return deployments[0]
55+
56+
async def scheduler_service(self) -> Service:
57+
services = []
58+
while not services:
59+
services = await self.api.get(
60+
Service.endpoint,
61+
label_selector=",".join(
62+
[
63+
f"dask.org/cluster-name={self.name}",
64+
"dask.org/component=scheduler",
65+
]
66+
),
67+
namespace=self.namespace,
68+
)
69+
assert len(services) == 1
70+
return services[0]
71+
1472

1573
class DaskWorkerGroup(APIObject):
1674
version = "kubernetes.dask.org/v1"
@@ -21,6 +79,35 @@ class DaskWorkerGroup(APIObject):
2179
namespaced = True
2280
scalable = True
2381

82+
async def pods(self) -> List[Pod]:
83+
return await self.api.get(
84+
Pod.endpoint,
85+
label_selector=",".join(
86+
[
87+
f"dask.org/cluster-name={self.spec['cluster']}",
88+
"dask.org/component=worker",
89+
f"dask.org/workergroup-name={self.name}",
90+
]
91+
),
92+
namespace=self.namespace,
93+
)
94+
95+
async def deployments(self) -> List[Deployment]:
96+
return await self.api.get(
97+
Deployment.endpoint,
98+
label_selector=",".join(
99+
[
100+
f"dask.org/cluster-name={self.spec['cluster']}",
101+
"dask.org/component=worker",
102+
f"dask.org/workergroup-name={self.name}",
103+
]
104+
),
105+
namespace=self.namespace,
106+
)
107+
108+
async def cluster(self) -> DaskCluster:
109+
return await DaskCluster.get(self.spec["cluster"], namespace=self.namespace)
110+
24111

25112
class DaskAutoscaler(APIObject):
26113
version = "kubernetes.dask.org/v1"
@@ -30,6 +117,9 @@ class DaskAutoscaler(APIObject):
30117
singular = "daskautoscaler"
31118
namespaced = True
32119

120+
async def cluster(self) -> DaskCluster:
121+
return await DaskCluster.get(self.spec["cluster"], namespace=self.namespace)
122+
33123

34124
class DaskJob(APIObject):
35125
version = "kubernetes.dask.org/v1"
@@ -38,3 +128,22 @@ class DaskJob(APIObject):
38128
plural = "daskjobs"
39129
singular = "daskjob"
40130
namespaced = True
131+
132+
async def cluster(self) -> DaskCluster:
133+
return await DaskCluster.get(self.name, namespace=self.namespace)
134+
135+
async def pod(self) -> Pod:
136+
pods = []
137+
while not pods:
138+
pods = await self.api.get(
139+
Pod.endpoint,
140+
label_selector=",".join(
141+
[
142+
f"dask.org/cluster-name={self.name}",
143+
"dask.org/component=job-runner",
144+
]
145+
),
146+
namespace=self.namespace,
147+
)
148+
assert len(pods) == 1
149+
return pods[0]

0 commit comments

Comments
 (0)