mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 06:07:02 +08:00
[Model] vLLM v1 supports Medusa (#17956)
Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com> Signed-off-by: skylee-01 <497627264@qq.com> Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
parent
ee659e3b60
commit
f4937a51c1
@ -1324,19 +1324,22 @@ class EngineArgs:
|
||||
# Only Ngram speculative decoding so far.
|
||||
is_ngram_enabled = False
|
||||
is_eagle_enabled = False
|
||||
is_medusa_enabled = False
|
||||
if self.speculative_config is not None:
|
||||
# This is supported but experimental (handled below).
|
||||
speculative_method = self.speculative_config.get("method")
|
||||
if speculative_method:
|
||||
if speculative_method in ("ngram", "[ngram]"):
|
||||
is_ngram_enabled = True
|
||||
elif speculative_method == "medusa":
|
||||
is_medusa_enabled = True
|
||||
elif speculative_method in ("eagle", "eagle3"):
|
||||
is_eagle_enabled = True
|
||||
else:
|
||||
speculative_model = self.speculative_config.get("model")
|
||||
if speculative_model in ("ngram", "[ngram]"):
|
||||
is_ngram_enabled = True
|
||||
if not (is_ngram_enabled or is_eagle_enabled):
|
||||
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
|
||||
# Other speculative decoding methods are not supported yet.
|
||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||
recommend_to_remove=False)
|
||||
|
||||
@ -51,7 +51,10 @@ class Medusa(nn.Module):
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
if hasattr(vllm_config, 'draft_model_config'):
|
||||
config = vllm_config.draft_model_config.hf_config
|
||||
else:
|
||||
config = vllm_config.model_config.hf_config
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.blocks = nn.ModuleList([
|
||||
|
||||
74
vllm/v1/spec_decode/medusa.py
Normal file
74
vllm/v1/spec_decode/medusa.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.medusa import Medusa
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Initialize logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MedusaProposer:
|
||||
"""
|
||||
Medusa proposer class for generating token sequences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
# Save config parameters
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.hidden_size = vllm_config.speculative_config.\
|
||||
draft_model_config.get_hidden_size(
|
||||
)
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
def propose(
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
logits = self.model.compute_logits(blocks, None)
|
||||
|
||||
# Get draft tokens and transpose the result
|
||||
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
# Get model loader and config
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
draft_config = self.vllm_config.speculative_config.draft_model_config
|
||||
|
||||
# Load model with proper dtype and config
|
||||
with set_default_torch_dtype(draft_config.dtype), \
|
||||
set_current_vllm_config(self.vllm_config):
|
||||
self.model = Medusa(
|
||||
vllm_config=self.vllm_config.speculative_config).to(
|
||||
self.device)
|
||||
|
||||
# Load model weights
|
||||
weights = loader.get_all_weights(draft_config, self.model)
|
||||
self.model.load_weights(weights)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self, num_tokens: int) -> None:
|
||||
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(hidden_states)
|
||||
@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device) # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if max_gen_len == 1:
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
valid_sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
|
||||
indices = torch.tensor(indices,
|
||||
device=sample_hidden_states.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user