[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)

Signed-off-by: quanliu <18646313696@163.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
quanliu 2025-12-09 23:01:38 +08:00 committed by GitHub
parent 5c213d2899
commit 5dcd593baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 3 deletions

View File

@ -10,6 +10,7 @@ from utils import (
BACKENDS,
_extract_step_logprobs,
_random_prompt,
is_device_capability_below_90,
resolve_model_name,
skip_unsupported,
)
@ -17,6 +18,8 @@ from utils import (
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
@skip_unsupported
@pytest.mark.timeout(1000)
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_model_len=8192,
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use more realistic prompts for better token generation
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
prompt = "the capital of france is"
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use a few test prompts
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
# Enable for MOE models
# enable_expert_parallel=True,
)

View File

@ -11,8 +11,10 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
# Supports testing on Ampere and Ada Lovelace devices.
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
reason="Requires CUDA and >= Ampere (SM80)",
)
BACKENDS: list[str] = [
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
return t, inner.token_ids
return None, None
def is_device_capability_below_90() -> bool:
return not current_platform.has_device_capability(90)

View File

@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if current_platform.is_device_capability(100):
if (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")