mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:45:02 +08:00
[Misc] Add max_seq_len to CommonAttentionMetadata (#23216)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5efd6905bc
commit
d6d13bd49e
@ -58,6 +58,7 @@ def create_common_attn_metadata(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
seq_lens_cpu = seq_lens.cpu()
|
seq_lens_cpu = seq_lens.cpu()
|
||||||
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
|
|
||||||
# Create computed tokens (context length for each sequence)
|
# Create computed tokens (context length for each sequence)
|
||||||
context_lens = [
|
context_lens = [
|
||||||
@ -101,6 +102,7 @@ def create_common_attn_metadata(
|
|||||||
num_reqs=batch_spec.batch_size,
|
num_reqs=batch_spec.batch_size,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=block_table_tensor,
|
block_table_tensor=block_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
|||||||
@ -50,6 +50,7 @@ def forward_attention(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
context_lens = seq_lens - query_lens
|
context_lens = seq_lens - query_lens
|
||||||
|
max_seq_len = int(seq_lens.max())
|
||||||
max_query_len = q_len
|
max_query_len = q_len
|
||||||
num_actual_tokens = query_start_loc[-1]
|
num_actual_tokens = query_start_loc[-1]
|
||||||
|
|
||||||
@ -81,6 +82,7 @@ def forward_attention(
|
|||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=block_table,
|
block_table_tensor=block_table,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
|
|||||||
@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
page_size = self.page_size
|
page_size = self.page_size
|
||||||
max_q_len = common_attn_metadata.max_query_len
|
max_q_len = common_attn_metadata.max_query_len
|
||||||
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|||||||
@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder(
|
|||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
|
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder(
|
|||||||
q_start_loc = common_attn_metadata.query_start_loc
|
q_start_loc = common_attn_metadata.query_start_loc
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
kv_seqlens = common_attn_metadata.seq_lens
|
kv_seqlens = common_attn_metadata.seq_lens
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
block_table = common_attn_metadata.block_table_tensor
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder(
|
|||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
|||||||
@ -58,6 +58,8 @@ class CommonAttentionMetadata:
|
|||||||
"""Total number of tokens in batch"""
|
"""Total number of tokens in batch"""
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
"""Longest query in batch"""
|
"""Longest query in batch"""
|
||||||
|
max_seq_len: int
|
||||||
|
"""Longest context length in batch"""
|
||||||
|
|
||||||
block_table_tensor: torch.Tensor
|
block_table_tensor: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
@ -107,6 +109,7 @@ def _make_metadata_with_slice(
|
|||||||
|
|
||||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||||
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||||
request_slice]
|
request_slice]
|
||||||
|
|
||||||
@ -128,6 +131,7 @@ def _make_metadata_with_slice(
|
|||||||
num_reqs=num_requests,
|
num_reqs=num_requests,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=block_table_tensor,
|
block_table_tensor=block_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
)
|
)
|
||||||
@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
|
|||||||
|
|
||||||
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
||||||
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
||||||
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
|
|
||||||
return CommonAttentionMetadata(
|
return CommonAttentionMetadata(
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
|
|||||||
num_reqs=len(seq_lens_cpu),
|
num_reqs=len(seq_lens_cpu),
|
||||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||||
max_query_len=seqlens_q_local.max(),
|
max_query_len=seqlens_q_local.max(),
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=block_table_local,
|
block_table_tensor=block_table_local,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
|||||||
@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder(
|
|||||||
q_seqlens = torch.diff(q_start_loc)
|
q_seqlens = torch.diff(q_start_loc)
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
kv_seqlens = common_attn_metadata.seq_lens
|
kv_seqlens = common_attn_metadata.seq_lens
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = common_attn_metadata.max_seq_len
|
||||||
block_table = common_attn_metadata.block_table_tensor
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
|||||||
@ -582,6 +582,7 @@ class EagleProposer:
|
|||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=total_num_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
|
max_seq_len=new_seq_lens_cpu.max().item(),
|
||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||||
causal=True,
|
causal=True,
|
||||||
|
|||||||
@ -774,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.seq_lens_np[num_reqs:].fill(0)
|
self.seq_lens_np[num_reqs:].fill(0)
|
||||||
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
|
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
|
||||||
seq_lens = self.seq_lens[:num_reqs]
|
seq_lens = self.seq_lens[:num_reqs]
|
||||||
|
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
|
||||||
|
|
||||||
# Copy the tensors to the GPU.
|
# Copy the tensors to the GPU.
|
||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
@ -886,6 +887,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
block_table_tensor=blk_table_tensor,
|
block_table_tensor=blk_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
causal=True,
|
causal=True,
|
||||||
@ -2338,6 +2340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
max_seq_len=self.max_model_len,
|
||||||
block_table_tensor=self.input_batch.block_table[
|
block_table_tensor=self.input_batch.block_table[
|
||||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||||
slot_mapping=self.input_batch.
|
slot_mapping=self.input_batch.
|
||||||
@ -3343,6 +3346,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
|
max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(),
|
||||||
block_table_tensor=dummy_block_table,
|
block_table_tensor=dummy_block_table,
|
||||||
slot_mapping=dummy_slot_mapping,
|
slot_mapping=dummy_slot_mapping,
|
||||||
causal=False,
|
causal=False,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user