From d053aa73e1eae20342235eea9715ff0c380dc264 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 20 Dec 2025 01:47:11 +0000 Subject: [PATCH] Fix Signed-off-by: Jee Jee Li --- vllm/lora/model_manager.py | 112 ++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 65 deletions(-) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index f685db14af17c..05713bda91236 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from collections import OrderedDict from collections.abc import Callable -from dataclasses import dataclass from typing import TypeVar import regex as re @@ -47,12 +45,6 @@ T = TypeVar("T") DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model" -@dataclass(frozen=True) -class LoRATarget: - wrapper: PunicaWrapperBase - prefix: str - - class AdapterLRUCache(LRUCache[int, T]): def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) @@ -120,21 +112,6 @@ class LoRAModelManager: def _init_punica_wrapper( self, max_num_batched_tokens: int, vllm_config: VllmConfig ) -> None: - self.punica_wrapper_mapping: OrderedDict[str, PunicaWrapperBase] = OrderedDict() - llm_punica_wrapper = get_punica_wrapper( - max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device, - max_loras=self.lora_config.max_loras, - ) - - # NOTE This assumes the existence of a language model LoRA - self.punica_wrapper_mapping.setdefault( - DEFAULT_LANGUAGE_WRAPPER_KEY, llm_punica_wrapper - ) - self._maybe_init_mm(vllm_config) - - def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) @@ -142,13 +119,39 @@ class LoRAModelManager: # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping") ) - if not self.supports_mm: - return + self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {} + if self.supports_mm: + self._maybe_init_mm(vllm_config,max_num_batched_tokens) + else: + llm_punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=self.lora_config.max_loras, + ) + self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = ( + llm_punica_wrapper + ) + + def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None: self.supports_tower_connector_lora = False model_config: ModelConfig = vllm_config.model_config self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() + # Only one language model can be included in the model. + assert len(self.mm_mapping.language_model) == 1 + + # Language model punica wrapper + llm_punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=self.lora_config.max_loras, + ) + lm_prefix = self.mm_mapping.language_model[0] + self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper + if self.lora_config.enable_tower_connector_lora: self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info self.supports_tower_connector_lora = self.supports_mm and hasattr( @@ -156,6 +159,7 @@ class LoRAModelManager: ) if not self.supports_tower_connector_lora: return + logger.warning( "LoRA for the tower and connector of multimodal models is " "experimental and may contain bugs. Please report any related issues on " @@ -172,20 +176,6 @@ class LoRAModelManager: mm_budget.get_encoder_budget() ) - # Only one language model can be included in the model. - assert len(self.mm_mapping.language_model) == 1 - - # Update prefix of language model - lm_prefix = ( - self.mm_mapping.language_model[0] - if self.supports_mm - else DEFAULT_LANGUAGE_WRAPPER_KEY - ) - llm_punica_wrapper = self.punica_wrapper_mapping.pop( - DEFAULT_LANGUAGE_WRAPPER_KEY - ) - self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper - # Tower wrappers tower_punica_wrapper = get_punica_wrapper( num_encoder_tokens, @@ -217,15 +207,6 @@ class LoRAModelManager: "determine the connector's token budget for LoRA operations." ) - # Longest-prefix-first - self.punica_wrapper_mapping = OrderedDict( - sorted( - self.punica_wrapper_mapping.items(), - key=lambda x: len(x[0]), - reverse=True, - ) - ) - def __len__(self) -> int: return len(self._registered_adapters) @@ -366,10 +347,10 @@ class LoRAModelManager: else: target_prefix = self.mm_mapping.language_model[0] - target = self._get_lora_target(target_prefix) - assert target is not None + punica_wrapper = self._get_punica_wrapper(target_prefix) + assert punica_wrapper is not None - target.wrapper.update_metadata( + punica_wrapper.wrapper.update_metadata( mapping, self.lora_index_to_id, self.lora_slots + 1, @@ -397,8 +378,8 @@ class LoRAModelManager: if not self._match_target_modules(module_name): continue - target = self._get_lora_target(module_name) - if target is None: + punica_wrapper = self._get_punica_wrapper(module_name) + if punica_wrapper is None: logger.warning( "Regarding %s, vLLM currently only supports adding LoRA to" " language model, %s will be ignored.", @@ -464,7 +445,7 @@ class LoRAModelManager: self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. - new_module.set_mapping(target.wrapper) + new_module.set_mapping(punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( @@ -485,7 +466,7 @@ class LoRAModelManager: if ( not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) - or self._get_lora_target(module_name) is None + or self._get_punica_wrapper(module_name) is None ): continue parts = module_name.split(".") @@ -574,23 +555,24 @@ class LoRAModelManager: for target_module in self.supported_lora_modules ) - def _get_lora_target(self, module_name: str) -> LoRATarget | None: + def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: """ Determine whether this module supports LoRA and which wrapper to use. """ # For language Model (early return) if not self.supports_mm: - wrapper = list(self.punica_wrapper_mapping.values())[0] - return LoRATarget(wrapper=wrapper, prefix=DEFAULT_LANGUAGE_WRAPPER_KEY) + return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] + + # For multimodal model - for prefix, wrapper in self.punica_wrapper_mapping.items(): - is_language_model = ( - prefix == DEFAULT_LANGUAGE_WRAPPER_KEY - and module_name.startswith(self.mm_mapping.language_model[0]) - ) - if is_language_model or module_name.startswith(prefix): - return LoRATarget(wrapper=wrapper, prefix=prefix) + # for prefix, wrapper in self.punica_wrapper_mapping.items(): + # is_language_model = ( + # prefix == DEFAULT_LANGUAGE_WRAPPER_KEY + # and module_name.startswith(self.mm_mapping.language_model[0]) + # ) + # if is_language_model or module_name.startswith(prefix): + # return LoRATarget(wrapper=wrapper, prefix=prefix) return None