[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync (#29723)

Signed-off-by: dongbo910220 <1275604947@qq.com>
This commit is contained in:
dongbo910220 2025-12-10 08:15:01 +08:00 committed by GitHub
parent 2e7054da06
commit 03b5f940fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,16 +38,16 @@ class MedusaProposer:
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)
# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
# synchronization.
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
# Compute argmax for each Medusa head and stack into a single tensor
# Shape: [batch_size, num_heads]
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
return draft_tokens
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag