[Perf] Convert np array to torch tensor to index into block table for attn chunking (#24474)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-09-09 20:01:06 -07:00 committed by GitHub
parent b23fb78623
commit dc625ea6b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -542,7 +542,14 @@ def make_local_attention_virtual_batches(
1) 1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch) local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
# regression when using numpy arrays (batch and block indices) to index into
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch]\
.view(virtual_batches, -1) .view(virtual_batches, -1)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)