Skip to content

Commit 1f4162a

Browse files
authored
[RAPTOR-14353] add gunicorn with gevent to DRUM (#1655)
* added * working * lint * add server selector * run drum server * corrected * import flask * rollback * == "gunicorn" * reformat * workers = 1 for gunicorn * add logger * fix error * fix error * fix sys.argv
1 parent 16f821c commit 1f4162a

File tree

10 files changed

+392
-8
lines changed

10 files changed

+392
-8
lines changed

custom_model_runner/bin/drum

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#!/usr/bin/env python3
22
# PYTHON_ARGCOMPLETE_OK
3-
4-
import os
5-
from datarobot_drum.drum.main import main
3+
from datarobot_drum.drum.entry_point import run_drum_server
64

75
if __name__ == "__main__":
8-
main()
6+
run_drum_server()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from datarobot_drum.drum.gunicorn.run_gunicorn import main_gunicorn
2+
from datarobot_drum.drum.main import main
3+
from datarobot_drum import RuntimeParameters
4+
5+
6+
def run_drum_server():
7+
if (
8+
RuntimeParameters.has("DRUM_SERVER_TYPE")
9+
and str(RuntimeParameters.get("DRUM_SERVER_TYPE")).lower() == "gunicorn"
10+
):
11+
main_gunicorn()
12+
else:
13+
main()
14+
15+
16+
if __name__ == "__main__":
17+
run_drum_server()

custom_model_runner/datarobot_drum/drum/gunicorn/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from datarobot_drum.drum.server import create_flask_app
2+
3+
app = create_flask_app()
4+
5+
_worker_ctx = None
6+
7+
8+
def set_worker_ctx(ctx):
9+
global _worker_ctx
10+
_worker_ctx = ctx
11+
12+
13+
def get_worker_ctx():
14+
return _worker_ctx
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import logging
2+
import threading
3+
from typing import Callable, Any, List, Tuple, Optional
4+
5+
from datarobot_drum.drum.enum import LOGGER_NAME_PREFIX
6+
7+
logger = logging.getLogger(LOGGER_NAME_PREFIX + "." + __name__)
8+
9+
10+
class WorkerCtx:
11+
"""
12+
Context of a single gunicorn worker:
13+
- .start() — start background tasks
14+
- .stop() — stop background tasks (graceful)
15+
- .cleanup() — close resources (DB/clients) in the correct order
16+
- add_* methods for registering objects/callbacks for stopping/cleaning up
17+
"""
18+
19+
def __init__(self, app):
20+
self.app = app
21+
self._running = False
22+
self._threads: List[threading.Thread] = []
23+
self._greenlets: List[Any] = []
24+
self._on_stop: List[Tuple[int, Callable[[], None], str]] = []
25+
self._on_cleanup: List[Tuple[int, Callable[[], None], str]] = []
26+
27+
def add_thread(
28+
self, t: threading.Thread, *, join_timeout: float = 2.0, name: Optional[str] = None
29+
):
30+
"""
31+
Adds a thread to the worker context and registers it for graceful stopping.
32+
33+
Args:
34+
t (threading.Thread): The thread to be added.
35+
join_timeout (float, optional): The timeout for joining the thread during stop. Defaults to 2.0 seconds.
36+
name (Optional[str], optional): A descriptive name for the thread. Defaults to None.
37+
"""
38+
t.daemon = True
39+
self._threads.append(t)
40+
41+
def _join():
42+
if t.is_alive():
43+
t.join(join_timeout)
44+
45+
self.defer_stop(_join, desc=name or f"thread:{id(t)}")
46+
47+
def add_greenlet(self, g, *, kill_timeout: float = 2.0, name: Optional[str] = None):
48+
"""
49+
Adds a greenlet to the worker context and registers it for graceful stopping.
50+
51+
Args:
52+
g: The greenlet to be added.
53+
kill_timeout (float, optional): The timeout for killing the greenlet during stop. Defaults to 2.0 seconds.
54+
name (Optional[str], optional): A descriptive name for the greenlet. Defaults to None.
55+
"""
56+
self._greenlets.append(g)
57+
58+
def _kill():
59+
try:
60+
g.kill(block=True, timeout=kill_timeout)
61+
except Exception:
62+
pass
63+
64+
self.defer_stop(_kill, desc=name or f"greenlet:{id(g)}")
65+
66+
def add_closeable(
67+
self, obj: Any, method: str = "close", *, order: int = 0, desc: Optional[str] = None
68+
):
69+
"""
70+
Registers an object to be closed using a specified method (e.g., close/quit/disconnect/shutdown) during cleanup (after stop).
71+
72+
Args:
73+
obj (Any): The object to be closed.
74+
method (str, optional): The method to call on the object for closing. Defaults to "close".
75+
order (int, optional): The priority order for cleanup. Lower values are executed later. Defaults to 0.
76+
desc (Optional[str], optional): A description for the cleanup action. Defaults to None.
77+
"""
78+
fn = getattr(obj, method, None)
79+
if callable(fn):
80+
self.defer_cleanup(fn, order=order, desc=desc or f"{obj}.{method}")
81+
82+
def defer_stop(self, fn: Callable[[], None], *, order: int = 0, desc: str = "on_stop"):
83+
self._on_stop.append((order, fn, desc))
84+
85+
def defer_cleanup(self, fn: Callable[[], None], *, order: int = 0, desc: str = "on_cleanup"):
86+
self._on_cleanup.append((order, fn, desc))
87+
88+
def start(self):
89+
"""
90+
Starts background tasks for the worker context.
91+
92+
This method sets the running flag to True and initializes the main application logic.
93+
It imports and calls the `main` function from `datarobot_drum.drum.main`, passing
94+
the application instance and the worker context as arguments.
95+
"""
96+
self._running = True
97+
98+
# fix RuntimeError: asyncio.run() cannot be called from a running event loop
99+
import asyncio
100+
import opentelemetry.instrumentation.openai.utils as otel_utils
101+
102+
def fixed_run_async(method):
103+
try:
104+
loop = asyncio.get_running_loop()
105+
except RuntimeError:
106+
loop = None
107+
108+
if loop and loop.is_running():
109+
110+
def handle_task_result(task):
111+
try:
112+
task.result() # retrieve exception if any
113+
except Exception as e:
114+
# Log or handle exceptions here if needed
115+
logger.error("OpenTelemetry async task error")
116+
117+
# Schedule the coroutine safely and add done callback to handle errors
118+
task = loop.create_task(method)
119+
task.add_done_callback(handle_task_result)
120+
else:
121+
asyncio.run(method)
122+
123+
# Apply monkey patch
124+
otel_utils.run_async = fixed_run_async
125+
126+
from datarobot_drum.drum.main import main
127+
128+
main(self.app, self)
129+
130+
def stop(self):
131+
"""
132+
Gracefully stops the worker context.
133+
134+
This method sets the running flag to False and executes all registered
135+
stop callbacks in the order of their priority. Threads and greenlets
136+
are then joined or killed as part of the stopping process.
137+
"""
138+
self._running = False
139+
for _, fn, desc in sorted(self._on_stop, key=lambda x: x[0]):
140+
try:
141+
fn()
142+
except Exception:
143+
pass
144+
145+
def cleanup(self):
146+
"""
147+
Cleans up resources in the worker context.
148+
149+
This method is called at the end of the worker's lifecycle to release resources.
150+
It executes all registered cleanup callbacks in reverse priority order, ensuring
151+
that resources are closed in the correct sequence.
152+
153+
Exceptions during cleanup are caught and ignored to prevent interruption.
154+
"""
155+
for _, fn, desc in sorted(self._on_cleanup, key=lambda x: x[0], reverse=True):
156+
try:
157+
logger.info("WorkerCtx cleanup: %s", desc)
158+
fn()
159+
except Exception as e:
160+
logger.error("Tracing shutdown failed: %s", e)
161+
162+
def running(self) -> bool:
163+
return self._running
164+
165+
166+
def create_ctx(app):
167+
"""
168+
Factory method to create a WorkerCtx instance.
169+
170+
This method initializes the worker context for the application and allows
171+
registration of objects or callbacks for stopping and cleanup. It ensures
172+
that no long-running processes are started here; the actual startup occurs
173+
in `post_worker_init` via `ctx.start()`.
174+
175+
Args:
176+
app: The application instance.
177+
178+
Returns:
179+
WorkerCtx: The initialized worker context.
180+
"""
181+
ctx = WorkerCtx(app)
182+
return ctx
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Import DRUM's WSGI application
2+
import os
3+
from datarobot_drum import RuntimeParameters
4+
5+
workers = 1
6+
if RuntimeParameters.has("CUSTOM_MODEL_WORKERS"):
7+
temp_workers = int(RuntimeParameters.get("CUSTOM_MODEL_WORKERS"))
8+
if 0 < temp_workers < 200:
9+
workers = temp_workers
10+
elif os.environ.get("MAX_WORKERS"):
11+
temp_workers = int(os.environ.get("MAX_WORKERS"))
12+
if 0 < temp_workers < 200:
13+
workers = temp_workers
14+
else:
15+
pass
16+
17+
backlog = 2048
18+
if RuntimeParameters.has("DRUM_WEBSERVER_BACKLOG"):
19+
temp_backlog = int(RuntimeParameters.get("DRUM_WEBSERVER_BACKLOG"))
20+
if 1 <= temp_backlog <= 10000:
21+
backlog = temp_backlog
22+
23+
timeout = 120
24+
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT"):
25+
temp_timeout = int(RuntimeParameters.get("DRUM_CLIENT_REQUEST_TIMEOUT"))
26+
if 0 <= temp_timeout <= 3600:
27+
timeout = temp_timeout
28+
29+
max_requests = 0
30+
if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"):
31+
temp_max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS"))
32+
if 0 <= temp_max_requests <= 10000:
33+
max_requests = temp_max_requests
34+
35+
max_requests_jitter = 0
36+
if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"):
37+
temp_max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER"))
38+
if 1 <= temp_max_requests_jitter <= 10000:
39+
max_requests_jitter = temp_max_requests_jitter
40+
41+
worker_connections = 5
42+
if RuntimeParameters.has("DRUM_WORKER_CONNECTIONS"):
43+
temp_worker_connections = int(RuntimeParameters.get("DRUM_WORKER_CONNECTIONS"))
44+
if 1 <= temp_worker_connections <= 10000:
45+
worker_connections = temp_worker_connections
46+
47+
worker_class = "gevent"
48+
if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"):
49+
temp_worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower()
50+
if temp_worker_class in {"sync", "gevent"}:
51+
worker_class = temp_worker_class
52+
53+
if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"):
54+
temp_graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT"))
55+
if 1 <= temp_graceful_timeout <= 3600:
56+
graceful_timeout = temp_graceful_timeout
57+
58+
if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"):
59+
temp_keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE"))
60+
if 1 <= temp_keepalive <= 3600:
61+
keepalive = temp_keepalive
62+
63+
loglevel = "info"
64+
if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"):
65+
temp_loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL")).lower()
66+
if temp_loglevel in {"debug", "info", "warning", "error", "critical"}:
67+
loglevel = temp_loglevel
68+
69+
bind = os.environ.get("ADDRESS", "0.0.0.0:8080")
70+
# loglevel = "info"
71+
accesslog = "-"
72+
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
73+
74+
75+
def post_worker_init(worker):
76+
from app import app, set_worker_ctx
77+
from datarobot_drum.drum.gunicorn.context import create_ctx
78+
import sys, shlex
79+
80+
sys.argv = shlex.split(os.environ.get("DRUM_GUNICORN_DRUM_ARGS"))
81+
82+
os.environ["MAX_WORKERS"] = "1"
83+
if RuntimeParameters.has("CUSTOM_MODEL_WORKERS"):
84+
os.environ.pop("MLOPS_RUNTIME_PARAM_CUSTOM_MODEL_WORKERS", None)
85+
86+
ctx = create_ctx(app)
87+
set_worker_ctx(ctx)
88+
ctx.start()
89+
90+
91+
def worker_exit(worker, code):
92+
"""
93+
Handles the shutdown and cleanup operations for a single Gunicorn worker.
94+
95+
When terminating or restarting a single Gunicorn worker, this method ensures
96+
that the shutdown and cleanup operations are applied only to that specific worker.
97+
The following steps should be executed per worker:
98+
- predictor.terminate()
99+
- stats_collector.mark("end")
100+
- runtime.cm_runner.terminate()
101+
- runtime.trace_provider.shutdown()
102+
- runtime.metric_provider.shutdown()
103+
- runtime.log_provider.shutdown()
104+
- etc.
105+
106+
This is necessary because the Gunicorn server is started from the command line
107+
(outside of DRUM), and worker-specific cleanup must be handled explicitly.
108+
109+
Args:
110+
worker: The Gunicorn worker instance being terminated.
111+
code: The exit code for the worker.
112+
"""
113+
from app import get_worker_ctx
114+
115+
ctx = get_worker_ctx()
116+
if ctx:
117+
try:
118+
ctx.stop()
119+
finally:
120+
ctx.cleanup()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
import subprocess
3+
from pathlib import Path
4+
import sys
5+
import os
6+
import shlex
7+
8+
from datarobot_drum.drum.enum import LOGGER_NAME_PREFIX
9+
10+
logger = logging.getLogger(LOGGER_NAME_PREFIX + "." + __name__)
11+
12+
13+
def main_gunicorn():
14+
# Resolve directory containing this script so we can always find config and app.py
15+
base_dir = Path(__file__).resolve().parent
16+
config_path = base_dir / "gunicorn.conf.py"
17+
18+
if not config_path.is_file():
19+
raise FileNotFoundError(f"Gunicorn config not found: {config_path}")
20+
21+
# Export all provided CLI args (excluding script) into DRUM_GUNICORN_DRUM_ARGS
22+
extra_args = sys.argv
23+
if extra_args:
24+
try:
25+
os.environ["DRUM_GUNICORN_DRUM_ARGS"] = shlex.join(extra_args)
26+
except AttributeError:
27+
os.environ["DRUM_GUNICORN_DRUM_ARGS"] = " ".join(shlex.quote(a) for a in extra_args)
28+
29+
# Use the gunicorn module explicitly to avoid issues where a shadowed
30+
# console script named "gunicorn" actually invokes the DRUM CLI.
31+
gunicorn_command = [
32+
sys.executable,
33+
"-m",
34+
"gunicorn",
35+
"-c",
36+
str(config_path),
37+
"app:app", # module:variable; app.py sits next to this script
38+
]
39+
40+
try:
41+
subprocess.run(gunicorn_command, cwd=base_dir, check=True)
42+
except FileNotFoundError:
43+
logger.error("gunicorn module not found. Ensure it is installed.")
44+
raise
45+
except subprocess.CalledProcessError as e:
46+
logger.error("Gunicorn exited with non-zero status %s", e.returncode)
47+
raise
48+
49+
50+
if __name__ == "__main__":
51+
main_gunicorn()

0 commit comments

Comments
 (0)