diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index a4e38eb32f6a..e547e71e0cdb 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -58,6 +58,7 @@ def create_common_attn_metadata( dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) # Create computed tokens (context length for each sequence) context_lens = [ @@ -101,6 +102,7 @@ def create_common_attn_metadata( num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, causal=True, diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 456ce712d36e..631781740866 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -50,6 +50,7 @@ def forward_attention( dtype=torch.int32, ) context_lens = seq_lens - query_lens + max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] @@ -81,6 +82,7 @@ def forward_attention( num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ab7a71a399b3..eed3cba9a2ca 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -233,7 +233,7 @@ class FlashAttentionMetadataBuilder( num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 53fafbc4af91..8a25088848a4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -463,7 +463,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): page_size = self.page_size max_q_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.seq_lens_cpu.max().item() + max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e599411b2d7e..abca981035d9 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -305,7 +305,7 @@ class FlexAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 36b5853bfdcb..b9ff113573a1 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -270,7 +270,7 @@ class AiterFlashAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082..2a0c52377cc7 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -205,7 +205,7 @@ class TreeAttentionMetadataBuilder( q_start_loc = common_attn_metadata.query_start_loc max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 48a9af3decac..c69dd8415f92 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -90,7 +90,7 @@ class TritonAttentionMetadataBuilder( num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 94dd3d2629eb..57c4d436c5b6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -58,6 +58,8 @@ class CommonAttentionMetadata: """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" + max_seq_len: int + """Longest context length in batch""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor @@ -107,6 +109,7 @@ def _make_metadata_with_slice( seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -128,6 +131,7 @@ def _make_metadata_with_slice( num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, + max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, ) @@ -520,6 +524,7 @@ def make_local_attention_virtual_batches( query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) + max_seq_len = int(seq_lens_cpu.max()) return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, @@ -531,6 +536,7 @@ def make_local_attention_virtual_batches( num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), + max_seq_len=max_seq_len, block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index fe732c601770..b305bc153908 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -231,7 +231,7 @@ class XFormersAttentionMetadataBuilder( q_seqlens = torch.diff(q_start_loc) max_query_len = common_attn_metadata.max_query_len kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + max_seq_len = common_attn_metadata.max_seq_len block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8cd2ad12cfa3..cc2b2a139d5e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -582,6 +582,7 @@ class EagleProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e0bab3367caf..d9770226b14e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -774,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens_np[num_reqs:].fill(0) self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True) seq_lens = self.seq_lens[:num_reqs] + max_seq_len = self.seq_lens_np[:num_reqs].max().item() # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -886,6 +887,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, causal=True, @@ -2338,6 +2340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, + max_seq_len=self.max_model_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. @@ -3343,6 +3346,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + max_seq_len=self.seq_lens_cpu[:num_reqs].max().item(), block_table_tensor=dummy_block_table, slot_mapping=dummy_slot_mapping, causal=False,