From f4937a51c138978928f38da6a2d3b30c53286240 Mon Sep 17 00:00:00 2001 From: Sky Lee <46676799+skylee-01@users.noreply.github.com> Date: Fri, 16 May 2025 12:05:31 +0800 Subject: [PATCH] [Model] vLLM v1 supports Medusa (#17956) Signed-off-by: lisiqi23 Signed-off-by: skylee-01 <497627264@qq.com> Co-authored-by: lisiqi23 --- vllm/engine/arg_utils.py | 5 +- vllm/model_executor/models/medusa.py | 5 +- vllm/v1/spec_decode/medusa.py | 74 ++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 26 ++++++++++ 4 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 vllm/v1/spec_decode/medusa.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 240142a1c5d1f..3e942b0f0ff9b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1324,19 +1324,22 @@ class EngineArgs: # Only Ngram speculative decoding so far. is_ngram_enabled = False is_eagle_enabled = False + is_medusa_enabled = False if self.speculative_config is not None: # This is supported but experimental (handled below). speculative_method = self.speculative_config.get("method") if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True + elif speculative_method == "medusa": + is_medusa_enabled = True elif speculative_method in ("eagle", "eagle3"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True - if not (is_ngram_enabled or is_eagle_enabled): + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", recommend_to_remove=False) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index ac0b281f359c3..4724cbe56445e 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -51,7 +51,10 @@ class Medusa(nn.Module): needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - config = vllm_config.model_config.hf_config + if hasattr(vllm_config, 'draft_model_config'): + config = vllm_config.draft_model_config.hf_config + else: + config = vllm_config.model_config.hf_config super().__init__() self.config = config self.blocks = nn.ModuleList([ diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py new file mode 100644 index 0000000000000..14bc9c9e0d1a3 --- /dev/null +++ b/vllm/v1/spec_decode/medusa.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.medusa import Medusa +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MedusaProposer: + """ + Medusa proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + self.hidden_size = vllm_config.speculative_config.\ + draft_model_config.get_hidden_size( + ) + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + target_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # Generate blocks and compute logits + blocks = self.model(target_hidden_states) + logits = self.model.compute_logits(blocks, None) + + # Get draft tokens and transpose the result + draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] + return [list(row) for row in zip(*draft_tokens)] + + def load_model(self, target_model: nn.Module) -> None: + # Get model loader and config + loader = get_model_loader(self.vllm_config.load_config) + draft_config = self.vllm_config.speculative_config.draft_model_config + + # Load model with proper dtype and config + with set_default_torch_dtype(draft_config.dtype), \ + set_current_vllm_config(self.vllm_config): + self.model = Medusa( + vllm_config=self.vllm_config.speculative_config).to( + self.device) + + # Load model weights + weights = loader.get_all_weights(draft_config, self.model) + self.model.load_weights(weights) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model(hidden_states) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b34a9fb06163..0788ac5adde8c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported @@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.device) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "medusa": + self.drafter = MedusaProposer( + vllm_config=self.vllm_config, + device=self.device) # type: ignore else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) + elif self.speculative_config.method == "medusa": + assert isinstance(self.drafter, MedusaProposer) + if max_gen_len == 1: + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + valid_sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + + indices = torch.tensor(indices, + device=sample_hidden_states.device) + hidden_states = sample_hidden_states[indices] + + spec_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop.