mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
04eb88dc80
commit
c6b636f9fb
@ -117,34 +117,13 @@ def test_prepare_inputs():
|
|||||||
])
|
])
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
|
||||||
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
|
||||||
mock_registry, mock_get_layers, mock_get_pp_group, method,
|
|
||||||
proposer_helper, draft_model_dir, target_attribute_path):
|
proposer_helper, draft_model_dir, target_attribute_path):
|
||||||
|
|
||||||
# Setup mock for model class
|
# Setup model mock
|
||||||
mock_model_cls = mock.MagicMock()
|
mock_model = mock.MagicMock()
|
||||||
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
|
mock_get_model.return_value = mock_model
|
||||||
"test_arch")
|
|
||||||
|
|
||||||
# Create a real context manager for mocks
|
|
||||||
class MockContextManager:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Make the mocks return actual context manager objects
|
|
||||||
mock_set_dtype.return_value = MockContextManager()
|
|
||||||
mock_set_config.return_value = MockContextManager()
|
|
||||||
|
|
||||||
# Setup mocks for attention layers
|
# Setup mocks for attention layers
|
||||||
target_attn_layers = {
|
target_attn_layers = {
|
||||||
@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
|||||||
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
||||||
mock_get_pp_group.return_value = mock_pp_group
|
mock_get_pp_group.return_value = mock_pp_group
|
||||||
|
|
||||||
# Setup model loader mock
|
|
||||||
mock_loader = mock.MagicMock()
|
|
||||||
mock_get_loader.return_value = mock_loader
|
|
||||||
|
|
||||||
# Setup model mock
|
|
||||||
mock_model = mock.MagicMock()
|
|
||||||
mock_model_cls.return_value = mock_model
|
|
||||||
mock_model.to.return_value = mock_model
|
|
||||||
|
|
||||||
# Configure mock to test the attribute sharing path
|
|
||||||
if method == "eagle":
|
|
||||||
# For eagle, test the lm_head path
|
|
||||||
mock_model.load_weights.return_value = {
|
|
||||||
"model.embed_tokens.weight": torch.zeros(1)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# For eagle3, test the embed_tokens path
|
|
||||||
mock_model.load_weights.return_value = {}
|
|
||||||
|
|
||||||
# Setup target model with the appropriate attributes
|
# Setup target model with the appropriate attributes
|
||||||
target_model = mock.MagicMock()
|
target_model = mock.MagicMock()
|
||||||
|
|
||||||
@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
|||||||
proposer.load_model(target_model)
|
proposer.load_model(target_model)
|
||||||
|
|
||||||
# Verify common interactions
|
# Verify common interactions
|
||||||
mock_get_loader.assert_called_once()
|
mock_get_model.assert_called_once()
|
||||||
mock_model_cls.assert_called_once()
|
|
||||||
mock_model.to.assert_called_once()
|
|
||||||
mock_model.load_weights.assert_called_once()
|
|
||||||
|
|
||||||
# Verify the loader was called with the right config
|
|
||||||
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
|
|
||||||
|
|
||||||
# Verify the specific attribute sharing based on the method
|
# Verify the specific attribute sharing based on the method
|
||||||
if method == "eagle":
|
if method == "eagle":
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import LoadConfig, LoadFormat, VllmConfig
|
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
||||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||||
BitsAndBytesModelLoader)
|
BitsAndBytesModelLoader)
|
||||||
@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|
||||||
|
|
||||||
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
|
def get_model(*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
model_config: Optional[ModelConfig] = None) -> nn.Module:
|
||||||
loader = get_model_loader(vllm_config.load_config)
|
loader = get_model_loader(vllm_config.load_config)
|
||||||
return loader.load_model(vllm_config=vllm_config)
|
if model_config is None:
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
return loader.load_model(vllm_config=vllm_config,
|
||||||
|
model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, *, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
"""Load a model with the given configurations."""
|
"""Load a model with the given configurations."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
|
|
||||||
|
|||||||
@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
fall_back_to_pt=True,
|
fall_back_to_pt=True,
|
||||||
allow_patterns_overrides=None)
|
allow_patterns_overrides=None)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
model = initialize_model(vllm_config=vllm_config)
|
model = initialize_model(vllm_config=vllm_config,
|
||||||
|
model_config=model_config)
|
||||||
|
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
loaded_weights = model.load_weights(
|
loaded_weights = model.load_weights(
|
||||||
|
|||||||
@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
pass # Nothing to download
|
pass # Nothing to download
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
|
|||||||
@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model)
|
self._prepare_weights(model_config.model)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
local_model_path = self._prepare_weights(model_config.model)
|
local_model_path = self._prepare_weights(model_config.model)
|
||||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||||
# we can only know if tie word embeddings after mapping weights
|
# we can only know if tie word embeddings after mapping weights
|
||||||
|
|||||||
@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
|||||||
"""Download model if necessary"""
|
"""Download model if necessary"""
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
"""Perform streaming of the model to destination"""
|
"""Perform streaming of the model to destination"""
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
|
|||||||
@ -100,9 +100,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
|
model_config: ModelConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|||||||
@ -93,8 +93,8 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
with self.tensorizer_config.open_stream():
|
with self.tensorizer_config.open_stream():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig,
|
||||||
model_config = vllm_config.model_config
|
model_config: ModelConfig) -> nn.Module:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
self._verify_config(model_config, parallel_config)
|
self._verify_config(model_config, parallel_config)
|
||||||
|
|
||||||
|
|||||||
@ -42,9 +42,11 @@ def initialize_model(
|
|||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
model_class: Optional[type[nn.Module]] = None,
|
model_class: Optional[type[nn.Module]] = None,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_config = vllm_config.model_config
|
if model_config is None:
|
||||||
|
model_config = vllm_config.model_config
|
||||||
if model_class is None:
|
if model_class is None:
|
||||||
model_class, _ = get_model_architecture(model_config)
|
model_class, _ = get_model_architecture(model_config)
|
||||||
|
|
||||||
|
|||||||
@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = vllm_config. \
|
self.config = vllm_config. \
|
||||||
speculative_config.draft_model_config.hf_config
|
speculative_config.draft_model_config.hf_config
|
||||||
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||||
|
vllm_config.parallel_config)
|
||||||
self.model = LlamaModel(vllm_config=vllm_config,
|
self.model = LlamaModel(vllm_config=vllm_config,
|
||||||
prefix="model",
|
prefix="model",
|
||||||
start_layer_id=start_layer_id)
|
start_layer_id=target_layer_num)
|
||||||
|
|
||||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||||
|
|||||||
@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = vllm_config. \
|
self.config = vllm_config. \
|
||||||
speculative_config.draft_model_config.hf_config
|
speculative_config.draft_model_config.hf_config
|
||||||
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||||
|
vllm_config.parallel_config)
|
||||||
self.model = LlamaModel(vllm_config=vllm_config,
|
self.model = LlamaModel(vllm_config=vllm_config,
|
||||||
start_layer_id=start_layer_id,
|
prefix="model",
|
||||||
prefix="model")
|
start_layer_id=target_layer_num)
|
||||||
|
|
||||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||||
scale=logit_scale)
|
scale=logit_scale)
|
||||||
self.draft_id_to_target_id = nn.Parameter(
|
self.draft_id_to_target_id = nn.Parameter(
|
||||||
torch.zeros((self.config.draft_vocab_size),
|
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
|
||||||
dtype=torch.long).type(torch.LongTensor),
|
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -51,10 +51,7 @@ class Medusa(nn.Module):
|
|||||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
if hasattr(vllm_config, 'draft_model_config'):
|
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||||
config = vllm_config.draft_model_config.hf_config
|
|
||||||
else:
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
|
|||||||
@ -4,14 +4,11 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, VllmConfig,
|
||||||
get_layers_from_vllm_config, set_current_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.utils import (
|
|
||||||
process_weights_after_loading, set_default_torch_dtype)
|
|
||||||
from vllm.model_executor.models import ModelRegistry
|
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
@ -280,51 +277,28 @@ class EagleProposer:
|
|||||||
return cu_num_tokens, token_indices
|
return cu_num_tokens, token_indices
|
||||||
|
|
||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
draft_model_config = \
|
||||||
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
self.vllm_config.speculative_config.draft_model_config
|
||||||
self.vllm_config.parallel_config)
|
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||||
|
|
||||||
draft_model_config = \
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
self.vllm_config.speculative_config.draft_model_config
|
model_config=draft_model_config)
|
||||||
# FIXME(lily): This does not handle with distributed inference.
|
|
||||||
target_device = self.vllm_config.device_config.device
|
|
||||||
# We need to set the vllm_config here to register attention
|
|
||||||
# layers in the forward context.
|
|
||||||
with set_default_torch_dtype(
|
|
||||||
draft_model_config.dtype), set_current_vllm_config(
|
|
||||||
self.vllm_config):
|
|
||||||
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
|
|
||||||
draft_model_config.architectures)
|
|
||||||
self.model = draft_model_cls(
|
|
||||||
vllm_config=self.vllm_config,
|
|
||||||
start_layer_id=target_layer_num).to(target_device)
|
|
||||||
|
|
||||||
draft_attn_layer_names = (
|
draft_attn_layer_names = (
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||||
target_attn_layer_names)
|
target_attn_layer_names)
|
||||||
assert len(draft_attn_layer_names) == 1
|
assert len(draft_attn_layer_names) == 1
|
||||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||||
loaded_weights = self.model.load_weights(
|
|
||||||
loader.get_all_weights(draft_model_config, self.model))
|
|
||||||
|
|
||||||
process_weights_after_loading(self.model, draft_model_config,
|
|
||||||
target_device)
|
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
assert "model.embed_tokens.weight" not in loaded_weights, \
|
|
||||||
"For PP = 1, Eagle draft should share embed with target model"
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head shares the same vocab embedding" \
|
"The EAGLE head shares the same vocab embedding" \
|
||||||
" with the target model."
|
" with the target model."
|
||||||
)
|
)
|
||||||
self.model.model.embed_tokens = target_model.model.embed_tokens
|
self.model.model.embed_tokens = target_model.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
assert "model.embed_tokens.weight" in loaded_weights, \
|
|
||||||
"For PP > 1, Eagle draft checkpoint should its own copy of "
|
|
||||||
" the model.embed_tokens.weight"
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
|
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
|
||||||
" weights instead of sharing them with the target model."
|
" weights instead of sharing them with the target model."
|
||||||
|
|||||||
@ -3,12 +3,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|
||||||
from vllm.model_executor.models.medusa import Medusa
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
# Initialize logger
|
# Initialize logger
|
||||||
@ -49,20 +47,9 @@ class MedusaProposer:
|
|||||||
return [list(row) for row in zip(*draft_tokens)]
|
return [list(row) for row in zip(*draft_tokens)]
|
||||||
|
|
||||||
def load_model(self, target_model: nn.Module) -> None:
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
# Get model loader and config
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
model_config=self.vllm_config.
|
||||||
draft_config = self.vllm_config.speculative_config.draft_model_config
|
speculative_config.draft_model_config)
|
||||||
|
|
||||||
# Load model with proper dtype and config
|
|
||||||
with set_default_torch_dtype(draft_config.dtype), \
|
|
||||||
set_current_vllm_config(self.vllm_config):
|
|
||||||
self.model = Medusa(
|
|
||||||
vllm_config=self.vllm_config.speculative_config).to(
|
|
||||||
self.device)
|
|
||||||
|
|
||||||
# Load model weights
|
|
||||||
weights = loader.get_all_weights(draft_config, self.model)
|
|
||||||
self.model.load_weights(weights)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(self, num_tokens: int) -> None:
|
def dummy_run(self, num_tokens: int) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user