Add TP parameter to attention tests (#27683)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-11-03 16:04:40 -05:00 committed by GitHub
parent 786030721e
commit 01baefe674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 11 deletions

View File

@ -347,8 +347,7 @@ steps:
- vllm/v1/attention - vllm/v1/attention
- tests/v1/attention - tests/v1/attention
commands: commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins - label: V1 Test others (CPU) # 5 mins
source_file_dependencies: source_file_dependencies:

View File

@ -295,6 +295,7 @@ def _test_backend_correctness(
block_size: int = 16, block_size: int = 16,
atol: float = 1e-2, atol: float = 1e-2,
rtol: float = 1e-2, rtol: float = 1e-2,
tensor_parallel_size: int = 1,
): ):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
@ -310,13 +311,38 @@ def _test_backend_correctness(
4. Running each vLLM attention backend with the new queries and the 4. Running each vLLM attention backend with the new queries and the
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
Note: When tensor_parallel_size > 1, we simulate the head partitioning
by overriding the model config to use fewer heads, without requiring
multiple GPUs. This tests that backends work correctly with different
head counts.
""" """
current_platform.seed_everything(42) current_platform.seed_everything(42)
hf_config_override = None
if tensor_parallel_size > 1:
from vllm.config import ModelConfig
temp_config = ModelConfig(model=model, max_model_len=1)
original_num_heads = temp_config.hf_text_config.num_attention_heads
original_num_kv_heads = getattr(
temp_config.hf_text_config, "num_key_value_heads", None
)
hf_config_override = {
"num_attention_heads": original_num_heads // tensor_parallel_size,
}
if original_num_kv_heads is not None:
hf_config_override["num_key_value_heads"] = max(
1, original_num_kv_heads // tensor_parallel_size
)
vllm_config = create_vllm_config( vllm_config = create_vllm_config(
model_name=model, model_name=model,
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
max_model_len=max(batch_spec.seq_lens), max_model_len=max(batch_spec.seq_lens),
block_size=block_size, block_size=block_size,
num_gpu_blocks=8192, num_gpu_blocks=8192,
hf_config_override=hf_config_override,
) )
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -503,7 +529,10 @@ def _test_backend_correctness(
], ],
) )
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str): @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention.""" """Test backend's correctness with causal attention."""
def causal_mask_mod( def causal_mask_mod(
@ -523,12 +552,23 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
] ]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) _test_backend_correctness(
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
causal_mask_mod,
tensor_parallel_size=tensor_parallel_size,
)
# Fast FlexAttention needs to run with block_size=128 # Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS: if LARGE_BLOCK_BACKENDS:
_test_backend_correctness( _test_backend_correctness(
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128,
tensor_parallel_size=tensor_parallel_size,
) )
@ -545,7 +585,10 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
) )
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sliding_window_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention.""" """Test backend's correctness with sliding window attention."""
def sliding_window_mask_mod( def sliding_window_mask_mod(
@ -575,7 +618,11 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
] ]
_test_backend_correctness( _test_backend_correctness(
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn batch_spec,
model,
SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
tensor_parallel_size=tensor_parallel_size,
) )
# Fast FlexAttention needs to run with block_size=128 # Fast FlexAttention needs to run with block_size=128
@ -586,4 +633,5 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
LARGE_BLOCK_BACKENDS, LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn, sliding_window_mask_mod_fn,
block_size=128, block_size=128,
tensor_parallel_size=tensor_parallel_size,
) )

View File

@ -394,8 +394,11 @@ def run_attention_backend(
"spec_decode_medium", "spec_decode_medium",
], ],
) )
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
def test_backend_correctness(dist_init, batch_spec_name: str, model: str): @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention. using torch.nn.functional.scaled_dot_product_attention.
@ -410,6 +413,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
4. Running each vLLM attention backend with the new queries and the 4. Running each vLLM attention backend with the new queries and the
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
Note: When tensor_parallel_size > 1, we simulate the head partitioning
by overriding the model config to use fewer heads, without requiring
multiple GPUs. This tests that backends work correctly with different
head counts.
""" """
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
@ -423,11 +431,30 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# Add 1 for null block at index 0, and some buffer # Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100 num_gpu_blocks = required_blocks + 1 + 100
hf_config_override = None
if tensor_parallel_size > 1:
from vllm.config import ModelConfig
temp_config = ModelConfig(model=model, max_model_len=1)
original_num_heads = temp_config.hf_text_config.num_attention_heads
original_num_kv_heads = getattr(
temp_config.hf_text_config, "num_key_value_heads", None
)
hf_config_override = {
"num_attention_heads": original_num_heads // tensor_parallel_size,
}
if original_num_kv_heads is not None:
hf_config_override["num_key_value_heads"] = max(
1, original_num_kv_heads // tensor_parallel_size
)
vllm_config = create_vllm_config( vllm_config = create_vllm_config(
model_name=model, model_name=model,
tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements
max_model_len=max(batch_spec.seq_lens), max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
block_size=default_block_size, block_size=default_block_size,
hf_config_override=hf_config_override,
) )
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold # For spec decode tests, add a speculative_config to set the reorder_batch_threshold

View File

@ -113,7 +113,10 @@ def _quantize_dequantize_fp8_ds_mla(
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test") pytest.skip("CUDA is required for sparse MLA decode test")
@ -135,8 +138,11 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
total_cache_tokens = sum(batch_spec.seq_lens) total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64 block_size = 64
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
# The test simulates head partitioning via mocked methods below.
vllm_config = create_vllm_config( vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=1,
max_model_len=max_seqlen, max_model_len=max_seqlen,
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size, block_size=block_size,
@ -156,7 +162,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
) )
model_config.dtype = dtype model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType( model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config lambda self, parallel_config: max(1, num_heads // tensor_parallel_size),
model_config,
) )
model_config.get_num_kv_heads = MethodType( model_config.get_num_kv_heads = MethodType(
lambda self, parallel_config: 1, model_config lambda self, parallel_config: 1, model_config