Merge 3306d9bb92d072bc3ae84e052cb3c40046281c49 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
baonudesifeizhai 2025-12-25 00:06:39 +00:00 committed by GitHub
commit 1b2d428864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 1 deletions

View File

@ -156,6 +156,20 @@ def test_full_graph(
)
for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
]
+ [
# Test get_raw_stream patch with compile_sizes
# This tests that TorchInductor autotune works correctly with get_raw_stream
# patch in torch 2.9 and without patch in torch 2.10+
(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream
cudagraph_mode=CUDAGraphMode.NONE,
),
"facebook/opt-125m",
{},
),
],
)
# only test some of the models

View File

@ -13,7 +13,10 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
from vllm.utils.torch_utils import (
_is_torch_equal_or_newer,
is_torch_equal,
)
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
@ -28,6 +31,29 @@ def test_version():
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_get_raw_stream_patch():
"""Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
import builtins
# Check if get_raw_stream exists in builtins
has_patch = hasattr(builtins, "get_raw_stream")
# Import torch to get actual version
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
if is_torch_2_9:
# For torch 2.9.x, the patch should be applied
assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
# Verify it's callable (it should be the _cuda_getCurrentRawStream function)
get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined]
assert callable(get_raw_stream)
# Verify it's the correct function from torch._C
from torch._C import _cuda_getCurrentRawStream
assert get_raw_stream is _cuda_getCurrentRawStream
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)

View File

@ -363,6 +363,30 @@ def _update_scheduler_patched(self) -> None:
self.scheduler = Scheduler(self.operations)
# ===================================================
# torch 2.9 Inductor get_raw_stream workaround
# ===================================================
# Workaround for TorchInductor autotune using get_raw_stream() without defining it.
# This occurs when compile_sizes > 1 in compilation_config.
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
def _patch_get_raw_stream_if_needed():
"""Workaround for TorchInductor autotune get_raw_stream() bug."""
from vllm.utils.torch_utils import is_torch_equal
# Only apply the patch for torch 2.9.0 or 2.9.1
if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"):
import builtins
# Check if CUDA functionality is available without initializing CUDA
# _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
builtins.get_raw_stream = _get_raw_stream
_patch_get_raw_stream_if_needed()
if is_torch_equal("2.9.0"):
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.graph import GraphLowering