[ROCm][AMD][Model] llama 3.2 support upstreaming (#12421)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-01-30 20:24:28 -08:00 committed by GitHub
parent 9798b2fb00
commit a1fc18c030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 303 additions and 87 deletions

View File

@ -90,6 +90,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
@ -100,30 +111,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
@ -131,6 +130,23 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
@ -141,10 +157,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
@ -159,12 +172,20 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
query_start_loc=None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
@ -194,7 +215,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
@ -304,6 +330,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor,
return attn_biases
def _get_seq_len_block_table_args(
attn_metadata: ROCmFlashAttentionMetadata,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
'''
partial_prefix_sum = 0
if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False
# No block tables associated with encoder attention
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
causal_mask = True
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
partial_prefix_sum = 0
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (query_start_loc, attn_metadata.max_prefill_seq_len,
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.seq_lens, causal_mask)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
@ -346,10 +463,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0
else:
self.logits_soft_cap = logits_soft_cap
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@ -374,6 +494,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
"logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
@ -398,14 +526,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = True
if self.use_naive_attn:
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")
if logits_soft_cap is not None:
raise ValueError(
"ROCm Naive FlashAttention does not support"
"attention logits soft capping.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")
self.attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@ -427,6 +554,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
cross-attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
@ -435,54 +593,80 @@ class ROCmFlashAttentionImpl(AttentionImpl):
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if kv_cache.numel() > 0:
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if key is not None and value is not None:
# Reshape the input keys and values and store them in the
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping
if self.attn_type != AttentionType.ENCODER_DECODER else
attn_metadata.cross_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens
else:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if key is not None and value is not None \
and self.attn_type != AttentionType.ENCODER_DECODER:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
# normal attention and DECODER
if self.attn_type == AttentionType.DECODER and (
kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = (prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
attn_metadata.seq_lens, True)
# prefix-enabled attention and ENCODER/ENCODER_DECODER
else:
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = _get_seq_len_block_table_args(
prefill_meta, self.attn_type)
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
@ -493,18 +677,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
seq_lens,
make_attn_mask=False) # type: ignore
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len,
True,
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
key_max_seq_len,
causal_mask,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
@ -528,11 +712,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query,
key,
value,
prefill_meta.seq_lens,
num_tokens,
query_seq_start_loc,
num_prefill_tokens,
self.num_heads,
self.head_size,
self.scale,
causal_mask,
attn_masks,
)
else:
@ -540,19 +725,23 @@ class ROCmFlashAttentionImpl(AttentionImpl):
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
@ -583,7 +772,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
if use_custom:
max_seq_len = decode_meta.max_decode_seq_len
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len)
assert max_seq_len is not None
max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
@ -599,8 +791,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
ops.paged_attention_rocm(
output[num_prefill_tokens:],
out,
exp_sums,
max_logits,
tmp_output,
@ -609,8 +805,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
self.num_kv_heads,
self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
@ -623,9 +823,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_decode_seq_len,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -635,7 +841,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
return output.view(-1, self.num_heads * self.head_size)
def _sdpa_attention(

View File

@ -48,7 +48,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
@ -847,7 +848,8 @@ class MllamaTextCrossAttention(nn.Module):
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
_Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
@ -859,7 +861,8 @@ class MllamaTextCrossAttention(nn.Module):
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
"XFORMERS or TORCH_SDPA.")
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
@ -1452,6 +1455,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)