mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:05:01 +08:00
[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)
Signed-off-by: quanliu <18646313696@163.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
5c213d2899
commit
5dcd593baf
@ -10,6 +10,7 @@ from utils import (
|
||||
BACKENDS,
|
||||
_extract_step_logprobs,
|
||||
_random_prompt,
|
||||
is_device_capability_below_90,
|
||||
resolve_model_name,
|
||||
skip_unsupported,
|
||||
)
|
||||
@ -17,6 +18,8 @@ from utils import (
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use a few test prompts
|
||||
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
@ -11,8 +11,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
|
||||
# Supports testing on Ampere and Ada Lovelace devices.
|
||||
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
|
||||
reason="Requires CUDA and >= Ampere (SM80)",
|
||||
)
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def is_device_capability_below_90() -> bool:
|
||||
return not current_platform.has_device_capability(90)
|
||||
|
||||
@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
|
||||
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if current_platform.is_device_capability(100):
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user