[BugFix] Fix new nightly failures (#29578)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-11-27 16:45:38 -05:00 committed by GitHub
parent ae0ce1be27
commit be493e0b3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 1 deletions

View File

@ -100,6 +100,32 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
) -> "CommonAttentionMetadata":
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
logits_indices_padded=self.logits_indices_padded,
num_logits_indices=self.num_logits_indices,
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
)
def slice_query_start_locs(
query_start_loc: torch.Tensor,

View File

@ -1551,7 +1551,7 @@ class GPUModelRunner(
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
(num_tokens_padded, 1),
(num_reqs_padded, 1),
dtype=torch.int32,
device=self.device,
)
@ -1652,6 +1652,16 @@ class GPUModelRunner(
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
):
# Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use
# padded attention metadata.
spec_decode_common_attn_metadata = (
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
)
return attn_metadata, spec_decode_common_attn_metadata
def _compute_cascade_attn_prefix_lens(