Skip to content

Commit 0c3ede7

Browse files
Create worker pods through Deployments (#730)
* Create worker pods through Deployments * Add a test --------- Co-authored-by: Jacob Tomlinson <jtomlinson@nvidia.com>
1 parent 62da268 commit 0c3ede7

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

dask_kubernetes/operator/controller/controller.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import copy
32
from collections import defaultdict
43
import time
54
from contextlib import suppress
@@ -111,10 +110,9 @@ def build_scheduler_service_spec(cluster_name, spec, annotations, labels):
111110
}
112111

113112

114-
def build_worker_pod_spec(
115-
worker_group_name, namespace, cluster_name, uuid, spec, annotations, labels
113+
def build_worker_deployment_spec(
114+
worker_group_name, namespace, cluster_name, uuid, pod_spec, annotations, labels
116115
):
117-
spec = copy.deepcopy(spec)
118116
labels.update(
119117
**{
120118
"dask.org/cluster-name": cluster_name,
@@ -124,14 +122,24 @@ def build_worker_pod_spec(
124122
}
125123
)
126124
worker_name = f"{worker_group_name}-worker-{uuid}"
127-
pod_spec = {
128-
"apiVersion": "v1",
129-
"kind": "Pod",
130-
"metadata": {
131-
"name": worker_name,
132-
"labels": labels,
133-
"annotations": annotations,
134-
},
125+
metadata = {
126+
"name": worker_name,
127+
"labels": labels,
128+
"annotations": annotations,
129+
}
130+
spec = {}
131+
spec["replicas"] = 1 # make_worker_spec returns dict with a replicas key?
132+
spec["selector"] = {
133+
"matchLabels": labels,
134+
}
135+
spec["template"] = {
136+
"metadata": metadata,
137+
"spec": pod_spec,
138+
}
139+
deployment_spec = {
140+
"apiVersion": "apps/v1",
141+
"kind": "Deployment",
142+
"metadata": metadata,
135143
"spec": spec,
136144
}
137145
env = [
@@ -144,12 +152,14 @@ def build_worker_pod_spec(
144152
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
145153
},
146154
]
147-
for i in range(len(pod_spec["spec"]["containers"])):
148-
if "env" in pod_spec["spec"]["containers"][i]:
149-
pod_spec["spec"]["containers"][i]["env"].extend(env)
155+
for i in range(len(deployment_spec["spec"]["template"]["spec"]["containers"])):
156+
if "env" in deployment_spec["spec"]["template"]["spec"]["containers"][i]:
157+
deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"].extend(
158+
env
159+
)
150160
else:
151-
pod_spec["spec"]["containers"][i]["env"] = env
152-
return pod_spec
161+
deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"] = env
162+
return deployment_spec
153163

154164

155165
def get_job_runner_pod_name(job_name):
@@ -632,18 +642,20 @@ async def daskworkergroup_replica_update(
632642
labels.update(**worker_spec["metadata"]["labels"])
633643
if workers_needed > 0:
634644
for _ in range(workers_needed):
635-
data = build_worker_pod_spec(
645+
data = build_worker_deployment_spec(
636646
worker_group_name=name,
637647
namespace=namespace,
638648
cluster_name=cluster_name,
639649
uuid=uuid4().hex[:10],
640-
spec=worker_spec["spec"],
650+
pod_spec=worker_spec["spec"],
641651
annotations=annotations,
642652
labels=labels,
643653
)
644654
kopf.adopt(data, owner=body)
645655
kopf.label(data, labels=cluster_labels)
646-
await corev1api.create_namespaced_pod(
656+
await kubernetes.client.AppsV1Api(
657+
api_client
658+
).create_namespaced_deployment(
647659
namespace=namespace,
648660
body=data,
649661
)
@@ -660,7 +672,9 @@ async def daskworkergroup_replica_update(
660672
)
661673
logger.info(f"Workers to close: {worker_ids}")
662674
for wid in worker_ids:
663-
await corev1api.delete_namespaced_pod(
675+
await kubernetes.client.AppsV1Api(
676+
api_client
677+
).delete_namespaced_deployment(
664678
name=wid,
665679
namespace=namespace,
666680
)

dask_kubernetes/operator/controller/tests/test_controller.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,46 @@ async def test_recreate_scheduler_pod(k8s_cluster, kopf_runner, gen_cluster):
374374
)
375375

376376

377+
@pytest.mark.asyncio
378+
async def test_recreate_worker_pods(k8s_cluster, kopf_runner, gen_cluster):
379+
with kopf_runner as runner:
380+
async with gen_cluster() as (cluster_name, ns):
381+
scheduler_deployment_name = "simple-scheduler"
382+
worker_deployment_name = "simple-default-worker"
383+
service_name = "simple-scheduler"
384+
while scheduler_deployment_name not in k8s_cluster.kubectl(
385+
"get", "pods", "-n", ns
386+
):
387+
await asyncio.sleep(0.1)
388+
while service_name not in k8s_cluster.kubectl("get", "svc", "-n", ns):
389+
await asyncio.sleep(0.1)
390+
while worker_deployment_name not in k8s_cluster.kubectl(
391+
"get", "pods", "-n", ns
392+
):
393+
await asyncio.sleep(0.1)
394+
k8s_cluster.kubectl(
395+
"delete",
396+
"pods",
397+
"-l",
398+
"dask.org/cluster-name=simple,dask.org/component=worker",
399+
"-n",
400+
ns,
401+
)
402+
k8s_cluster.kubectl(
403+
"wait",
404+
"--for=condition=Ready",
405+
"-l",
406+
"dask.org/cluster-name=simple,dask.org/component=worker",
407+
"pod",
408+
"-n",
409+
ns,
410+
"--timeout=60s",
411+
)
412+
assert worker_deployment_name in k8s_cluster.kubectl(
413+
"get", "pods", "-n", ns
414+
)
415+
416+
377417
def _get_job_status(k8s_cluster, ns):
378418
return json.loads(
379419
k8s_cluster.kubectl(

0 commit comments

Comments
 (0)