mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 16:58:03 +08:00
Merge 3306d9bb92d072bc3ae84e052cb3c40046281c49 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
1b2d428864
@ -156,6 +156,20 @@ def test_full_graph(
|
|||||||
)
|
)
|
||||||
for model_info in models_list(all=False)
|
for model_info in models_list(all=False)
|
||||||
if is_torch_equal_or_newer("2.9.0.dev")
|
if is_torch_equal_or_newer("2.9.0.dev")
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
# Test get_raw_stream patch with compile_sizes
|
||||||
|
# This tests that TorchInductor autotune works correctly with get_raw_stream
|
||||||
|
# patch in torch 2.9 and without patch in torch 2.10+
|
||||||
|
(
|
||||||
|
CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream
|
||||||
|
cudagraph_mode=CUDAGraphMode.NONE,
|
||||||
|
),
|
||||||
|
"facebook/opt-125m",
|
||||||
|
{},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# only test some of the models
|
# only test some of the models
|
||||||
|
|||||||
@ -13,7 +13,10 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo
|
|||||||
from vllm.config.compilation import CompilationMode, PassConfig
|
from vllm.config.compilation import CompilationMode, PassConfig
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
from vllm.utils.torch_utils import (
|
||||||
|
_is_torch_equal_or_newer,
|
||||||
|
is_torch_equal,
|
||||||
|
)
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from . import silly_attention # noqa: F401
|
from . import silly_attention # noqa: F401
|
||||||
@ -28,6 +31,29 @@ def test_version():
|
|||||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_raw_stream_patch():
|
||||||
|
"""Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
# Check if get_raw_stream exists in builtins
|
||||||
|
has_patch = hasattr(builtins, "get_raw_stream")
|
||||||
|
|
||||||
|
# Import torch to get actual version
|
||||||
|
|
||||||
|
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
|
||||||
|
|
||||||
|
if is_torch_2_9:
|
||||||
|
# For torch 2.9.x, the patch should be applied
|
||||||
|
assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
|
||||||
|
# Verify it's callable (it should be the _cuda_getCurrentRawStream function)
|
||||||
|
get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined]
|
||||||
|
assert callable(get_raw_stream)
|
||||||
|
# Verify it's the correct function from torch._C
|
||||||
|
from torch._C import _cuda_getCurrentRawStream
|
||||||
|
|
||||||
|
assert get_raw_stream is _cuda_getCurrentRawStream
|
||||||
|
|
||||||
|
|
||||||
def test_copy_pass():
|
def test_copy_pass():
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
|||||||
@ -363,6 +363,30 @@ def _update_scheduler_patched(self) -> None:
|
|||||||
self.scheduler = Scheduler(self.operations)
|
self.scheduler = Scheduler(self.operations)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================
|
||||||
|
# torch 2.9 Inductor get_raw_stream workaround
|
||||||
|
# ===================================================
|
||||||
|
# Workaround for TorchInductor autotune using get_raw_stream() without defining it.
|
||||||
|
# This occurs when compile_sizes > 1 in compilation_config.
|
||||||
|
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
|
||||||
|
def _patch_get_raw_stream_if_needed():
|
||||||
|
"""Workaround for TorchInductor autotune get_raw_stream() bug."""
|
||||||
|
from vllm.utils.torch_utils import is_torch_equal
|
||||||
|
|
||||||
|
# Only apply the patch for torch 2.9.0 or 2.9.1
|
||||||
|
if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"):
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
# Check if CUDA functionality is available without initializing CUDA
|
||||||
|
# _cuda_getCurrentRawStream only exists in CUDA builds of PyTorch
|
||||||
|
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
|
||||||
|
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
||||||
|
|
||||||
|
builtins.get_raw_stream = _get_raw_stream
|
||||||
|
|
||||||
|
|
||||||
|
_patch_get_raw_stream_if_needed()
|
||||||
|
|
||||||
if is_torch_equal("2.9.0"):
|
if is_torch_equal("2.9.0"):
|
||||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||||
from torch._inductor.graph import GraphLowering
|
from torch._inductor.graph import GraphLowering
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user