mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 00:17:03 +08:00
fix bug
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
c0cc07e7ee
commit
598052b04e
@ -11,6 +11,7 @@ import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (
|
||||
@ -42,6 +43,7 @@ from vllm.model_executor.utils import get_packed_modules_mapping
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils.cache import LRUCache
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.worker.utils import MultiModalBudget
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -302,7 +304,7 @@ class LoRAModelManager:
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig | None,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
@ -340,7 +342,7 @@ class LoRAModelManager:
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
||||
self._init_multimodal_config(model_config)
|
||||
self._init_multimodal_config(vllm_config)
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
@ -351,7 +353,7 @@ class LoRAModelManager:
|
||||
|
||||
self.model.lora_manager = self
|
||||
|
||||
def _init_multimodal_config(self, model_config):
|
||||
def _init_multimodal_config(self, vllm_config: VllmConfig):
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -359,25 +361,27 @@ class LoRAModelManager:
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
# For v0 compatibility
|
||||
self.supports_mm_lora = False
|
||||
if model_config is not None:
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.info = self.mm_registry.create_processor(model_config).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
|
||||
if not self.supports_mm_lora:
|
||||
return
|
||||
|
||||
mm_budget = MultiModalBudget(
|
||||
model_config,
|
||||
vllm_config.scheduler_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
)
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_config = model_config.multimodal_config
|
||||
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
||||
|
||||
# For vision tower
|
||||
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
||||
self.max_num_batched_tokens
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name: get_punica_wrapper(
|
||||
@ -911,7 +915,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
@ -920,7 +924,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
max_num_batched_tokens,
|
||||
vocab_size,
|
||||
lora_config,
|
||||
model_config,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
@ -994,7 +998,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs,
|
||||
@ -1008,7 +1012,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
model_config=model_config,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (
|
||||
LoRAModel,
|
||||
@ -69,7 +69,7 @@ class WorkerLoRAManager:
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
model_config: ModelConfig | None = None,
|
||||
vllm_config: VllmConfig,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -79,7 +79,7 @@ class WorkerLoRAManager:
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
model_config=model_config,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
@ -212,7 +212,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
model_config: ModelConfig | None = None,
|
||||
vllm_config: VllmConfig,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -222,7 +222,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
model_config=model_config,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
@ -59,7 +58,6 @@ class DummyDecoderData(NamedTuple):
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
multi_modal_token_ids: list[int]
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
@ -324,13 +322,10 @@ class MultiModalProfiler(Generic[_I]):
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
multi_modal_token_ids = copy.deepcopy(prompt_token_ids)
|
||||
|
||||
return DummyDecoderData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
multi_modal_token_ids=multi_modal_token_ids,
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
|
||||
@ -3620,7 +3620,7 @@ class GPUModelRunner(
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model, self.vllm_config, self.device, self.model_config
|
||||
self.model, self.vllm_config, self.device
|
||||
)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info_once("Loading drafter model...")
|
||||
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
@ -33,7 +33,6 @@ class LoRAModelRunnerMixin:
|
||||
model: nn.Module,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> nn.Module:
|
||||
if not supports_lora(model):
|
||||
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
@ -44,7 +43,7 @@ class LoRAModelRunnerMixin:
|
||||
device,
|
||||
model.embedding_modules,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model, model_config)
|
||||
return self.lora_manager.create_lora_manager(model, vllm_config)
|
||||
|
||||
def _set_active_loras(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user