[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
wwl2755 2025-05-18 21:49:46 -05:00 committed by GitHub
parent d1211f8794
commit 9da1095daf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 3 deletions

View File

@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
batch_size, output_len, seed)
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -146,6 +146,17 @@ class EAGLE(nn.Module):
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
# Handle both empty previous_hidden_states
# and mismatched batch size
batch_size = inputs_embeds.size(0)
if previous_hidden_states.size(0) == 0 or \
previous_hidden_states.size(0) != batch_size:
hidden_dim = self.config.model.hidden_size
device = inputs_embeds.device
# Create zero tensor with matching batch size
previous_hidden_states = \
torch.zeros(batch_size, hidden_dim, device=device)
if self.add_para_norm:
inputs_embeds = torch.cat([
self.enorm(inputs_embeds),

View File

@ -164,7 +164,14 @@ class Medusa(nn.Module):
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[SamplerOutput]:
) -> Optional[list[SamplerOutput]]:
# During preemption, we may receive an empty tensor (batch_size=0)
if previous_hidden_states.size(0) == 0:
# Return None to signal the Top1Proposer that no proposals
# were generated for this batch, allowing it to handle this
# special case appropriately
return None
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),

View File

@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True,
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
# Only keep sequence IDs that exist in self._seq_ids
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]