[Model] vLLM v1 supports Medusa (#17956)

Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com>
Signed-off-by: skylee-01 <497627264@qq.com>
Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
Sky Lee 2025-05-16 12:05:31 +08:00 committed by GitHub
parent ee659e3b60
commit f4937a51c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 108 additions and 2 deletions

View File

@ -1324,19 +1324,22 @@ class EngineArgs:
# Only Ngram speculative decoding so far.
is_ngram_enabled = False
is_eagle_enabled = False
is_medusa_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
speculative_method = self.speculative_config.get("method")
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
if speculative_model in ("ngram", "[ngram]"):
is_ngram_enabled = True
if not (is_ngram_enabled or is_eagle_enabled):
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
# Other speculative decoding methods are not supported yet.
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)

View File

@ -51,7 +51,10 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
if hasattr(vllm_config, 'draft_model_config'):
config = vllm_config.draft_model_config.hf_config
else:
config = vllm_config.model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.medusa import Medusa
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
logger = init_logger(__name__)
class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.hidden_size = vllm_config.speculative_config.\
draft_model_config.get_hidden_size(
)
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
# Get model loader and config
loader = get_model_loader(self.vllm_config.load_config)
draft_config = self.vllm_config.speculative_config.draft_model_config
# Load model with proper dtype and config
with set_default_torch_dtype(draft_config.dtype), \
set_current_vllm_config(self.vllm_config):
self.model = Medusa(
vllm_config=self.vllm_config.speculative_config).to(
self.device)
# Load model weights
weights = loader.get_all_weights(draft_config, self.model)
self.model.load_weights(weights)
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(hidden_states)

View File

@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if max_gen_len == 1:
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
valid_sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices,
device=sample_hidden_states.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.