From 9d70afe6c6b49e0708785017dbd85264d0efd4a8 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Wed, 17 Dec 2025 18:13:53 -0500 Subject: [PATCH 1/8] Add workaround for TorchInductor get_raw_stream bug --- vllm/env_override.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/vllm/env_override.py b/vllm/env_override.py index 9ae1af3af46cf..02001209145a4 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -363,6 +363,24 @@ def _update_scheduler_patched(self) -> None: self.scheduler = Scheduler(self.operations) +# =================================================== +# torch 2.9 Inductor get_raw_stream workaround +# =================================================== +# Workaround for TorchInductor autotune using get_raw_stream() without defining it. +# This occurs when compile_sizes > 1 in compilation_config. +# For more context, see https://github.com/vllm-project/vllm/issues/30905. +def _patch_get_raw_stream_if_needed(): + """Workaround for TorchInductor autotune get_raw_stream() bug.""" + if is_torch_equal("2.9.0") and os.getenv("VLLM_PATCH_GET_RAW_STREAM", "1") == "1": + import builtins + + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + + builtins.get_raw_stream = _get_raw_stream + + +_patch_get_raw_stream_if_needed() + if is_torch_equal("2.9.0"): from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.graph import GraphLowering From e8985d97165de85f79b6e939e0382debfefa11e7 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Wed, 17 Dec 2025 18:16:48 -0500 Subject: [PATCH 2/8] Add workaround for TorchInductor get_raw_stream bug --- vllm/env_override.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/env_override.py b/vllm/env_override.py index 02001209145a4..8ff61ef7593c2 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -371,12 +371,11 @@ def _update_scheduler_patched(self) -> None: # For more context, see https://github.com/vllm-project/vllm/issues/30905. def _patch_get_raw_stream_if_needed(): """Workaround for TorchInductor autotune get_raw_stream() bug.""" - if is_torch_equal("2.9.0") and os.getenv("VLLM_PATCH_GET_RAW_STREAM", "1") == "1": - import builtins + import builtins - from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream - builtins.get_raw_stream = _get_raw_stream + builtins.get_raw_stream = _get_raw_stream _patch_get_raw_stream_if_needed() From 0b5e466c8dd9cc74d9dc361abf8d4ccd09ca2683 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Mon, 22 Dec 2025 21:39:34 -0500 Subject: [PATCH 3/8] fix Signed-off-by: baonudesifeizhai --- tests/compile/test_config.py | 27 +++++++++++++++++++++++++++ vllm/env_override.py | 10 +++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 04bb56ecb6470..757d87125c575 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -28,6 +28,33 @@ def test_version(): assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") +def test_get_raw_stream_patch(): + """Test that get_raw_stream patch is applied only for torch 2.9.x.""" + import builtins + + # Check if get_raw_stream exists in builtins + has_patch = hasattr(builtins, "get_raw_stream") + + # Import torch to get actual version + + from vllm.utils.torch_utils import is_torch_equal_or_newer + + is_torch_2_9 = is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer( + "2.10.0" + ) + + if is_torch_2_9: + # For torch 2.9.x, the patch should be applied + assert has_patch, "get_raw_stream should be patched for torch 2.9.x" + # Verify it's callable (it should be the _cuda_getCurrentRawStream function) + get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined] + assert callable(get_raw_stream) + # Verify it's the correct function from torch._C + from torch._C import _cuda_getCurrentRawStream + + assert get_raw_stream is _cuda_getCurrentRawStream + + def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) diff --git a/vllm/env_override.py b/vllm/env_override.py index 8ff61ef7593c2..fb0a70239cf40 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -371,11 +371,15 @@ def _update_scheduler_patched(self) -> None: # For more context, see https://github.com/vllm-project/vllm/issues/30905. def _patch_get_raw_stream_if_needed(): """Workaround for TorchInductor autotune get_raw_stream() bug.""" - import builtins + from vllm.utils.torch_utils import is_torch_equal_or_newer - from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + # Only apply the patch for torch 2.9.x versions + if is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer("2.10.0"): + import builtins - builtins.get_raw_stream = _get_raw_stream + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + + builtins.get_raw_stream = _get_raw_stream _patch_get_raw_stream_if_needed() From e713ba403943fe4fe967b548b798dde013b06c69 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Tue, 23 Dec 2025 14:13:20 -0500 Subject: [PATCH 4/8] fix and add unit test Signed-off-by: baonudesifeizhai --- tests/compile/test_config.py | 63 ++++++++++++++++++++++++++++++++---- vllm/env_override.py | 6 ++-- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 757d87125c575..c5d49f8cfd01e 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -13,7 +13,11 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils.torch_utils import _is_torch_equal_or_newer +from vllm.utils.torch_utils import ( + _is_torch_equal_or_newer, + is_torch_equal, + is_torch_equal_or_newer, +) # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -29,7 +33,7 @@ def test_version(): def test_get_raw_stream_patch(): - """Test that get_raw_stream patch is applied only for torch 2.9.x.""" + """Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1.""" import builtins # Check if get_raw_stream exists in builtins @@ -37,11 +41,7 @@ def test_get_raw_stream_patch(): # Import torch to get actual version - from vllm.utils.torch_utils import is_torch_equal_or_newer - - is_torch_2_9 = is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer( - "2.10.0" - ) + is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") if is_torch_2_9: # For torch 2.9.x, the patch should be applied @@ -55,6 +55,55 @@ def test_get_raw_stream_patch(): assert get_raw_stream is _cuda_getCurrentRawStream +# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 +@pytest.mark.forked +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="compile_sizes with autotune requires CUDA", +) +def test_get_raw_stream_patch_e2e(vllm_runner, monkeypatch): + """ + E2E test to verify get_raw_stream patch works correctly with compile_sizes. + This test should run in both torch 2.9 and 2.10 to ensure no regression. + When compile_sizes > 1, TorchInductor autotune uses get_raw_stream(). + In torch 2.9.0/2.9.1, this function needs to be patched, + but in 2.10+ it should work without patch. + """ + import torch + + # Verify torch version >= 2.9.0 + if not is_torch_equal_or_newer("2.9.0"): + pytest.skip(f"Test requires torch >= 2.9.0, got {torch.__version__}") + + # Determine version context for logging + is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") + version_context = ( + "2.9.x (patch applied)" if is_torch_2_9 else "2.10+ (patch not needed)" + ) + print(f"Running test with torch {torch.__version__} ({version_context})") + + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + compilation_config = { + "mode": CompilationMode.VLLM_COMPILE, + "compile_sizes": [1, 2], # Triggers autotune which uses get_raw_stream + "cudagraph_mode": CUDAGraphMode.NONE, + } + + with vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as llm: + from vllm import SamplingParams + + outputs = llm.generate( + ["Hello, my name is"], SamplingParams(temperature=0, max_tokens=5) + ) + assert len(outputs) == 1 + assert len(outputs[0].outputs) > 0 + + def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) diff --git a/vllm/env_override.py b/vllm/env_override.py index fb0a70239cf40..e091867fd0ca8 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -371,10 +371,10 @@ def _update_scheduler_patched(self) -> None: # For more context, see https://github.com/vllm-project/vllm/issues/30905. def _patch_get_raw_stream_if_needed(): """Workaround for TorchInductor autotune get_raw_stream() bug.""" - from vllm.utils.torch_utils import is_torch_equal_or_newer + from vllm.utils.torch_utils import is_torch_equal - # Only apply the patch for torch 2.9.x versions - if is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer("2.10.0"): + # Only apply the patch for torch 2.9.0 or 2.9.1 + if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"): import builtins from torch._C import _cuda_getCurrentRawStream as _get_raw_stream From 919ad836d3b31f1840f11751c50f104a5a92ef32 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Tue, 23 Dec 2025 14:57:46 -0500 Subject: [PATCH 5/8] fix Signed-off-by: baonudesifeizhai --- tests/compile/fullgraph/test_full_graph.py | 39 ++++++++++++++++- tests/compile/test_config.py | 50 ---------------------- 2 files changed, 38 insertions(+), 51 deletions(-) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 22af2d57f4f3d..87a2f37af540f 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -13,7 +13,10 @@ from vllm import LLM, SamplingParams from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import ( + is_torch_equal, + is_torch_equal_or_newer, +) from ...utils import create_new_process_for_each_test @@ -156,6 +159,20 @@ def test_full_graph( ) for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") + ] + + [ + # Test get_raw_stream patch with compile_sizes + # This tests that TorchInductor autotune works correctly with get_raw_stream + # patch in torch 2.9 and without patch in torch 2.10+ + ( + CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream + cudagraph_mode=CUDAGraphMode.NONE, + ), + "facebook/opt-125m", + {}, + ), ], ) # only test some of the models @@ -178,6 +195,26 @@ def test_custom_compile_config( ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # For get_raw_stream patch test: verify version and log context + if compilation_config.compile_sizes and len(compilation_config.compile_sizes) > 1: + # Verify torch version >= 2.9.0 for compile_sizes autotune + if not is_torch_equal_or_newer("2.9.0"): + pytest.skip( + f"compile_sizes autotune requires torch >= 2.9.0, " + f"got {torch.__version__}" + ) + + # Log version context for get_raw_stream patch testing + is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") + version_context = ( + "2.9.x (get_raw_stream patch applied)" + if is_torch_2_9 + else "2.10+ (get_raw_stream patch not needed)" + ) + print( + f"Testing compile_sizes with torch {torch.__version__} ({version_context})" + ) + print(f"MODEL={model}") run_model(compilation_config, model, **model_kwargs) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index c5d49f8cfd01e..2502b62645371 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -16,7 +16,6 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import ( _is_torch_equal_or_newer, is_torch_equal, - is_torch_equal_or_newer, ) # This import automatically registers `torch.ops.silly.attention` @@ -55,55 +54,6 @@ def test_get_raw_stream_patch(): assert get_raw_stream is _cuda_getCurrentRawStream -# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 -@pytest.mark.forked -@pytest.mark.skipif( - not current_platform.is_cuda(), - reason="compile_sizes with autotune requires CUDA", -) -def test_get_raw_stream_patch_e2e(vllm_runner, monkeypatch): - """ - E2E test to verify get_raw_stream patch works correctly with compile_sizes. - This test should run in both torch 2.9 and 2.10 to ensure no regression. - When compile_sizes > 1, TorchInductor autotune uses get_raw_stream(). - In torch 2.9.0/2.9.1, this function needs to be patched, - but in 2.10+ it should work without patch. - """ - import torch - - # Verify torch version >= 2.9.0 - if not is_torch_equal_or_newer("2.9.0"): - pytest.skip(f"Test requires torch >= 2.9.0, got {torch.__version__}") - - # Determine version context for logging - is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") - version_context = ( - "2.9.x (patch applied)" if is_torch_2_9 else "2.10+ (patch not needed)" - ) - print(f"Running test with torch {torch.__version__} ({version_context})") - - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - - compilation_config = { - "mode": CompilationMode.VLLM_COMPILE, - "compile_sizes": [1, 2], # Triggers autotune which uses get_raw_stream - "cudagraph_mode": CUDAGraphMode.NONE, - } - - with vllm_runner( - "facebook/opt-125m", - compilation_config=compilation_config, - gpu_memory_utilization=0.4, - ) as llm: - from vllm import SamplingParams - - outputs = llm.generate( - ["Hello, my name is"], SamplingParams(temperature=0, max_tokens=5) - ) - assert len(outputs) == 1 - assert len(outputs[0].outputs) > 0 - - def test_copy_pass(): vllm_config = VllmConfig() inductor_pass = FixFunctionalizationPass(vllm_config) From 226c3dc98ad785be149b705aca99e002171a1bde Mon Sep 17 00:00:00 2001 From: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Date: Tue, 23 Dec 2025 15:18:39 -0500 Subject: [PATCH 6/8] Update tests/compile/fullgraph/test_full_graph.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> --- tests/compile/fullgraph/test_full_graph.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 87a2f37af540f..5a3ed71d82df6 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -195,26 +195,6 @@ def test_custom_compile_config( ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - # For get_raw_stream patch test: verify version and log context - if compilation_config.compile_sizes and len(compilation_config.compile_sizes) > 1: - # Verify torch version >= 2.9.0 for compile_sizes autotune - if not is_torch_equal_or_newer("2.9.0"): - pytest.skip( - f"compile_sizes autotune requires torch >= 2.9.0, " - f"got {torch.__version__}" - ) - - # Log version context for get_raw_stream patch testing - is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1") - version_context = ( - "2.9.x (get_raw_stream patch applied)" - if is_torch_2_9 - else "2.10+ (get_raw_stream patch not needed)" - ) - print( - f"Testing compile_sizes with torch {torch.__version__} ({version_context})" - ) - print(f"MODEL={model}") run_model(compilation_config, model, **model_kwargs) From a2774c4d69743180ea763b5938aebf44e9605a25 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Tue, 23 Dec 2025 15:25:43 -0500 Subject: [PATCH 7/8] fix Signed-off-by: baonudesifeizhai --- tests/compile/fullgraph/test_full_graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 5a3ed71d82df6..ffec14d36fd17 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -14,7 +14,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils.torch_utils import ( - is_torch_equal, is_torch_equal_or_newer, ) From ff82fce3b23699be1ea22a87b7b6dfe2682414e5 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Tue, 23 Dec 2025 15:58:39 -0500 Subject: [PATCH 8/8] fix for cpu Signed-off-by: baonudesifeizhai --- tests/compile/fullgraph/test_full_graph.py | 4 +--- vllm/env_override.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index ffec14d36fd17..c5baa66cbeb07 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -13,9 +13,7 @@ from vllm import LLM, SamplingParams from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils.torch_utils import ( - is_torch_equal_or_newer, -) +from vllm.utils.torch_utils import is_torch_equal_or_newer from ...utils import create_new_process_for_each_test diff --git a/vllm/env_override.py b/vllm/env_override.py index e091867fd0ca8..474ac69919eb1 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -377,9 +377,12 @@ def _patch_get_raw_stream_if_needed(): if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"): import builtins - from torch._C import _cuda_getCurrentRawStream as _get_raw_stream + # Check if CUDA functionality is available without initializing CUDA + # _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch + if hasattr(torch._C, "_cuda_getCurrentRawStream"): + from torch._C import _cuda_getCurrentRawStream as _get_raw_stream - builtins.get_raw_stream = _get_raw_stream + builtins.get_raw_stream = _get_raw_stream _patch_get_raw_stream_if_needed()