mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:15:01 +08:00
[bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (#8474)
This commit is contained in:
parent
18ae428a0d
commit
9e5ec35b1f
@ -1,6 +1,6 @@
|
|||||||
"""Attention layer ROCm GPUs."""
|
"""Attention layer ROCm GPUs."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -15,6 +15,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PARTITION_SIZE_ROCM = 512
|
_PARTITION_SIZE_ROCM = 512
|
||||||
@ -180,6 +183,59 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
)
|
)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
|
block_size: int, num_seqs: int, num_queries: int):
|
||||||
|
"""
|
||||||
|
Update metadata in-place to advance one decode step.
|
||||||
|
"""
|
||||||
|
# When using cudagraph, the num_seqs is padded to the next captured
|
||||||
|
# batch sized, but num_queries tracks the actual number of requests in
|
||||||
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||||
|
if num_seqs != num_queries:
|
||||||
|
assert num_seqs > num_queries
|
||||||
|
assert self.use_cuda_graph
|
||||||
|
|
||||||
|
assert self.num_prefills == 0
|
||||||
|
assert self.num_prefill_tokens == 0
|
||||||
|
assert self.num_decode_tokens == num_seqs
|
||||||
|
assert self.slot_mapping.shape == (num_seqs, )
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert len(self.seq_lens) == num_seqs
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||||
|
assert self.max_query_len == 1
|
||||||
|
assert self.max_prefill_seq_len == 0
|
||||||
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||||
|
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||||
|
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.context_lens_tensor.shape == (num_queries, )
|
||||||
|
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.block_tables.shape[0] == num_seqs
|
||||||
|
|
||||||
|
# Update query lengths. Note that we update only queries and not seqs,
|
||||||
|
# since tensors may be padded due to captured cuda graph batch size
|
||||||
|
for i in range(num_queries):
|
||||||
|
self.seq_lens[i] += 1
|
||||||
|
self.max_decode_seq_len = max(self.seq_lens)
|
||||||
|
|
||||||
|
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||||
|
num_queries=num_queries,
|
||||||
|
block_size=block_size,
|
||||||
|
input_tokens=model_input.input_tokens,
|
||||||
|
sampled_token_ids=sampled_token_ids,
|
||||||
|
input_positions=model_input.input_positions,
|
||||||
|
seq_lens=self.seq_lens_tensor,
|
||||||
|
slot_mapping=self.slot_mapping,
|
||||||
|
block_tables=self.block_tables)
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionMetadataBuilder(
|
class ROCmFlashAttentionMetadataBuilder(
|
||||||
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
|
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
|
||||||
|
|
||||||
|
|
||||||
def seq_output_builder():
|
def seq_output_builder():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user