From 0b5e466c8dd9cc74d9dc361abf8d4ccd09ca2683 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Mon, 22 Dec 2025 21:39:34 -0500 Subject: [PATCH] fix Signed-off-by: baonudesifeizhai --- tests/compile/test_config.py | 27 +++++++++++++++++++++++++++ vllm/env_override.py | 10 +++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 04bb56ecb6470..757d87125c575 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -28,6 +28,33 @@ def test_version(): assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") +def test_get_raw_stream_patch(): + """Test that get_raw_stream patch is applied only for torch 2.9.x.""" + import builtins + + # Check if get_raw_stream exists in builtins + has_patch = hasattr(builtins, "get_raw_stream") + + # Import torch to get actual version + + from vllm.utils.torch_utils import is_torch_equal_or_newer + + is_torch_2_9 = is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer( + "2.10.0" + ) + + if is_torch_2_9: + # For torch 2.9.x, the patch should be applied + assert has_patch, "get_raw_stream should be patched for torch 2.9.x" + # Verify it's callable (it should be the _cuda_getCurrentRawStream function) + get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined] + assert callable(get_raw_stream) + # Verify it's the correct function from torch._C + from torch._C import _cuda_getCurrentRawStream + + assert get_raw_stream is _cuda_getCurrentRawStream + + def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) diff --git a/vllm/env_override.py b/vllm/env_override.py index 8ff61ef7593c2..fb0a70239cf40 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -371,11 +371,15 @@ def _update_scheduler_patched(self) -> None: # For more context, see https://github.com/vllm-project/vllm/issues/30905. def _patch_get_raw_stream_if_needed(): """Workaround for TorchInductor autotune get_raw_stream() bug.""" - import builtins + from vllm.utils.torch_utils import is_torch_equal_or_newer - from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + # Only apply the patch for torch 2.9.x versions + if is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer("2.10.0"): + import builtins - builtins.get_raw_stream = _get_raw_stream + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + + builtins.get_raw_stream = _get_raw_stream _patch_get_raw_stream_if_needed()