mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:15:01 +08:00
[Spec Decode][V0] Fix spec decode correctness test in V0 eagle/medusa (#18175)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
d1211f8794
commit
9da1095daf
@ -178,8 +178,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
|||||||
batch_size, output_len, seed)
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
|
|
||||||
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
|||||||
@ -146,6 +146,17 @@ class EAGLE(nn.Module):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
# Handle both empty previous_hidden_states
|
||||||
|
# and mismatched batch size
|
||||||
|
batch_size = inputs_embeds.size(0)
|
||||||
|
if previous_hidden_states.size(0) == 0 or \
|
||||||
|
previous_hidden_states.size(0) != batch_size:
|
||||||
|
hidden_dim = self.config.model.hidden_size
|
||||||
|
device = inputs_embeds.device
|
||||||
|
# Create zero tensor with matching batch size
|
||||||
|
previous_hidden_states = \
|
||||||
|
torch.zeros(batch_size, hidden_dim, device=device)
|
||||||
|
|
||||||
if self.add_para_norm:
|
if self.add_para_norm:
|
||||||
inputs_embeds = torch.cat([
|
inputs_embeds = torch.cat([
|
||||||
self.enorm(inputs_embeds),
|
self.enorm(inputs_embeds),
|
||||||
|
|||||||
@ -164,7 +164,14 @@ class Medusa(nn.Module):
|
|||||||
self,
|
self,
|
||||||
previous_hidden_states: torch.Tensor,
|
previous_hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> list[SamplerOutput]:
|
) -> Optional[list[SamplerOutput]]:
|
||||||
|
# During preemption, we may receive an empty tensor (batch_size=0)
|
||||||
|
if previous_hidden_states.size(0) == 0:
|
||||||
|
# Return None to signal the Top1Proposer that no proposals
|
||||||
|
# were generated for this batch, allowing it to handle this
|
||||||
|
# special case appropriately
|
||||||
|
return None
|
||||||
|
|
||||||
return self.sample(
|
return self.sample(
|
||||||
logits=self.compute_logits(
|
logits=self.compute_logits(
|
||||||
hidden_states=self.forward(previous_hidden_states),
|
hidden_states=self.forward(previous_hidden_states),
|
||||||
|
|||||||
@ -1330,6 +1330,8 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
|||||||
# may be "paused" then "resumed" later. This should only prune sequences
|
# may be "paused" then "resumed" later. This should only prune sequences
|
||||||
# which are confirmed to be aborted.
|
# which are confirmed to be aborted.
|
||||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||||
|
# Only keep sequence IDs that exist in self._seq_ids
|
||||||
|
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
|
||||||
if seq_ids != self._seq_ids:
|
if seq_ids != self._seq_ids:
|
||||||
# Batch contents changed - prune removed sequences.
|
# Batch contents changed - prune removed sequences.
|
||||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user