From 58d2c47b9a65e43b464a6737caad0d4bce4e2375 Mon Sep 17 00:00:00 2001 From: bk-201 Date: Mon, 15 Dec 2025 07:49:52 +0000 Subject: [PATCH] update punica_wrapper_mapping Signed-off-by: bk-201 --- vllm/lora/model_manager.py | 98 +++++++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 4 +- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 9bffa17b36712..33e147195a6f7 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -42,6 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget logger = init_logger(__name__) T = TypeVar("T") +DEFAULT_WRAPPER_KEY = "__default__" class AdapterLRUCache(LRUCache[int, T]): @@ -117,6 +118,11 @@ class LoRAModelManager: device=self.device, max_loras=self.lora_config.max_loras, ) + + self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = { + DEFAULT_WRAPPER_KEY: self.punica_wrapper + } + self._maybe_init_mm(vllm_config) def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: @@ -132,8 +138,8 @@ class LoRAModelManager: self.supports_tower_connector_lora = False model_config: ModelConfig = vllm_config.model_config - self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() + if self.lora_config.enable_tower_connector_lora: self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info self.supports_tower_connector_lora = self.supports_mm and hasattr( @@ -153,24 +159,26 @@ class LoRAModelManager: MULTIMODAL_REGISTRY, ) limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values()) - - # For vision tower num_encoder_tokens = self.info.get_num_mm_encoder_tokens( mm_budget.get_encoder_budget() ) - self.mm_punica_wrapper_mapping = { - name: get_punica_wrapper( + + self.punica_wrapper_mapping = {} + + # Tower wrappers + for name in self.mm_mapping.tower_model: + self.punica_wrapper_mapping[name] = get_punica_wrapper( num_encoder_tokens, max_batches=self.max_num_seqs * limit_per_prompt, device=self.device, max_loras=self.lora_config.max_loras, ) - for name in self.mm_mapping.tower_model - } - # For language model - self.mm_punica_wrapper_mapping.update( - {self.mm_mapping.language_model[0]: self.punica_wrapper} + + # Language wrapper + self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = ( + self.punica_wrapper ) + # Use wrapper for connector if present. if self.mm_mapping.connector: if hasattr(self.info, "get_num_mm_connector_tokens"): @@ -183,7 +191,7 @@ class LoRAModelManager: device=self.device, max_loras=self.lora_config.max_loras, ) - self.mm_punica_wrapper_mapping.update( + self.punica_wrapper_mapping.update( { name: connector_punica_wrapper for name in self.mm_mapping.connector @@ -323,20 +331,19 @@ class LoRAModelManager: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # Default to the main language model wrapper - target_wrapper = self.punica_wrapper - - if self.supports_mm and self.supports_tower_connector_lora: + if not (self.supports_mm and self.supports_tower_connector_lora): + target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY] + else: if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: - target_name = self.mm_mapping.tower_model[0] - target_wrapper = self.mm_punica_wrapper_mapping[target_name] + target_prefix = self.mm_mapping.tower_model[0] elif ( mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector ): - target_name = self.mm_mapping.connector[0] - target_wrapper = self.mm_punica_wrapper_mapping[target_name] + target_prefix = self.mm_mapping.connector[0] else: - target_name = self.mm_mapping.language_model[0] - target_wrapper = self.mm_punica_wrapper_mapping[target_name] + target_prefix = self.mm_mapping.language_model[0] + + target_wrapper = self.punica_wrapper_mapping[target_prefix] target_wrapper.update_metadata( mapping, @@ -369,7 +376,7 @@ class LoRAModelManager: if self._filter_unsupported_mm_module(module_name): logger.warning( "Regarding %s, vLLM currently only supports adding LoRA to" - " language model, {module_name} will be ignored.", + " language model, %s will be ignored.", self.model.__class__.__name__, module_name, ) @@ -432,10 +439,10 @@ class LoRAModelManager: self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. - if self.supports_mm and self.supports_tower_connector_lora: - new_module.set_mapping(self._get_mm_punica_wrapper(module_name)) - else: - new_module.set_mapping(self.punica_wrapper) + wrapper = self._get_punica_wrapper_for_module(module_name) + if wrapper is None: + continue + new_module.set_mapping(wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( @@ -551,31 +558,36 @@ class LoRAModelManager: language model. LoRA for other modules, such as the vision tower, will be filtered out. """ - if self.supports_mm: - prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model - if self.supports_tower_connector_lora: - return self._get_mm_punica_wrapper(module_name) is None - else: - return any([module_name.startswith(prefix) for prefix in prefix_lst]) - return False + if not self.supports_mm: + return False - def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: + if self.supports_tower_connector_lora: + return self._get_punica_wrapper_for_module(module_name) is None + + prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model + return any(module_name.startswith(prefix) for prefix in prefix_lst) + + def _get_punica_wrapper_for_module( + self, module_name: str + ) -> PunicaWrapperBase | None: """ Match the corresponding punica_wrapper based on module_name, and return None if lora is not supported for this module. """ - if self.supports_tower_connector_lora: + best_prefix = None + for prefix in self.punica_wrapper_mapping: + if prefix == DEFAULT_WRAPPER_KEY: + continue # Ensure matching by the longest prefix. - sorted_prefixes = sorted( - self.mm_punica_wrapper_mapping.keys(), - key=lambda x: len(x), - reverse=True, - ) + if module_name.startswith(prefix) and ( + best_prefix is None or len(prefix) > len(best_prefix) + ): + best_prefix = prefix - for prefix in sorted_prefixes: - if module_name.startswith(prefix): - return self.mm_punica_wrapper_mapping[prefix] - return None + if best_prefix is not None: + return self.punica_wrapper_mapping[best_prefix] + + return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY) def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4f214a20595c..ce5bc48ebaafb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2161,9 +2161,7 @@ class GPUModelRunner( # pos_info.length may overcount (e.g., special tokens in Qwen-VL). # Fall back to length if is_embed is None. num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined] - pos_info.length - if pos_info.is_embed is None - else pos_info.is_embed.sum() + pos_info.get_num_embeds() ) prompt_lora_mapping.append(lora_id) token_lora_mapping.extend([lora_id] * num_tokens)