mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
04eb88dc80
commit
c6b636f9fb
@ -117,34 +117,13 @@ def test_prepare_inputs():
|
||||
])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
||||
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
mock_registry, mock_get_layers, mock_get_pp_group, method,
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
proposer_helper, draft_model_dir, target_attribute_path):
|
||||
|
||||
# Setup mock for model class
|
||||
mock_model_cls = mock.MagicMock()
|
||||
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
|
||||
"test_arch")
|
||||
|
||||
# Create a real context manager for mocks
|
||||
class MockContextManager:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
# Make the mocks return actual context manager objects
|
||||
mock_set_dtype.return_value = MockContextManager()
|
||||
mock_set_config.return_value = MockContextManager()
|
||||
# Setup model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
# Setup mocks for attention layers
|
||||
target_attn_layers = {
|
||||
@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Setup model loader mock
|
||||
mock_loader = mock.MagicMock()
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
# Setup model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model_cls.return_value = mock_model
|
||||
mock_model.to.return_value = mock_model
|
||||
|
||||
# Configure mock to test the attribute sharing path
|
||||
if method == "eagle":
|
||||
# For eagle, test the lm_head path
|
||||
mock_model.load_weights.return_value = {
|
||||
"model.embed_tokens.weight": torch.zeros(1)
|
||||
}
|
||||
else:
|
||||
# For eagle3, test the embed_tokens path
|
||||
mock_model.load_weights.return_value = {}
|
||||
|
||||
# Setup target model with the appropriate attributes
|
||||
target_model = mock.MagicMock()
|
||||
|
||||
@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify common interactions
|
||||
mock_get_loader.assert_called_once()
|
||||
mock_model_cls.assert_called_once()
|
||||
mock_model.to.assert_called_once()
|
||||
mock_model.load_weights.assert_called_once()
|
||||
|
||||
# Verify the loader was called with the right config
|
||||
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
|
||||
mock_get_model.assert_called_once()
|
||||
|
||||
# Verify the specific attribute sharing based on the method
|
||||
if method == "eagle":
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, LoadFormat, VllmConfig
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||
BitsAndBytesModelLoader)
|
||||
@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
|
||||
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
|
||||
def get_model(*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: Optional[ModelConfig] = None) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
return loader.load_model(vllm_config=vllm_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, *, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
|
||||
|
||||
@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
|
||||
@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
|
||||
@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
|
||||
@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""Perform streaming of the model to destination"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
|
||||
@ -100,9 +100,9 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
@ -93,8 +93,8 @@ class TensorizerLoader(BaseModelLoader):
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
model_config = vllm_config.model_config
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
|
||||
@ -42,9 +42,11 @@ def initialize_model(
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: Optional[type[nn.Module]] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
|
||||
@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=start_layer_id)
|
||||
start_layer_id=target_layer_num)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
|
||||
@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
prefix="model",
|
||||
start_layer_id=target_layer_num)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.lm_head = ParallelLMHead(
|
||||
@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||
scale=logit_scale)
|
||||
self.draft_id_to_target_id = nn.Parameter(
|
||||
torch.zeros((self.config.draft_vocab_size),
|
||||
dtype=torch.long).type(torch.LongTensor),
|
||||
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
@ -51,10 +51,7 @@ class Medusa(nn.Module):
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
if hasattr(vllm_config, 'draft_model_config'):
|
||||
config = vllm_config.draft_model_config.hf_config
|
||||
else:
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.blocks = nn.ModuleList([
|
||||
|
||||
@ -4,14 +4,11 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@ -280,51 +277,28 @@ class EagleProposer:
|
||||
return cu_num_tokens, token_indices
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
||||
self.vllm_config.parallel_config)
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
# FIXME(lily): This does not handle with distributed inference.
|
||||
target_device = self.vllm_config.device_config.device
|
||||
# We need to set the vllm_config here to register attention
|
||||
# layers in the forward context.
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
|
||||
draft_model_config.architectures)
|
||||
self.model = draft_model_cls(
|
||||
vllm_config=self.vllm_config,
|
||||
start_layer_id=target_layer_num).to(target_device)
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=draft_model_config)
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||
target_attn_layer_names)
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||
loaded_weights = self.model.load_weights(
|
||||
loader.get_all_weights(draft_model_config, self.model))
|
||||
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1:
|
||||
assert "model.embed_tokens.weight" not in loaded_weights, \
|
||||
"For PP = 1, Eagle draft should share embed with target model"
|
||||
logger.info(
|
||||
"The EAGLE head shares the same vocab embedding" \
|
||||
" with the target model."
|
||||
)
|
||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||
else:
|
||||
assert "model.embed_tokens.weight" in loaded_weights, \
|
||||
"For PP > 1, Eagle draft checkpoint should its own copy of "
|
||||
" the model.embed_tokens.weight"
|
||||
logger.info(
|
||||
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
|
||||
" weights instead of sharing them with the target model."
|
||||
|
||||
@ -3,12 +3,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.medusa import Medusa
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Initialize logger
|
||||
@ -49,20 +47,9 @@ class MedusaProposer:
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
# Get model loader and config
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
draft_config = self.vllm_config.speculative_config.draft_model_config
|
||||
|
||||
# Load model with proper dtype and config
|
||||
with set_default_torch_dtype(draft_config.dtype), \
|
||||
set_current_vllm_config(self.vllm_config):
|
||||
self.model = Medusa(
|
||||
vllm_config=self.vllm_config.speculative_config).to(
|
||||
self.device)
|
||||
|
||||
# Load model weights
|
||||
weights = loader.get_all_weights(draft_config, self.model)
|
||||
self.model.load_weights(weights)
|
||||
self.model = get_model(vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.
|
||||
speculative_config.draft_model_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self, num_tokens: int) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user