mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 16:57:12 +08:00
Merge 3306d9bb92d072bc3ae84e052cb3c40046281c49 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
1b2d428864
@ -156,6 +156,20 @@ def test_full_graph(
|
||||
)
|
||||
for model_info in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
]
|
||||
+ [
|
||||
# Test get_raw_stream patch with compile_sizes
|
||||
# This tests that TorchInductor autotune works correctly with get_raw_stream
|
||||
# patch in torch 2.9 and without patch in torch 2.10+
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
),
|
||||
"facebook/opt-125m",
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
# only test some of the models
|
||||
|
||||
@ -13,7 +13,10 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import (
|
||||
_is_torch_equal_or_newer,
|
||||
is_torch_equal,
|
||||
)
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
@ -28,6 +31,29 @@ def test_version():
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_get_raw_stream_patch():
|
||||
"""Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
|
||||
import builtins
|
||||
|
||||
# Check if get_raw_stream exists in builtins
|
||||
has_patch = hasattr(builtins, "get_raw_stream")
|
||||
|
||||
# Import torch to get actual version
|
||||
|
||||
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
|
||||
|
||||
if is_torch_2_9:
|
||||
# For torch 2.9.x, the patch should be applied
|
||||
assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
|
||||
# Verify it's callable (it should be the _cuda_getCurrentRawStream function)
|
||||
get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined]
|
||||
assert callable(get_raw_stream)
|
||||
# Verify it's the correct function from torch._C
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
assert get_raw_stream is _cuda_getCurrentRawStream
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
@ -363,6 +363,30 @@ def _update_scheduler_patched(self) -> None:
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
|
||||
|
||||
# ===================================================
|
||||
# torch 2.9 Inductor get_raw_stream workaround
|
||||
# ===================================================
|
||||
# Workaround for TorchInductor autotune using get_raw_stream() without defining it.
|
||||
# This occurs when compile_sizes > 1 in compilation_config.
|
||||
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
|
||||
def _patch_get_raw_stream_if_needed():
|
||||
"""Workaround for TorchInductor autotune get_raw_stream() bug."""
|
||||
from vllm.utils.torch_utils import is_torch_equal
|
||||
|
||||
# Only apply the patch for torch 2.9.0 or 2.9.1
|
||||
if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"):
|
||||
import builtins
|
||||
|
||||
# Check if CUDA functionality is available without initializing CUDA
|
||||
# _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch
|
||||
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
|
||||
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
||||
|
||||
builtins.get_raw_stream = _get_raw_stream
|
||||
|
||||
|
||||
_patch_get_raw_stream_if_needed()
|
||||
|
||||
if is_torch_equal("2.9.0"):
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user