mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:35:32 +08:00
[Test] Test multiple attn backend for chunked prefill. (#4023)
This commit is contained in:
parent
7fd3949a0b
commit
36729bac13
@ -12,7 +12,13 @@ steps:
|
|||||||
command: pytest -v -s async_engine
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
- label: Basic Correctness Test
|
- label: Basic Correctness Test
|
||||||
command: pytest -v -s basic_correctness
|
commands:
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
|
||||||
- label: Core Test
|
- label: Core Test
|
||||||
command: pytest -v -s core
|
command: pytest -v -s core
|
||||||
|
|||||||
@ -4,8 +4,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.attention.selector import VLLM_ATTENTION_BACKEND
|
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
@ -16,7 +14,6 @@ MODELS = [
|
|||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||||
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
|
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -25,10 +22,7 @@ def test_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
attn_backend: str,
|
|
||||||
monkeypatch,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
|
|
||||||
hf_model = hf_runner(model, dtype=dtype)
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
del hf_model
|
del hf_model
|
||||||
|
|||||||
@ -33,10 +33,6 @@ def test_models(
|
|||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
|
|
||||||
and not enforce_eager):
|
|
||||||
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
|
|
||||||
"for high TP to save testing time.")
|
|
||||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||||
enable_chunked_prefill = False
|
enable_chunked_prefill = False
|
||||||
max_num_batched_tokens = None
|
max_num_batched_tokens = None
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
# AMD Radeon 7900 series (gfx1100) currently does not support
|
# AMD Radeon 7900 series (gfx1100) currently does not support
|
||||||
# xFormers nor FlashAttention. As a temporary workaround, we use
|
# xFormers nor FlashAttention. As a temporary workaround, we use
|
||||||
# naive PyTorch implementation of attention.
|
# naive PyTorch implementation of attention.
|
||||||
self.attn_fuc = _naive_attention()
|
self.attn_fuc = _naive_attention
|
||||||
logger.debug("Using naive attention in ROCmBackend")
|
logger.debug("Using naive attention in ROCmBackend")
|
||||||
elif self.use_triton_flash_attn:
|
elif self.use_triton_flash_attn:
|
||||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||||
@ -334,26 +334,21 @@ def _naive_attention(
|
|||||||
prompt_lens: List[int],
|
prompt_lens: List[int],
|
||||||
scale: float,
|
scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = query.shape[0]
|
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
start = 0
|
start = 0
|
||||||
for _, prompt_len in enumerate(prompt_lens):
|
for _, prompt_len in enumerate(prompt_lens):
|
||||||
end = start + prompt_len
|
end = start + prompt_len
|
||||||
out = _naive_masked_attention(
|
out = _naive_masked_attention(
|
||||||
query[None, start:end],
|
query[start:end],
|
||||||
key[None, start:end],
|
key[start:end],
|
||||||
value[None, start:end],
|
value[start:end],
|
||||||
scale,
|
scale,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out)
|
output[start:end].copy_(out)
|
||||||
start += prompt_len
|
start += prompt_len
|
||||||
|
|
||||||
# Using view got RuntimeError: view size is not compatible
|
return output
|
||||||
# with input tensor's size and stride (at least one
|
|
||||||
# dimension spans across two contiguous subspaces).
|
|
||||||
# Use reshape instead.
|
|
||||||
return output.reshape(num_tokens, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def _naive_masked_attention(
|
def _naive_masked_attention(
|
||||||
@ -362,14 +357,13 @@ def _naive_masked_attention(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_len, _, _ = query.shape
|
seq_len, head_size, head_dim = query.shape
|
||||||
attn_mask = torch.triu(torch.ones(seq_len,
|
attn_mask = torch.triu(torch.ones(seq_len,
|
||||||
seq_len,
|
seq_len,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device=query.device),
|
device=query.device),
|
||||||
diagonal=1)
|
diagonal=1)
|
||||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||||
|
|
||||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
attn_weights = attn_weights + attn_mask.float()
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user