mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
d1211f8794
commit
9da1095daf
@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
|
||||
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
||||
@ -146,6 +146,17 @@ class EAGLE(nn.Module):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
# Handle both empty previous_hidden_states
|
||||
# and mismatched batch size
|
||||
batch_size = inputs_embeds.size(0)
|
||||
if previous_hidden_states.size(0) == 0 or \
|
||||
previous_hidden_states.size(0) != batch_size:
|
||||
hidden_dim = self.config.model.hidden_size
|
||||
device = inputs_embeds.device
|
||||
# Create zero tensor with matching batch size
|
||||
previous_hidden_states = \
|
||||
torch.zeros(batch_size, hidden_dim, device=device)
|
||||
|
||||
if self.add_para_norm:
|
||||
inputs_embeds = torch.cat([
|
||||
self.enorm(inputs_embeds),
|
||||
|
||||
@ -164,7 +164,14 @@ class Medusa(nn.Module):
|
||||
self,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[SamplerOutput]:
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
# During preemption, we may receive an empty tensor (batch_size=0)
|
||||
if previous_hidden_states.size(0) == 0:
|
||||
# Return None to signal the Top1Proposer that no proposals
|
||||
# were generated for this batch, allowing it to handle this
|
||||
# special case appropriately
|
||||
return None
|
||||
|
||||
return self.sample(
|
||||
logits=self.compute_logits(
|
||||
hidden_states=self.forward(previous_hidden_states),
|
||||
|
||||
@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||
# may be "paused" then "resumed" later. This should only prune sequences
|
||||
# which are confirmed to be aborted.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
# Only keep sequence IDs that exist in self._seq_ids
|
||||
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
|
||||
if seq_ids != self._seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user