[Misc] Minor refactoring for prepare_inputs (#23116)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-18 16:58:05 -07:00 committed by GitHub
parent 498259ccce
commit 0dd3f4f5ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -757,10 +757,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
query_start_loc = self.query_start_loc[:num_reqs + 1]
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens_np[num_reqs:].fill(0)
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
@ -776,22 +785,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens[num_reqs:].fill_(0)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc[num_reqs + 1:].fill_(
self.query_start_loc_cpu[num_reqs].item())
query_start_loc = self.query_start_loc[:num_reqs + 1]
spec_decode_common_attn_metadata = None
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@ -860,6 +853,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
per_layer_metadata[layer_name]
attn_metadata[layer_name] = encoder_attn_metadata
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@ -874,12 +874,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,