From 9da1095daf1710a41da3f79f7954f5dee956ce62 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Sun, 18 May 2025 21:49:46 -0500 Subject: [PATCH] [Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175) Signed-off-by: wwl2755 --- tests/spec_decode/e2e/test_eagle_correctness.py | 2 -- vllm/model_executor/models/eagle.py | 11 +++++++++++ vllm/model_executor/models/medusa.py | 9 ++++++++- vllm/sequence.py | 2 ++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index 2814bb6d3773..eee535a146f4 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( batch_size, output_len, seed) -# TRACKING: https://github.com/vllm-project/vllm/issues/18166 -@pytest.mark.skip(reason="RE-ENABLE: Failing on main.") @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 726660796a6f..fb1675d29915 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -146,6 +146,17 @@ class EAGLE(nn.Module): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) + # Handle both empty previous_hidden_states + # and mismatched batch size + batch_size = inputs_embeds.size(0) + if previous_hidden_states.size(0) == 0 or \ + previous_hidden_states.size(0) != batch_size: + hidden_dim = self.config.model.hidden_size + device = inputs_embeds.device + # Create zero tensor with matching batch size + previous_hidden_states = \ + torch.zeros(batch_size, hidden_dim, device=device) + if self.add_para_norm: inputs_embeds = torch.cat([ self.enorm(inputs_embeds), diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 4724cbe56445..588bcb628f8c 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -164,7 +164,14 @@ class Medusa(nn.Module): self, previous_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: + ) -> Optional[list[SamplerOutput]]: + # During preemption, we may receive an empty tensor (batch_size=0) + if previous_hidden_states.size(0) == 0: + # Return None to signal the Top1Proposer that no proposals + # were generated for this batch, allowing it to handle this + # special case appropriately + return None + return self.sample( logits=self.compute_logits( hidden_states=self.forward(previous_hidden_states), diff --git a/vllm/sequence.py b/vllm/sequence.py index 91f769d6dbd9..5aa9ae62f542 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True, # may be "paused" then "resumed" later. This should only prune sequences # which are confirmed to be aborted. seq_ids = get_all_seq_ids(seq_group_metadata_list) + # Only keep sequence IDs that exist in self._seq_ids + seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]