diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 5e6bc331835b..94dd3d2629eb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -464,8 +464,9 @@ def make_local_attention_virtual_batches( attn_chunk_size)[arange > 0] # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) + cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) + np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:]) + cu_seqlens_q_local[0] = 0 # compute the seqlens_k_local, # basically a full local attention block for all but the last block in each @@ -508,11 +509,10 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + block_indices = (block_starts[:, None] + + np.arange(pages_per_local_batch, dtype=np.int32)) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - + 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\