mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:45:01 +08:00
[ROCm][AMD][Model] llama 3.2 support upstreaming (#12421)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
parent
9798b2fb00
commit
a1fc18c030
@ -90,6 +90,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -100,30 +111,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
max_query_len: Optional[int] = None
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
@ -131,6 +130,23 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
@ -141,10 +157,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
@ -159,12 +172,20 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
query_start_loc=None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@ -194,7 +215,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
@ -304,6 +330,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths
|
||||
Encoder attn -> select encoder sequence lengths fields
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensors for query and key
|
||||
* Appropriate max sequence-length scalar
|
||||
'''
|
||||
|
||||
partial_prefix_sum = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# No block tables associated with encoder attention
|
||||
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_lens, causal_mask)
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
causal_mask = True
|
||||
|
||||
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
||||
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
|
||||
partial_prefix_sum = 0
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
key_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (query_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.seq_lens, causal_mask)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
@ -346,10 +463,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support attention logits soft "
|
||||
"capping.")
|
||||
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
self.logits_soft_cap = 0.0
|
||||
else:
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.attn_type = attn_type
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@ -374,6 +494,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
if self.use_triton_flash_attn:
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Triton FlashAttention does not support attention"
|
||||
"logits soft capping."
|
||||
" please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
triton_attention)
|
||||
self.attn_func = triton_attention
|
||||
@ -398,14 +526,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.use_naive_attn = True
|
||||
|
||||
if self.use_naive_attn:
|
||||
self.attn_func = _sdpa_attention
|
||||
logger.debug("Using naive attention in ROCmBackend")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Naive FlashAttention does not support"
|
||||
"attention logits soft capping.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmFlashAttentionImpl")
|
||||
self.attn_func = _sdpa_attention
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
@ -427,6 +554,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
|
||||
cross-attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
@ -435,54 +593,80 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
if key is not None and value is not None:
|
||||
# Reshape the input keys and values and store them in the
|
||||
# cache. If kv_cache is not provided, the new key and value
|
||||
# tensors are not cached. This happens during the initial
|
||||
# memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
attn_metadata.cross_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
if self.attn_type != AttentionType.ENCODER:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
if key is not None and value is not None \
|
||||
and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.seq_lens is not None
|
||||
# normal attention and DECODER
|
||||
if self.attn_type == AttentionType.DECODER and (
|
||||
kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = (prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
attn_metadata.seq_lens, True)
|
||||
# prefix-enabled attention and ENCODER/ENCODER_DECODER
|
||||
else:
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = _get_seq_len_block_table_args(
|
||||
prefill_meta, self.attn_type)
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
@ -493,18 +677,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
seq_lens,
|
||||
make_attn_mask=False) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
True,
|
||||
query_seq_start_loc,
|
||||
key_seq_start_loc,
|
||||
query_max_seq_len,
|
||||
key_max_seq_len,
|
||||
causal_mask,
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
@ -528,11 +712,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.seq_lens,
|
||||
num_tokens,
|
||||
query_seq_start_loc,
|
||||
num_prefill_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
causal_mask,
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
@ -540,19 +725,23 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_q=query_seq_start_loc,
|
||||
cu_seqlens_k=key_seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=key_max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
# common code for prefill
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
if output.shape[0] > num_prefill_tokens:
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
@ -583,7 +772,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len)
|
||||
if use_custom:
|
||||
max_seq_len = decode_meta.max_decode_seq_len
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len)
|
||||
assert max_seq_len is not None
|
||||
max_num_partitions = (
|
||||
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
@ -599,8 +791,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if num_prefill_tokens > 0:
|
||||
out = output[num_prefill_tokens:]
|
||||
else:
|
||||
out = output
|
||||
ops.paged_attention_rocm(
|
||||
output[num_prefill_tokens:],
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
@ -609,8 +805,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
@ -623,9 +823,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -635,7 +841,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def _sdpa_attention(
|
||||
|
||||
@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SequenceData
|
||||
@ -847,7 +848,8 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
i,
|
||||
i,
|
||||
)
|
||||
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
|
||||
elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
|
||||
_Backend.TORCH_SDPA):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
@ -859,7 +861,8 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
raise ValueError(
|
||||
f"Unsupported Attention backend {self.attn.backend} "
|
||||
"enum found. Expected the Attention backend to be "
|
||||
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
|
||||
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
|
||||
"XFORMERS or TORCH_SDPA.")
|
||||
|
||||
# We have to call torch.sdpa for prefill when using a
|
||||
# custom cross-attention mask. Because the mask is not a
|
||||
@ -1452,6 +1455,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
orig_name = name
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
logger.debug("Missing name %s, orig name %s", name,
|
||||
orig_name)
|
||||
continue
|
||||
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user