diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 22af2d57f4f3d..c5baa66cbeb07 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -156,6 +156,20 @@ def test_full_graph( ) for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") + ] + + [ + # Test get_raw_stream patch with compile_sizes + # This tests that TorchInductor autotune works correctly with get_raw_stream + # patch in torch 2.9 and without patch in torch 2.10+ + ( + CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream + cudagraph_mode=CUDAGraphMode.NONE, + ), + "facebook/opt-125m", + {}, + ), ], ) # only test some of the models diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 6435d87ba7631..5f9d4ac53b80e 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -13,7 +13,10 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils.torch_utils import _is_torch_equal_or_newer +from vllm.utils.torch_utils import ( + _is_torch_equal_or_newer, + is_torch_equal, +) # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -28,6 +31,29 @@ def test_version(): assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") +def test_get_raw_stream_patch(): + """Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1.""" + import builtins + + # Check if get_raw_stream exists in builtins + has_patch = hasattr(builtins, "get_raw_stream") + + # Import torch to get actual version + + is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") + + if is_torch_2_9: + # For torch 2.9.x, the patch should be applied + assert has_patch, "get_raw_stream should be patched for torch 2.9.x" + # Verify it's callable (it should be the _cuda_getCurrentRawStream function) + get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined] + assert callable(get_raw_stream) + # Verify it's the correct function from torch._C + from torch._C import _cuda_getCurrentRawStream + + assert get_raw_stream is _cuda_getCurrentRawStream + + def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) diff --git a/vllm/env_override.py b/vllm/env_override.py index 9ae1af3af46cf..474ac69919eb1 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -363,6 +363,30 @@ def _update_scheduler_patched(self) -> None: self.scheduler = Scheduler(self.operations) +# =================================================== +# torch 2.9 Inductor get_raw_stream workaround +# =================================================== +# Workaround for TorchInductor autotune using get_raw_stream() without defining it. +# This occurs when compile_sizes > 1 in compilation_config. +# For more context, see https://github.com/vllm-project/vllm/issues/30905. +def _patch_get_raw_stream_if_needed(): + """Workaround for TorchInductor autotune get_raw_stream() bug.""" + from vllm.utils.torch_utils import is_torch_equal + + # Only apply the patch for torch 2.9.0 or 2.9.1 + if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"): + import builtins + + # Check if CUDA functionality is available without initializing CUDA + # _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch + if hasattr(torch._C, "_cuda_getCurrentRawStream"): + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + + builtins.get_raw_stream = _get_raw_stream + + +_patch_get_raw_stream_if_needed() + if is_torch_equal("2.9.0"): from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.graph import GraphLowering