mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 22:37:09 +08:00
fix
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
parent
e713ba4039
commit
919ad836d3
@ -13,7 +13,10 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import (
|
||||
is_torch_equal,
|
||||
is_torch_equal_or_newer,
|
||||
)
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
@ -156,6 +159,20 @@ def test_full_graph(
|
||||
)
|
||||
for model_info in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
]
|
||||
+ [
|
||||
# Test get_raw_stream patch with compile_sizes
|
||||
# This tests that TorchInductor autotune works correctly with get_raw_stream
|
||||
# patch in torch 2.9 and without patch in torch 2.10+
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
compile_sizes=[1, 2], # Triggers autotune which uses get_raw_stream
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
),
|
||||
"facebook/opt-125m",
|
||||
{},
|
||||
),
|
||||
],
|
||||
)
|
||||
# only test some of the models
|
||||
@ -178,6 +195,26 @@ def test_custom_compile_config(
|
||||
):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
# For get_raw_stream patch test: verify version and log context
|
||||
if compilation_config.compile_sizes and len(compilation_config.compile_sizes) > 1:
|
||||
# Verify torch version >= 2.9.0 for compile_sizes autotune
|
||||
if not is_torch_equal_or_newer("2.9.0"):
|
||||
pytest.skip(
|
||||
f"compile_sizes autotune requires torch >= 2.9.0, "
|
||||
f"got {torch.__version__}"
|
||||
)
|
||||
|
||||
# Log version context for get_raw_stream patch testing
|
||||
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
|
||||
version_context = (
|
||||
"2.9.x (get_raw_stream patch applied)"
|
||||
if is_torch_2_9
|
||||
else "2.10+ (get_raw_stream patch not needed)"
|
||||
)
|
||||
print(
|
||||
f"Testing compile_sizes with torch {torch.__version__} ({version_context})"
|
||||
)
|
||||
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, **model_kwargs)
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
_is_torch_equal_or_newer,
|
||||
is_torch_equal,
|
||||
is_torch_equal_or_newer,
|
||||
)
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@ -55,55 +54,6 @@ def test_get_raw_stream_patch():
|
||||
assert get_raw_stream is _cuda_getCurrentRawStream
|
||||
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="compile_sizes with autotune requires CUDA",
|
||||
)
|
||||
def test_get_raw_stream_patch_e2e(vllm_runner, monkeypatch):
|
||||
"""
|
||||
E2E test to verify get_raw_stream patch works correctly with compile_sizes.
|
||||
This test should run in both torch 2.9 and 2.10 to ensure no regression.
|
||||
When compile_sizes > 1, TorchInductor autotune uses get_raw_stream().
|
||||
In torch 2.9.0/2.9.1, this function needs to be patched,
|
||||
but in 2.10+ it should work without patch.
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Verify torch version >= 2.9.0
|
||||
if not is_torch_equal_or_newer("2.9.0"):
|
||||
pytest.skip(f"Test requires torch >= 2.9.0, got {torch.__version__}")
|
||||
|
||||
# Determine version context for logging
|
||||
is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")
|
||||
version_context = (
|
||||
"2.9.x (patch applied)" if is_torch_2_9 else "2.10+ (patch not needed)"
|
||||
)
|
||||
print(f"Running test with torch {torch.__version__} ({version_context})")
|
||||
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [1, 2], # Triggers autotune which uses get_raw_stream
|
||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||
}
|
||||
|
||||
with vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4,
|
||||
) as llm:
|
||||
from vllm import SamplingParams
|
||||
|
||||
outputs = llm.generate(
|
||||
["Hello, my name is"], SamplingParams(temperature=0, max_tokens=5)
|
||||
)
|
||||
assert len(outputs) == 1
|
||||
assert len(outputs[0].outputs) > 0
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user