mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 12:15:29 +08:00
[TPU] support attention head dim smaller than 128 (#19620)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b692e9cd07
commit
a77aea59fd
@ -67,6 +67,43 @@ def test_basic(
|
|||||||
assert "1024" in output or "0, 1" in output
|
assert "1024" in output or "0, 1" in output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This is a basic test for TPU only")
|
||||||
|
@pytest.mark.parametrize("max_tokens", [8])
|
||||||
|
@pytest.mark.parametrize("max_num_seqs", [16])
|
||||||
|
def test_phi3(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
max_tokens: int,
|
||||||
|
max_num_seqs: int,
|
||||||
|
) -> None:
|
||||||
|
prompts = [
|
||||||
|
"A robot may not injure a human being",
|
||||||
|
"It is only with the heart that one can see rightly;",
|
||||||
|
"The greatest glory in living lies not in never falling,",
|
||||||
|
]
|
||||||
|
answers = [
|
||||||
|
" or, by violating privacy",
|
||||||
|
" what is essential is love.",
|
||||||
|
" but in rising every time we fall.",
|
||||||
|
]
|
||||||
|
# test head dim = 96
|
||||||
|
model = "microsoft/Phi-3-mini-128k-instruct"
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_batched_tokens=256,
|
||||||
|
max_num_seqs=max_num_seqs) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||||
|
# vllm_outputs is a list of tuples whose first element is the token id
|
||||||
|
# and the second element is the output (including the prompt).
|
||||||
|
for output, answer in zip(vllm_outputs, answers):
|
||||||
|
generated_text = output[1]
|
||||||
|
assert answer in generated_text
|
||||||
|
|
||||||
|
|
||||||
TP_SIZE_8 = 8
|
TP_SIZE_8 = 8
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,9 @@ from vllm.utils import cdiv, next_power_of_2
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# TPU requires the head size to be a multiple of 128.
|
||||||
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -43,6 +46,14 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
|
padded_head_size = cdiv(
|
||||||
|
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
|
num_blocks = num_blocks * head_size // padded_head_size
|
||||||
|
if padded_head_size != head_size:
|
||||||
|
logger.warning_once(
|
||||||
|
"head size is padded to %d, and num_blocks is adjusted to %d"
|
||||||
|
" accordingly", padded_head_size, num_blocks)
|
||||||
|
head_size = padded_head_size
|
||||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -132,8 +143,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
if head_size % 128 != 0:
|
|
||||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
raise NotImplementedError("Alibi slopes is not supported.")
|
raise NotImplementedError("Alibi slopes is not supported.")
|
||||||
if kv_cache_dtype != "auto":
|
if kv_cache_dtype != "auto":
|
||||||
@ -187,6 +196,18 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||||
|
padded_head_size = cdiv(
|
||||||
|
self.head_size,
|
||||||
|
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, padded_head_size - self.head_size), value=0.0)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, padded_head_size - self.head_size), value=0.0)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, padded_head_size - self.head_size), value=0.0)
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||||
# Write input keys and values to the KV cache.
|
# Write input keys and values to the KV cache.
|
||||||
@ -213,6 +234,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
soft_cap=self.logits_soft_cap,
|
soft_cap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||||
|
output = output[:, :, :self.head_size]
|
||||||
|
|
||||||
return output.reshape(num_tokens, hidden_size)
|
return output.reshape(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
@ -231,11 +255,8 @@ def write_to_kv_cache(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||||
num_kv_heads = num_combined_kv_heads // 2
|
head_size = cdiv(head_size,
|
||||||
|
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||||
head_size)
|
head_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user