update punica_wrapper_mapping

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-15 07:49:52 +00:00
parent c1bb71ef6b
commit 58d2c47b9a
2 changed files with 56 additions and 46 deletions

View File

@ -42,6 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
DEFAULT_WRAPPER_KEY = "__default__"
class AdapterLRUCache(LRUCache[int, T]): class AdapterLRUCache(LRUCache[int, T]):
@ -117,6 +118,11 @@ class LoRAModelManager:
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
DEFAULT_WRAPPER_KEY: self.punica_wrapper
}
self._maybe_init_mm(vllm_config) self._maybe_init_mm(vllm_config)
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None: def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
@ -132,8 +138,8 @@ class LoRAModelManager:
self.supports_tower_connector_lora = False self.supports_tower_connector_lora = False
model_config: ModelConfig = vllm_config.model_config model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
if self.lora_config.enable_tower_connector_lora: if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
@ -153,24 +159,26 @@ class LoRAModelManager:
MULTIMODAL_REGISTRY, MULTIMODAL_REGISTRY,
) )
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values()) limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
# For vision tower
num_encoder_tokens = self.info.get_num_mm_encoder_tokens( num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper( self.punica_wrapper_mapping = {}
# Tower wrappers
for name in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[name] = get_punica_wrapper(
num_encoder_tokens, num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt, max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
for name in self.mm_mapping.tower_model
} # Language wrapper
# For language model self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
self.mm_punica_wrapper_mapping.update( self.punica_wrapper
{self.mm_mapping.language_model[0]: self.punica_wrapper}
) )
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"): if hasattr(self.info, "get_num_mm_connector_tokens"):
@ -183,7 +191,7 @@ class LoRAModelManager:
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
self.mm_punica_wrapper_mapping.update( self.punica_wrapper_mapping.update(
{ {
name: connector_punica_wrapper name: connector_punica_wrapper
for name in self.mm_mapping.connector for name in self.mm_mapping.connector
@ -323,20 +331,19 @@ class LoRAModelManager:
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# Default to the main language model wrapper # Default to the main language model wrapper
target_wrapper = self.punica_wrapper if not (self.supports_mm and self.supports_tower_connector_lora):
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
if self.supports_mm and self.supports_tower_connector_lora: else:
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_name = self.mm_mapping.tower_model[0] target_prefix = self.mm_mapping.tower_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
elif ( elif (
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
): ):
target_name = self.mm_mapping.connector[0] target_prefix = self.mm_mapping.connector[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
else: else:
target_name = self.mm_mapping.language_model[0] target_prefix = self.mm_mapping.language_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
target_wrapper = self.punica_wrapper_mapping[target_prefix]
target_wrapper.update_metadata( target_wrapper.update_metadata(
mapping, mapping,
@ -369,7 +376,7 @@ class LoRAModelManager:
if self._filter_unsupported_mm_module(module_name): if self._filter_unsupported_mm_module(module_name):
logger.warning( logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to" "Regarding %s, vLLM currently only supports adding LoRA to"
" language model, {module_name} will be ignored.", " language model, %s will be ignored.",
self.model.__class__.__name__, self.model.__class__.__name__,
module_name, module_name,
) )
@ -432,10 +439,10 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
if self.supports_mm and self.supports_tower_connector_lora: wrapper = self._get_punica_wrapper_for_module(module_name)
new_module.set_mapping(self._get_mm_punica_wrapper(module_name)) if wrapper is None:
else: continue
new_module.set_mapping(self.punica_wrapper) new_module.set_mapping(wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), ( assert isinstance(module, BaseLayerWithLoRA), (
@ -551,31 +558,36 @@ class LoRAModelManager:
language model. LoRA for other modules, such as the vision tower, will language model. LoRA for other modules, such as the vision tower, will
be filtered out. be filtered out.
""" """
if self.supports_mm: if not self.supports_mm:
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model return False
if self.supports_tower_connector_lora:
return self._get_mm_punica_wrapper(module_name) is None
else:
return any([module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: if self.supports_tower_connector_lora:
return self._get_punica_wrapper_for_module(module_name) is None
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
return any(module_name.startswith(prefix) for prefix in prefix_lst)
def _get_punica_wrapper_for_module(
self, module_name: str
) -> PunicaWrapperBase | None:
""" """
Match the corresponding punica_wrapper based on module_name, Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module. and return None if lora is not supported for this module.
""" """
if self.supports_tower_connector_lora: best_prefix = None
for prefix in self.punica_wrapper_mapping:
if prefix == DEFAULT_WRAPPER_KEY:
continue
# Ensure matching by the longest prefix. # Ensure matching by the longest prefix.
sorted_prefixes = sorted( if module_name.startswith(prefix) and (
self.mm_punica_wrapper_mapping.keys(), best_prefix is None or len(prefix) > len(best_prefix)
key=lambda x: len(x), ):
reverse=True, best_prefix = prefix
)
for prefix in sorted_prefixes: if best_prefix is not None:
if module_name.startswith(prefix): return self.punica_wrapper_mapping[best_prefix]
return self.mm_punica_wrapper_mapping[prefix]
return None return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
def _register_packed_modules(self, module_full_name: str) -> None: def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".") parts = module_full_name.split(".")

View File

@ -2161,9 +2161,7 @@ class GPUModelRunner(
# pos_info.length may overcount (e.g., special tokens in Qwen-VL). # pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None. # Fall back to length if is_embed is None.
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined] num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length pos_info.get_num_embeds()
if pos_info.is_embed is None
else pos_info.is_embed.sum()
) )
prompt_lora_mapping.append(lora_id) prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens) token_lora_mapping.extend([lora_id] * num_tokens)