mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:05:52 +08:00
Add TP parameter to attention tests (#27683)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
786030721e
commit
01baefe674
@ -347,8 +347,7 @@ steps:
|
||||
- vllm/v1/attention
|
||||
- tests/v1/attention
|
||||
commands:
|
||||
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
|
||||
- pytest -v -s v1/attention
|
||||
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
|
||||
|
||||
- label: V1 Test others (CPU) # 5 mins
|
||||
source_file_dependencies:
|
||||
|
||||
@ -295,6 +295,7 @@ def _test_backend_correctness(
|
||||
block_size: int = 16,
|
||||
atol: float = 1e-2,
|
||||
rtol: float = 1e-2,
|
||||
tensor_parallel_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
@ -310,13 +311,38 @@ def _test_backend_correctness(
|
||||
4. Running each vLLM attention backend with the new queries and the
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
|
||||
Note: When tensor_parallel_size > 1, we simulate the head partitioning
|
||||
by overriding the model config to use fewer heads, without requiring
|
||||
multiple GPUs. This tests that backends work correctly with different
|
||||
head counts.
|
||||
"""
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
hf_config_override = None
|
||||
if tensor_parallel_size > 1:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
temp_config = ModelConfig(model=model, max_model_len=1)
|
||||
original_num_heads = temp_config.hf_text_config.num_attention_heads
|
||||
original_num_kv_heads = getattr(
|
||||
temp_config.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
hf_config_override = {
|
||||
"num_attention_heads": original_num_heads // tensor_parallel_size,
|
||||
}
|
||||
if original_num_kv_heads is not None:
|
||||
hf_config_override["num_key_value_heads"] = max(
|
||||
1, original_num_kv_heads // tensor_parallel_size
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192,
|
||||
hf_config_override=hf_config_override,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
@ -503,7 +529,10 @@ def _test_backend_correctness(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_causal_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
|
||||
def causal_mask_mod(
|
||||
@ -523,12 +552,23 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
SMALL_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
|
||||
batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
@ -545,7 +585,10 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_sliding_window_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
|
||||
def sliding_window_mask_mod(
|
||||
@ -575,7 +618,11 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
|
||||
batch_spec,
|
||||
model,
|
||||
SMALL_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
@ -586,4 +633,5 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
@ -394,8 +394,11 @@ def run_attention_backend(
|
||||
"spec_decode_medium",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
|
||||
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
|
||||
def test_backend_correctness(
|
||||
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
using torch.nn.functional.scaled_dot_product_attention.
|
||||
@ -410,6 +413,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
4. Running each vLLM attention backend with the new queries and the
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
|
||||
Note: When tensor_parallel_size > 1, we simulate the head partitioning
|
||||
by overriding the model config to use fewer heads, without requiring
|
||||
multiple GPUs. This tests that backends work correctly with different
|
||||
head counts.
|
||||
"""
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
@ -423,11 +431,30 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# Add 1 for null block at index 0, and some buffer
|
||||
num_gpu_blocks = required_blocks + 1 + 100
|
||||
|
||||
hf_config_override = None
|
||||
if tensor_parallel_size > 1:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
temp_config = ModelConfig(model=model, max_model_len=1)
|
||||
original_num_heads = temp_config.hf_text_config.num_attention_heads
|
||||
original_num_kv_heads = getattr(
|
||||
temp_config.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
hf_config_override = {
|
||||
"num_attention_heads": original_num_heads // tensor_parallel_size,
|
||||
}
|
||||
if original_num_kv_heads is not None:
|
||||
hf_config_override["num_key_value_heads"] = max(
|
||||
1, original_num_kv_heads // tensor_parallel_size
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
block_size=default_block_size,
|
||||
hf_config_override=hf_config_override,
|
||||
)
|
||||
|
||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||
|
||||
@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
@ -135,8 +138,11 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
|
||||
# The test simulates head partitioning via mocked methods below.
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size,
|
||||
@ -156,7 +162,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config
|
||||
lambda self, parallel_config: max(1, num_heads // tensor_parallel_size),
|
||||
model_config,
|
||||
)
|
||||
model_config.get_num_kv_heads = MethodType(
|
||||
lambda self, parallel_config: 1, model_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user