mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:54:58 +08:00
[BugFix] Fix new nightly failures (#29578)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
ae0ce1be27
commit
be493e0b3c
@ -100,6 +100,32 @@ class CommonAttentionMetadata:
|
|||||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||||
|
|
||||||
|
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||||
|
def unpadded(
|
||||||
|
self, num_actual_tokens: int, num_actual_reqs: int
|
||||||
|
) -> "CommonAttentionMetadata":
|
||||||
|
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
|
||||||
|
return CommonAttentionMetadata(
|
||||||
|
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||||
|
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||||
|
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||||
|
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
||||||
|
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
|
||||||
|
num_reqs=num_actual_reqs,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||||
|
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||||
|
causal=self.causal,
|
||||||
|
logits_indices_padded=self.logits_indices_padded,
|
||||||
|
num_logits_indices=self.num_logits_indices,
|
||||||
|
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
|
||||||
|
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
||||||
|
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
||||||
|
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def slice_query_start_locs(
|
def slice_query_start_locs(
|
||||||
query_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
|
|||||||
@ -1551,7 +1551,7 @@ class GPUModelRunner(
|
|||||||
# Encoder-only layers do not have KV cache, so we need to
|
# Encoder-only layers do not have KV cache, so we need to
|
||||||
# create a dummy block table and slot mapping for them.
|
# create a dummy block table and slot mapping for them.
|
||||||
blk_table_tensor = torch.zeros(
|
blk_table_tensor = torch.zeros(
|
||||||
(num_tokens_padded, 1),
|
(num_reqs_padded, 1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
@ -1652,6 +1652,16 @@ class GPUModelRunner(
|
|||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
|
if spec_decode_common_attn_metadata is not None and (
|
||||||
|
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
|
||||||
|
):
|
||||||
|
# Currently the drafter still only uses piecewise cudagraphs (and modifies
|
||||||
|
# the attention metadata in directly), and therefore does not want to use
|
||||||
|
# padded attention metadata.
|
||||||
|
spec_decode_common_attn_metadata = (
|
||||||
|
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
|
||||||
|
)
|
||||||
|
|
||||||
return attn_metadata, spec_decode_common_attn_metadata
|
return attn_metadata, spec_decode_common_attn_metadata
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_lens(
|
def _compute_cascade_attn_prefix_lens(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user