mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 21:15:01 +08:00
[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)
Signed-off-by: quanliu <18646313696@163.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
5c213d2899
commit
5dcd593baf
@ -10,6 +10,7 @@ from utils import (
|
|||||||
BACKENDS,
|
BACKENDS,
|
||||||
_extract_step_logprobs,
|
_extract_step_logprobs,
|
||||||
_random_prompt,
|
_random_prompt,
|
||||||
|
is_device_capability_below_90,
|
||||||
resolve_model_name,
|
resolve_model_name,
|
||||||
skip_unsupported,
|
skip_unsupported,
|
||||||
)
|
)
|
||||||
@ -17,6 +18,8 @@ from utils import (
|
|||||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.timeout(1000)
|
@pytest.mark.timeout(1000)
|
||||||
@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16", # not everything is supported
|
dtype="bfloat16", # not everything is supported
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use more realistic prompts for better token generation
|
# Use more realistic prompts for better token generation
|
||||||
@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
|||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "the capital of france is"
|
prompt = "the capital of france is"
|
||||||
@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
|||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
)
|
)
|
||||||
|
|
||||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||||
@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
|||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use a few test prompts
|
# Use a few test prompts
|
||||||
@ -925,6 +933,8 @@ def LLM_with_max_seqs(
|
|||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||||
# Enable for MOE models
|
# Enable for MOE models
|
||||||
# enable_expert_parallel=True,
|
# enable_expert_parallel=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,8 +11,10 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.flashinfer import has_flashinfer
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
|
|
||||||
skip_unsupported = pytest.mark.skipif(
|
skip_unsupported = pytest.mark.skipif(
|
||||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
|
||||||
reason="Requires CUDA and >= Hopper (SM90)",
|
# Supports testing on Ampere and Ada Lovelace devices.
|
||||||
|
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
|
||||||
|
reason="Requires CUDA and >= Ampere (SM80)",
|
||||||
)
|
)
|
||||||
|
|
||||||
BACKENDS: list[str] = [
|
BACKENDS: list[str] = [
|
||||||
@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
|
|||||||
return t, inner.token_ids
|
return t, inner.token_ids
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def is_device_capability_below_90() -> bool:
|
||||||
|
return not current_platform.has_device_capability(90)
|
||||||
|
|||||||
@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
|
|||||||
|
|
||||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||||
if current_platform.is_device_capability(100):
|
if (
|
||||||
|
current_platform.is_device_capability(100)
|
||||||
|
or current_platform.is_device_capability(80)
|
||||||
|
or current_platform.is_device_capability(89)
|
||||||
|
):
|
||||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user