mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:06:25 +08:00
[BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (#16998)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
b2f195c429
commit
d0da99fb70
@ -105,6 +105,7 @@ class FlashAttentionMetadata:
|
|||||||
local_block_table: torch.Tensor
|
local_block_table: torch.Tensor
|
||||||
local_max_query_len: int
|
local_max_query_len: int
|
||||||
local_max_seq_len: int
|
local_max_seq_len: int
|
||||||
|
local_scheduler_metadata: Optional[torch.Tensor]
|
||||||
|
|
||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||||
|
|
||||||
@ -282,7 +283,9 @@ class FlashAttentionMetadataBuilder:
|
|||||||
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
self.num_heads = model_config.get_num_attention_heads(
|
self.num_heads_q = model_config.get_num_attention_heads(
|
||||||
|
runner.parallel_config)
|
||||||
|
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||||
runner.parallel_config)
|
runner.parallel_config)
|
||||||
self.headdim = model_config.get_head_size()
|
self.headdim = model_config.get_head_size()
|
||||||
self.page_size = self.runner.block_size
|
self.page_size = self.runner.block_size
|
||||||
@ -304,6 +307,23 @@ class FlashAttentionMetadataBuilder:
|
|||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
self.runner.device, non_blocking=True).long()
|
self.runner.device, non_blocking=True).long()
|
||||||
|
|
||||||
|
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||||
|
max_seq_len, causal):
|
||||||
|
if self.aot_schedule:
|
||||||
|
return get_scheduler_metadata(
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_seqlen_q=max_query_len,
|
||||||
|
max_seqlen_k=max_seq_len,
|
||||||
|
cache_seqlens=seqlens,
|
||||||
|
num_heads_q=self.num_heads_q,
|
||||||
|
num_heads_kv=self.num_heads_kv,
|
||||||
|
headdim=self.headdim,
|
||||||
|
page_size=self.page_size,
|
||||||
|
cu_seqlens_q=cu_query_lens,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
# for local attention
|
# for local attention
|
||||||
local_attn_metadata = None
|
local_attn_metadata = None
|
||||||
if self.runner.attention_chunk_size is not None:
|
if self.runner.attention_chunk_size is not None:
|
||||||
@ -315,36 +335,31 @@ class FlashAttentionMetadataBuilder:
|
|||||||
block_table,
|
block_table,
|
||||||
self.runner.block_size,
|
self.runner.block_size,
|
||||||
)
|
)
|
||||||
|
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||||
|
self.runner.device, non_blocking=True)
|
||||||
|
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
|
self.runner.device, non_blocking=True)
|
||||||
|
local_max_query_len = seqlens_q_local_np.max()
|
||||||
|
local_max_seq_len = virt_k_seqlens_np.max()
|
||||||
|
local_scheduler_metadata = schedule(
|
||||||
|
batch_size=local_query_start_loc.shape[0] - 1,
|
||||||
|
cu_query_lens=local_query_start_loc,
|
||||||
|
max_query_len=local_max_query_len,
|
||||||
|
seqlens=local_seqused_k,
|
||||||
|
max_seq_len=local_max_seq_len,
|
||||||
|
causal=True)
|
||||||
|
|
||||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
local_query_start_loc=torch.from_numpy(
|
local_query_start_loc=local_query_start_loc,
|
||||||
virt_q_cu_seqlens_np).to(self.runner.device,
|
local_seqused_k=local_seqused_k,
|
||||||
non_blocking=True),
|
|
||||||
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
|
|
||||||
self.runner.device, non_blocking=True),
|
|
||||||
local_block_table=virt_block_table,
|
local_block_table=virt_block_table,
|
||||||
local_max_query_len=seqlens_q_local_np.max(),
|
local_max_query_len=local_max_query_len,
|
||||||
local_max_seq_len=virt_k_seqlens_np.max(),
|
local_max_seq_len=local_max_seq_len,
|
||||||
|
local_scheduler_metadata=local_scheduler_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
|
|
||||||
causal):
|
|
||||||
if self.aot_schedule:
|
|
||||||
return get_scheduler_metadata(
|
|
||||||
batch_size=num_reqs,
|
|
||||||
max_seqlen_q=max_query_len,
|
|
||||||
max_seqlen_k=max_seq_len,
|
|
||||||
cache_seqlens=seqlens,
|
|
||||||
num_heads_q=self.num_heads,
|
|
||||||
num_heads_kv=self.num_heads,
|
|
||||||
headdim=self.headdim,
|
|
||||||
page_size=self.page_size,
|
|
||||||
cu_seqlens_q=cu_query_lens,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -357,12 +372,14 @@ class FlashAttentionMetadataBuilder:
|
|||||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||||
self.runner.device)
|
self.runner.device)
|
||||||
prefix_scheduler_metadata = schedule(
|
prefix_scheduler_metadata = schedule(
|
||||||
|
batch_size=num_reqs,
|
||||||
cu_query_lens=cu_prefix_query_lens,
|
cu_query_lens=cu_prefix_query_lens,
|
||||||
max_query_len=num_actual_tokens,
|
max_query_len=num_actual_tokens,
|
||||||
seqlens=prefix_kv_lens,
|
seqlens=prefix_kv_lens,
|
||||||
max_seq_len=common_prefix_len,
|
max_seq_len=common_prefix_len,
|
||||||
causal=False)
|
causal=False)
|
||||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||||
|
cu_query_lens=query_start_loc,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
seqlens=suffix_kv_lens,
|
seqlens=suffix_kv_lens,
|
||||||
max_seq_len=max_seq_len -
|
max_seq_len=max_seq_len -
|
||||||
@ -373,7 +390,8 @@ class FlashAttentionMetadataBuilder:
|
|||||||
prefix_kv_lens = None
|
prefix_kv_lens = None
|
||||||
suffix_kv_lens = None
|
suffix_kv_lens = None
|
||||||
prefix_scheduler_metadata = None
|
prefix_scheduler_metadata = None
|
||||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||||
|
cu_query_lens=query_start_loc,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
seqlens=seq_lens,
|
seqlens=seq_lens,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
@ -540,12 +558,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_q = local_metadata.local_max_query_len
|
max_seqlen_q = local_metadata.local_max_query_len
|
||||||
max_seqlen_k = local_metadata.local_max_seq_len
|
max_seqlen_k = local_metadata.local_max_seq_len
|
||||||
block_table = local_metadata.local_block_table
|
block_table = local_metadata.local_block_table
|
||||||
|
scheduler_metadata = local_metadata.local_scheduler_metadata
|
||||||
else:
|
else:
|
||||||
cu_seqlens_q = attn_metadata.query_start_loc
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
seqused_k = attn_metadata.seq_lens
|
seqused_k = attn_metadata.seq_lens
|
||||||
max_seqlen_q = attn_metadata.max_query_len
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
max_seqlen_k = attn_metadata.max_seq_len
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
|
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
|
|
||||||
@ -564,7 +584,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
scheduler_metadata=scheduler_metadata,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
q_descale=layer._q_scale.expand(descale_shape),
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user