mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 11:55:20 +08:00
[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync (#29723)
Signed-off-by: dongbo910220 <1275604947@qq.com>
This commit is contained in:
parent
2e7054da06
commit
03b5f940fd
@ -38,16 +38,16 @@ class MedusaProposer:
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[list[int]]:
|
||||
) -> torch.Tensor:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
logits = self.model.compute_logits(blocks)
|
||||
|
||||
# Get draft tokens and transpose the result
|
||||
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
|
||||
# synchronization.
|
||||
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
# Compute argmax for each Medusa head and stack into a single tensor
|
||||
# Shape: [batch_size, num_heads]
|
||||
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
|
||||
|
||||
return draft_tokens
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user