fix and add unit test

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
baonudesifeizhai 2025-12-23 14:13:20 -05:00
parent 0b5e466c8d
commit e713ba4039
2 changed files with 59 additions and 10 deletions

View File

@ -13,7 +13,11 @@ from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmCo
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
from vllm.utils.torch_utils import (
_is_torch_equal_or_newer,
is_torch_equal,
is_torch_equal_or_newer,
)
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
@ -29,7 +33,7 @@ def test_version():
def test_get_raw_stream_patch():
"""Test that get_raw_stream patch is applied only for torch 2.9.x."""
"""Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
import builtins
# Check if get_raw_stream exists in builtins
@ -37,11 +41,7 @@ def test_get_raw_stream_patch():
# Import torch to get actual version
from vllm.utils.torch_utils import is_torch_equal_or_newer
is_torch_2_9 = is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer(
"2.10.0"
)
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
if is_torch_2_9:
# For torch 2.9.x, the patch should be applied
@ -55,6 +55,55 @@ def test_get_raw_stream_patch():
assert get_raw_stream is _cuda_getCurrentRawStream
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="compile_sizes with autotune requires CUDA",
)
def test_get_raw_stream_patch_e2e(vllm_runner, monkeypatch):
"""
E2E test to verify get_raw_stream patch works correctly with compile_sizes.
This test should run in both torch 2.9 and 2.10 to ensure no regression.
When compile_sizes > 1, TorchInductor autotune uses get_raw_stream().
In torch 2.9.0/2.9.1, this function needs to be patched,
but in 2.10+ it should work without patch.
"""
import torch
# Verify torch version >= 2.9.0
if not is_torch_equal_or_newer("2.9.0"):
pytest.skip(f"Test requires torch >= 2.9.0, got {torch.__version__}")
# Determine version context for logging
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
version_context = (
"2.9.x (patch applied)" if is_torch_2_9 else "2.10+ (patch not needed)"
)
print(f"Running test with torch {torch.__version__} ({version_context})")
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = {
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [1, 2], # Triggers autotune which uses get_raw_stream
"cudagraph_mode": CUDAGraphMode.NONE,
}
with vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as llm:
from vllm import SamplingParams
outputs = llm.generate(
["Hello, my name is"], SamplingParams(temperature=0, max_tokens=5)
)
assert len(outputs) == 1
assert len(outputs[0].outputs) > 0
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)

View File

@ -371,10 +371,10 @@ def _update_scheduler_patched(self) -> None:
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
def _patch_get_raw_stream_if_needed():
"""Workaround for TorchInductor autotune get_raw_stream() bug."""
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal
# Only apply the patch for torch 2.9.x versions
if is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer("2.10.0"):
# Only apply the patch for torch 2.9.0 or 2.9.1
if is_torch_equal("2.9.0") or is_torch_equal("2.9.1"):
import builtins
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream