[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-05-23 03:05:44 +01:00 committed by GitHub
parent 04eb88dc80
commit c6b636f9fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 59 additions and 135 deletions

View File

@ -117,34 +117,13 @@ def test_prepare_inputs():
])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_registry, mock_get_layers, mock_get_pp_group, method,
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):
# Setup mock for model class
mock_model_cls = mock.MagicMock()
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
"test_arch")
# Create a real context manager for mocks
class MockContextManager:
def __init__(self):
pass
def __enter__(self):
return None
def __exit__(self, exc_type, exc_val, exc_tb):
return False
# Make the mocks return actual context manager objects
mock_set_dtype.return_value = MockContextManager()
mock_set_config.return_value = MockContextManager()
# Setup model mock
mock_model = mock.MagicMock()
mock_get_model.return_value = mock_model
# Setup mocks for attention layers
target_attn_layers = {
@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_get_pp_group.return_value = mock_pp_group
# Setup model loader mock
mock_loader = mock.MagicMock()
mock_get_loader.return_value = mock_loader
# Setup model mock
mock_model = mock.MagicMock()
mock_model_cls.return_value = mock_model
mock_model.to.return_value = mock_model
# Configure mock to test the attribute sharing path
if method == "eagle":
# For eagle, test the lm_head path
mock_model.load_weights.return_value = {
"model.embed_tokens.weight": torch.zeros(1)
}
else:
# For eagle3, test the embed_tokens path
mock_model.load_weights.return_value = {}
# Setup target model with the appropriate attributes
target_model = mock.MagicMock()
@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
proposer.load_model(target_model)
# Verify common interactions
mock_get_loader.assert_called_once()
mock_model_cls.assert_called_once()
mock_model.to.assert_called_once()
mock_model.load_weights.assert_called_once()
# Verify the loader was called with the right config
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
mock_get_model.assert_called_once()
# Verify the specific attribute sharing based on the method
if method == "eagle":

View File

@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from torch import nn
from vllm.config import LoadConfig, LoadFormat, VllmConfig
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
return DefaultModelLoader(load_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
__all__ = [

View File

@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError
@abstractmethod
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, *, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError

View File

@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):

View File

@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt=True,
allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(

View File

@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:

View File

@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights

View File

@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:

View File

@ -100,9 +100,9 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from vllm.distributed import get_tensor_model_parallel_rank

View File

@ -93,8 +93,8 @@ class TensorizerLoader(BaseModelLoader):
with self.tensorizer_config.open_stream():
pass
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

View File

@ -42,8 +42,10 @@ def initialize_model(
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
model_config: Optional[ModelConfig] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)

View File

@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,

View File

@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id,
prefix="model")
prefix="model",
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
scale=logit_scale)
self.draft_id_to_target_id = nn.Parameter(
torch.zeros((self.config.draft_vocab_size),
dtype=torch.long).type(torch.LongTensor),
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
requires_grad=False,
)

View File

@ -51,10 +51,7 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
if hasattr(vllm_config, 'draft_model_config'):
config = vllm_config.draft_model_config.hf_config
else:
config = vllm_config.model_config.hf_config
config = vllm_config.speculative_config.draft_model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([

View File

@ -4,14 +4,11 @@ import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
@ -280,51 +277,28 @@ class EagleProposer:
return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None:
loader = get_model_loader(self.vllm_config.load_config)
target_layer_num = self.vllm_config.model_config.get_num_layers(
self.vllm_config.parallel_config)
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
draft_model_config.architectures)
self.model = draft_model_cls(
vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device)
self.model = get_model(vllm_config=self.vllm_config,
model_config=draft_model_config)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
assert "model.embed_tokens.weight" not in loaded_weights, \
"For PP = 1, Eagle draft should share embed with target model"
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_model.model.embed_tokens
else:
assert "model.embed_tokens.weight" in loaded_weights, \
"For PP > 1, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."

View File

@ -3,12 +3,10 @@
import torch
import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.medusa import Medusa
from vllm.model_executor.model_loader import get_model
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
@ -49,20 +47,9 @@ class MedusaProposer:
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
# Get model loader and config
loader = get_model_loader(self.vllm_config.load_config)
draft_config = self.vllm_config.speculative_config.draft_model_config
# Load model with proper dtype and config
with set_default_torch_dtype(draft_config.dtype), \
set_current_vllm_config(self.vllm_config):
self.model = Medusa(
vllm_config=self.vllm_config.speculative_config).to(
self.device)
# Load model weights
weights = loader.get_all_weights(draft_config, self.model)
self.model.load_weights(weights)
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: