Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-04 16:57:49 +00:00
parent c0cc07e7ee
commit 598052b04e
5 changed files with 29 additions and 31 deletions

View File

@ -11,6 +11,7 @@ import safetensors.torch
import torch
from torch import nn
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig, ModelConfig
from vllm.logger import init_logger
from vllm.lora.layers import (
@ -42,6 +43,7 @@ from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils.cache import LRUCache
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__)
@ -302,7 +304,7 @@ class LoRAModelManager:
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
model_config: ModelConfig | None,
vllm_config: VllmConfig,
device: torch.device,
):
"""Create a LoRAModelManager and adapter for a given model.
@ -340,7 +342,7 @@ class LoRAModelManager:
f" {self.model.__class__.__name__}."
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
self._init_multimodal_config(model_config)
self._init_multimodal_config(vllm_config)
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
@ -351,7 +353,7 @@ class LoRAModelManager:
self.model.lora_manager = self
def _init_multimodal_config(self, model_config):
def _init_multimodal_config(self, vllm_config: VllmConfig):
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
@ -359,25 +361,27 @@ class LoRAModelManager:
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility
self.supports_mm_lora = False
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
model_config: ModelConfig = vllm_config.model_config
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
if not self.supports_mm_lora:
return
mm_budget = MultiModalBudget(
model_config,
vllm_config.scheduler_config,
MULTIMODAL_REGISTRY,
)
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
# For vision tower
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
self.max_num_batched_tokens
mm_budget.get_encoder_budget()
)
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
@ -911,7 +915,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
model_config: ModelConfig,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
@ -920,7 +924,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
max_num_batched_tokens,
vocab_size,
lora_config,
model_config,
vllm_config,
device,
)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
@ -994,7 +998,7 @@ def create_lora_manager(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
model_config: ModelConfig,
vllm_config: VllmConfig,
device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs,
@ -1008,7 +1012,7 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
model_config=model_config,
vllm_config=vllm_config,
device=device,
**kwargs,
)

View File

@ -6,7 +6,7 @@ from typing import Any, Literal
import torch
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.models import (
LoRAModel,
@ -69,7 +69,7 @@ class WorkerLoRAManager:
def create_lora_manager(
self,
model: torch.nn.Module,
model_config: ModelConfig | None = None,
vllm_config: VllmConfig,
) -> Any:
lora_manager = create_lora_manager(
model,
@ -79,7 +79,7 @@ class WorkerLoRAManager:
lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls,
model_config=model_config,
vllm_config=vllm_config,
)
self._adapter_manager = lora_manager
return lora_manager.model
@ -212,7 +212,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager(
self,
model: torch.nn.Module,
model_config: ModelConfig | None = None,
vllm_config: VllmConfig,
) -> Any:
lora_manager = create_lora_manager(
model,
@ -222,7 +222,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
model_config=model_config,
vllm_config=vllm_config,
)
self._adapter_manager = lora_manager
return lora_manager.model

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
@ -59,7 +58,6 @@ class DummyDecoderData(NamedTuple):
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_placeholders: MultiModalPlaceholderDict
multi_modal_token_ids: list[int]
_I = TypeVar("_I", bound=BaseProcessingInfo)
@ -324,13 +322,10 @@ class MultiModalProfiler(Generic[_I]):
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
multi_modal_token_ids = copy.deepcopy(prompt_token_ids)
return DummyDecoderData(
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
multi_modal_placeholders=mm_inputs["mm_placeholders"],
multi_modal_token_ids=multi_modal_token_ids,
)
def _get_mm_max_tokens(

View File

@ -3620,7 +3620,7 @@ class GPUModelRunner(
)
if self.lora_config:
self.model = self.load_lora_model(
self.model, self.vllm_config, self.device, self.model_config
self.model, self.vllm_config, self.device
)
if hasattr(self, "drafter"):
logger.info_once("Loading drafter model...")

View File

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping, LoRAMappingType
@ -33,7 +33,6 @@ class LoRAModelRunnerMixin:
model: nn.Module,
vllm_config: VllmConfig,
device: torch.device,
model_config: ModelConfig | None = None,
) -> nn.Module:
if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
@ -44,7 +43,7 @@ class LoRAModelRunnerMixin:
device,
model.embedding_modules,
)
return self.lora_manager.create_lora_manager(model, model_config)
return self.lora_manager.create_lora_manager(model, vllm_config)
def _set_active_loras(
self,