Skip to content

Commit e858bfe

Browse files
benchislettgemini-code-assist[bot]hmellor
authored
[Cleanup] Refactor profiling env vars into a CLI config (#29912)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent d471b2a commit e858bfe

File tree

22 files changed

+437
-256
lines changed

22 files changed

+437
-256
lines changed

benchmarks/auto_tune/auto_tune.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ start_server() {
9696
# This correctly passes each element as a separate argument.
9797
if [[ -n "$profile_dir" ]]; then
9898
# Start server with profiling enabled
99-
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
100-
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
99+
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
100+
VLLM_SERVER_DEV_MODE=1 \
101+
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
101102
else
102103
# Start server without profiling
103104
VLLM_SERVER_DEV_MODE=1 \

benchmarks/benchmark_serving_structured_output.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,8 +963,7 @@ def create_argument_parser():
963963
parser.add_argument(
964964
"--profile",
965965
action="store_true",
966-
help="Use Torch Profiler. The endpoint must be launched with "
967-
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
966+
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
968967
)
969968
parser.add_argument(
970969
"--result-dir",

docs/api/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes.
1515
- [vllm.config.MultiModalConfig][]
1616
- [vllm.config.PoolerConfig][]
1717
- [vllm.config.StructuredOutputsConfig][]
18+
- [vllm.config.ProfilerConfig][]
1819
- [vllm.config.ObservabilityConfig][]
1920
- [vllm.config.KVTransferConfig][]
2021
- [vllm.config.CompilationConfig][]

docs/contributing/profiling.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55

66
## Profile with PyTorch Profiler
77

8-
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables:
8+
We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config`
9+
when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config:
910

10-
- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default
11-
- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default
12-
- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default
13-
- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default
14-
- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default
15-
- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default
16-
17-
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
11+
- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default
12+
- `torch_profiler_with_memory` to record memory, off by default
13+
- `torch_profiler_with_stack` to enable recording stack information, on by default
14+
- `torch_profiler_with_flops` to enable recording FLOPs, off by default
15+
- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default
16+
- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default
1817

1918
When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag.
2019

@@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline
4039
#### OpenAI Server
4140

4241
```bash
43-
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
44-
vllm serve meta-llama/Llama-3.1-8B-Instruct
42+
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
4543
```
4644

4745
vllm bench command:
@@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with `
104102

105103
```bash
106104
# server
107-
VLLM_TORCH_CUDA_PROFILE=1 \
108105
nsys profile \
109106
--trace-fork-before-exec=true \
110107
--cuda-graph-trace=node \
111108
--capture-range=cudaProfilerApi \
112109
--capture-range-end repeat \
113-
vllm serve meta-llama/Llama-3.1-8B-Instruct
110+
vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda
114111

115112
# client
116113
vllm bench serve \

examples/offline_inference/simple_profiling.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
54
import time
65

76
from vllm import LLM, SamplingParams
87

9-
# enable torch profiler, can also be set on cmd line
10-
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
11-
128
# Sample prompts.
139
prompts = [
1410
"Hello, my name is",
@@ -22,7 +18,14 @@
2218

2319
def main():
2420
# Create an LLM.
25-
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
21+
llm = LLM(
22+
model="facebook/opt-125m",
23+
tensor_parallel_size=1,
24+
profiler_config={
25+
"profiler": "torch",
26+
"torch_profiler_dir": "./vllm_profile",
27+
},
28+
)
2629

2730
llm.start_profile()
2831

tests/v1/worker/test_gpu_profiler.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44

5-
import vllm.envs as envs
6-
from vllm.profiler.gpu_profiler import WorkerProfiler
5+
from vllm.config import ProfilerConfig
6+
from vllm.profiler.wrapper import WorkerProfiler
77

88

99
class ConcreteWorkerProfiler(WorkerProfiler):
1010
"""
1111
A basic implementation of a worker profiler for testing purposes.
1212
"""
1313

14-
def __init__(self):
14+
def __init__(self, profiler_config: ProfilerConfig):
1515
self.start_call_count = 0
1616
self.stop_call_count = 0
1717
self.should_fail_start = False
18-
super().__init__()
18+
super().__init__(profiler_config)
1919

2020
def _start(self) -> None:
2121
if self.should_fail_start:
@@ -26,17 +26,19 @@ def _stop(self) -> None:
2626
self.stop_call_count += 1
2727

2828

29-
@pytest.fixture(autouse=True)
30-
def reset_mocks():
31-
"""Fixture to reset mocks and env variables before each test."""
32-
envs.VLLM_PROFILER_DELAY_ITERS = 0
33-
envs.VLLM_PROFILER_MAX_ITERS = 0
29+
@pytest.fixture
30+
def default_profiler_config():
31+
return ProfilerConfig(
32+
profiler="torch",
33+
torch_profiler_dir="/tmp/mock",
34+
delay_iterations=0,
35+
max_iterations=0,
36+
)
3437

3538

36-
def test_immediate_start_stop():
39+
def test_immediate_start_stop(default_profiler_config):
3740
"""Test standard start without delay."""
38-
profiler = ConcreteWorkerProfiler()
39-
41+
profiler = ConcreteWorkerProfiler(default_profiler_config)
4042
profiler.start()
4143
assert profiler._running is True
4244
assert profiler._active is True
@@ -48,10 +50,10 @@ def test_immediate_start_stop():
4850
assert profiler.stop_call_count == 1
4951

5052

51-
def test_delayed_start():
53+
def test_delayed_start(default_profiler_config):
5254
"""Test that profiler waits for N steps before actually starting."""
53-
envs.VLLM_PROFILER_DELAY_ITERS = 2
54-
profiler = ConcreteWorkerProfiler()
55+
default_profiler_config.delay_iterations = 2
56+
profiler = ConcreteWorkerProfiler(default_profiler_config)
5557

5658
# User requests start
5759
profiler.start()
@@ -71,10 +73,10 @@ def test_delayed_start():
7173
assert profiler.start_call_count == 1
7274

7375

74-
def test_max_iterations():
76+
def test_max_iterations(default_profiler_config):
7577
"""Test that profiler stops automatically after max iterations."""
76-
envs.VLLM_PROFILER_MAX_ITERS = 2
77-
profiler = ConcreteWorkerProfiler()
78+
default_profiler_config.max_iterations = 2
79+
profiler = ConcreteWorkerProfiler(default_profiler_config)
7880

7981
profiler.start()
8082
assert profiler._running is True
@@ -95,12 +97,11 @@ def test_max_iterations():
9597
assert profiler.stop_call_count == 1
9698

9799

98-
def test_delayed_start_and_max_iters():
100+
def test_delayed_start_and_max_iters(default_profiler_config):
99101
"""Test combined delayed start and max iterations."""
100-
envs.VLLM_PROFILER_DELAY_ITERS = 2
101-
envs.VLLM_PROFILER_MAX_ITERS = 2
102-
profiler = ConcreteWorkerProfiler()
103-
102+
default_profiler_config.delay_iterations = 2
103+
default_profiler_config.max_iterations = 2
104+
profiler = ConcreteWorkerProfiler(default_profiler_config)
104105
profiler.start()
105106

106107
# Step 1
@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
127128
assert profiler.stop_call_count == 1
128129

129130

130-
def test_idempotency():
131+
def test_idempotency(default_profiler_config):
131132
"""Test that calling start/stop multiple times doesn't break logic."""
132-
profiler = ConcreteWorkerProfiler()
133+
profiler = ConcreteWorkerProfiler(default_profiler_config)
133134

134135
# Double Start
135136
profiler.start()
@@ -142,10 +143,10 @@ def test_idempotency():
142143
assert profiler.stop_call_count == 1 # Should only stop once
143144

144145

145-
def test_step_inactive():
146+
def test_step_inactive(default_profiler_config):
146147
"""Test that stepping while inactive does nothing."""
147-
envs.VLLM_PROFILER_DELAY_ITERS = 2
148-
profiler = ConcreteWorkerProfiler()
148+
default_profiler_config.delay_iterations = 2
149+
profiler = ConcreteWorkerProfiler(default_profiler_config)
149150

150151
# Not started yet
151152
profiler.step()
@@ -155,9 +156,9 @@ def test_step_inactive():
155156
assert profiler.start_call_count == 0
156157

157158

158-
def test_start_failure():
159+
def test_start_failure(default_profiler_config):
159160
"""Test behavior when the underlying _start method raises exception."""
160-
profiler = ConcreteWorkerProfiler()
161+
profiler = ConcreteWorkerProfiler(default_profiler_config)
161162
profiler.should_fail_start = True
162163

163164
profiler.start()
@@ -168,9 +169,9 @@ def test_start_failure():
168169
assert profiler.start_call_count == 0 # Logic failed inside start
169170

170171

171-
def test_shutdown():
172+
def test_shutdown(default_profiler_config):
172173
"""Test that shutdown calls stop only if running."""
173-
profiler = ConcreteWorkerProfiler()
174+
profiler = ConcreteWorkerProfiler(default_profiler_config)
174175

175176
# Case 1: Not running
176177
profiler.shutdown()
@@ -182,10 +183,10 @@ def test_shutdown():
182183
assert profiler.stop_call_count == 1
183184

184185

185-
def test_mixed_delay_and_stop():
186+
def test_mixed_delay_and_stop(default_profiler_config):
186187
"""Test manual stop during the delay period."""
187-
envs.VLLM_PROFILER_DELAY_ITERS = 5
188-
profiler = ConcreteWorkerProfiler()
188+
default_profiler_config.delay_iterations = 5
189+
profiler = ConcreteWorkerProfiler(default_profiler_config)
189190

190191
profiler.start()
191192
profiler.step()

vllm/benchmarks/latency.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
from tqdm import tqdm
1414

15-
import vllm.envs as envs
1615
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
1716
from vllm.engine.arg_utils import EngineArgs
1817
from vllm.inputs import PromptType
@@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
7978

8079

8180
def main(args: argparse.Namespace):
82-
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
83-
raise OSError(
84-
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
85-
"Please set it to a valid path to use torch profiler."
86-
)
8781
engine_args = EngineArgs.from_cli_args(args)
82+
if args.profile and not engine_args.profiler_config.profiler == "torch":
83+
raise ValueError(
84+
"The torch profiler is not enabled. Please provide profiler_config."
85+
)
8886

8987
# Lazy import to avoid importing LLM when the bench command is not selected.
9088
from vllm import LLM, SamplingParams
@@ -144,7 +142,7 @@ def run_to_completion(profile_dir: str | None = None):
144142
run_to_completion(profile_dir=None)
145143

146144
if args.profile:
147-
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
145+
profile_dir = engine_args.profiler_config.torch_profiler_dir
148146
print(f"Profiling (results will be saved to '{profile_dir}')...")
149147
run_to_completion(profile_dir=profile_dir)
150148
return

vllm/benchmarks/serve.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,8 +1097,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
10971097
parser.add_argument(
10981098
"--profile",
10991099
action="store_true",
1100-
help="Use Torch Profiler. The endpoint must be launched with "
1101-
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
1100+
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
11021101
)
11031102
parser.add_argument(
11041103
"--save-result",

vllm/benchmarks/throughput.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
655655
"--profile",
656656
action="store_true",
657657
default=False,
658-
help="Use Torch Profiler. The env variable "
659-
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
658+
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
660659
)
661660

662661
# prefix repetition dataset

vllm/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.config.observability import ObservabilityConfig
2525
from vllm.config.parallel import EPLBConfig, ParallelConfig
2626
from vllm.config.pooler import PoolerConfig
27+
from vllm.config.profiler import ProfilerConfig
2728
from vllm.config.scheduler import SchedulerConfig
2829
from vllm.config.speculative import SpeculativeConfig
2930
from vllm.config.speech_to_text import SpeechToTextConfig
@@ -89,6 +90,8 @@
8990
"SpeechToTextConfig",
9091
# From vllm.config.structured_outputs
9192
"StructuredOutputsConfig",
93+
# From vllm.config.profiler
94+
"ProfilerConfig",
9295
# From vllm.config.utils
9396
"ConfigType",
9497
"SupportsMetricsInfo",

0 commit comments

Comments
 (0)