|
5 | 5 | import pytest |
6 | 6 | import torch |
7 | 7 |
|
8 | | -from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass |
| 8 | +from vllm.compilation.inductor_pass import ( |
| 9 | + CallableInductorPass, |
| 10 | + InductorPass, |
| 11 | + pass_context, |
| 12 | +) |
9 | 13 | from vllm.compilation.pass_manager import PostGradPassManager |
10 | 14 | from vllm.config import ModelConfig, VllmConfig |
| 15 | +from vllm.config.utils import Range |
11 | 16 |
|
12 | 17 |
|
13 | 18 | # dummy custom pass that doesn't inherit |
@@ -42,35 +47,37 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: |
42 | 47 | ], |
43 | 48 | ) |
44 | 49 | def test_pass_manager_uuid(callable): |
45 | | - # Some passes need dtype to be set |
46 | | - config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) |
47 | | - |
48 | | - pass_manager = PostGradPassManager() |
49 | | - pass_manager.configure(config) |
50 | | - |
51 | | - # Check that UUID is different if the same pass is added 2x |
52 | | - pass_manager.add(callable) |
53 | | - uuid1 = pass_manager.uuid() |
54 | | - pass_manager.add(callable) |
55 | | - uuid2 = pass_manager.uuid() |
56 | | - assert uuid1 != uuid2 |
57 | | - |
58 | | - # UUID should be the same as the original one, |
59 | | - # as we constructed in the same way. |
60 | | - pass_manager2 = PostGradPassManager() |
61 | | - pass_manager2.configure(config) |
62 | | - pass_manager2.add(callable) |
63 | | - assert uuid1 == pass_manager2.uuid() |
64 | | - |
65 | | - # UUID should be different due to config change |
66 | | - config2 = copy.deepcopy(config) |
67 | | - config2.compilation_config.pass_config.fuse_norm_quant = ( |
68 | | - not config2.compilation_config.pass_config.fuse_norm_quant |
69 | | - ) |
70 | | - config2.compilation_config.pass_config.fuse_act_quant = ( |
71 | | - not config2.compilation_config.pass_config.fuse_act_quant |
72 | | - ) |
73 | | - pass_manager3 = PostGradPassManager() |
74 | | - pass_manager3.configure(config2) |
75 | | - pass_manager3.add(callable) |
76 | | - assert uuid1 != pass_manager3.uuid() |
| 50 | + # Set the pass context as PassManager uuid uses it |
| 51 | + with pass_context(Range(start=1, end=8)): |
| 52 | + # Some passes need dtype to be set |
| 53 | + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) |
| 54 | + |
| 55 | + pass_manager = PostGradPassManager() |
| 56 | + pass_manager.configure(config) |
| 57 | + |
| 58 | + # Check that UUID is different if the same pass is added 2x |
| 59 | + pass_manager.add(callable) |
| 60 | + uuid1 = pass_manager.uuid() |
| 61 | + pass_manager.add(callable) |
| 62 | + uuid2 = pass_manager.uuid() |
| 63 | + assert uuid1 != uuid2 |
| 64 | + |
| 65 | + # UUID should be the same as the original one, |
| 66 | + # as we constructed in the same way. |
| 67 | + pass_manager2 = PostGradPassManager() |
| 68 | + pass_manager2.configure(config) |
| 69 | + pass_manager2.add(callable) |
| 70 | + assert uuid1 == pass_manager2.uuid() |
| 71 | + |
| 72 | + # UUID should be different due to config change |
| 73 | + config2 = copy.deepcopy(config) |
| 74 | + config2.compilation_config.pass_config.fuse_norm_quant = ( |
| 75 | + not config2.compilation_config.pass_config.fuse_norm_quant |
| 76 | + ) |
| 77 | + config2.compilation_config.pass_config.fuse_act_quant = ( |
| 78 | + not config2.compilation_config.pass_config.fuse_act_quant |
| 79 | + ) |
| 80 | + pass_manager3 = PostGradPassManager() |
| 81 | + pass_manager3.configure(config2) |
| 82 | + pass_manager3.add(callable) |
| 83 | + assert uuid1 != pass_manager3.uuid() |
0 commit comments