update punica_wrapper_mapping

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-15 07:49:52 +00:00
parent c1bb71ef6b
commit 58d2c47b9a
2 changed files with 56 additions and 46 deletions

View File

@ -42,6 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__)
T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "__default__"
class AdapterLRUCache(LRUCache[int, T]):
@ -117,6 +118,11 @@ class LoRAModelManager:
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: self.punica_wrapper
}
self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -132,8 +138,8 @@ class LoRAModelManager:
self.supports_tower_connector_lora = False
model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
@ -153,24 +159,26 @@ class LoRAModelManager:
MULTIMODAL_REGISTRY,
)
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
# For vision tower
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget()
)
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
self.punica_wrapper_mapping = {}
# Tower wrappers
for name in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[name] = get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
# For language model
self.mm_punica_wrapper_mapping.update(
{self.mm_mapping.language_model[0]: self.punica_wrapper}
# Language wrapper
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
self.punica_wrapper
)
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"):
@ -183,7 +191,7 @@ class LoRAModelManager:
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.mm_punica_wrapper_mapping.update(
self.punica_wrapper_mapping.update(
{
name: connector_punica_wrapper
for name in self.mm_mapping.connector
@ -323,20 +331,19 @@ class LoRAModelManager:
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# Default to the main language model wrapper
target_wrapper = self.punica_wrapper
if self.supports_mm and self.supports_tower_connector_lora:
if not (self.supports_mm and self.supports_tower_connector_lora):
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
else:
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_name = self.mm_mapping.tower_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
target_prefix = self.mm_mapping.tower_model[0]
elif (
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
):
target_name = self.mm_mapping.connector[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
target_prefix = self.mm_mapping.connector[0]
else:
target_name = self.mm_mapping.language_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
target_prefix = self.mm_mapping.language_model[0]
target_wrapper = self.punica_wrapper_mapping[target_prefix]
target_wrapper.update_metadata(
mapping,
@ -369,7 +376,7 @@ class LoRAModelManager:
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to"
" language model, {module_name} will be ignored.",
" language model, %s will be ignored.",
self.model.__class__.__name__,
module_name,
)
@ -432,10 +439,10 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
if self.supports_mm and self.supports_tower_connector_lora:
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
else:
new_module.set_mapping(self.punica_wrapper)
wrapper = self._get_punica_wrapper_for_module(module_name)
if wrapper is None:
continue
new_module.set_mapping(wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@ -551,31 +558,36 @@ class LoRAModelManager:
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
if self.supports_tower_connector_lora:
return self._get_mm_punica_wrapper(module_name) is None
else:
return any([module_name.startswith(prefix) for prefix in prefix_lst])
return False
if not self.supports_mm:
return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
if self.supports_tower_connector_lora:
return self._get_punica_wrapper_for_module(module_name) is None
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
return any(module_name.startswith(prefix) for prefix in prefix_lst)
def _get_punica_wrapper_for_module(
self, module_name: str
) -> PunicaWrapperBase | None:
"""
Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module.
"""
if self.supports_tower_connector_lora:
best_prefix = None
for prefix in self.punica_wrapper_mapping:
if prefix == DEFAULT_WRAPPER_KEY:
continue
# Ensure matching by the longest prefix.
sorted_prefixes = sorted(
self.mm_punica_wrapper_mapping.keys(),
key=lambda x: len(x),
reverse=True,
)
if module_name.startswith(prefix) and (
best_prefix is None or len(prefix) > len(best_prefix)
):
best_prefix = prefix
for prefix in sorted_prefixes:
if module_name.startswith(prefix):
return self.mm_punica_wrapper_mapping[prefix]
return None
if best_prefix is not None:
return self.punica_wrapper_mapping[best_prefix]
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")

View File

@ -2161,9 +2161,7 @@ class GPUModelRunner(
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None.
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length
if pos_info.is_embed is None
else pos_info.is_embed.sum()
pos_info.get_num_embeds()
)
prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens)