Move forward

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-12 00:41:48 +00:00
parent 208dc0c954
commit d4f39dc38a

View File

@ -103,7 +103,7 @@ class LoRAModelManager:
f" {self.model.__class__.__name__}." f" {self.model.__class__.__name__}."
self.packed_modules_mapping = process_packed_modules_mapping(self.model) self.packed_modules_mapping = process_packed_modules_mapping(self.model)
self._init_multimodal_config(vllm_config) self._maybe_init_mm(vllm_config)
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {} self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
@ -114,7 +114,7 @@ class LoRAModelManager:
self.model.lora_manager = self self.model.lora_manager = self
def _init_multimodal_config(self, vllm_config: VllmConfig | None = None): def _maybe_init_mm(self, vllm_config: VllmConfig):
# Used to indicate whether the model is a multimodal model # Used to indicate whether the model is a multimodal model
self.supports_mm: bool = ( self.supports_mm: bool = (
supports_multimodal(self.model) supports_multimodal(self.model)
@ -122,24 +122,26 @@ class LoRAModelManager:
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping") and hasattr(self.model, "get_mm_mapping")
) )
if not self.supports_mm:
self.supports_mm_lora = False
if self.supports_mm and vllm_config is not None:
model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
if not self.supports_mm_lora:
return return
assert vllm_config is not None, ( self.supports_tower_connector_lora = False
"vllm_config should not be None when supports_mm_lora is True" model_config: ModelConfig = vllm_config.model_config
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
if self.lora_config.enable_tower_connector_lora:
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
if not self.supports_tower_connector_lora:
return
logger.warning(
"LoRA for the tower and connector of multimodal models is "
"experimental and may contain bugs. Please report any related issues on "
"GitHub if you encounter them."
) )
mm_budget = MultiModalBudget( mm_budget = MultiModalBudget(
model_config, model_config,
vllm_config.scheduler_config, vllm_config.scheduler_config,
@ -318,7 +320,7 @@ class LoRAModelManager:
# Default to the main language model wrapper # Default to the main language model wrapper
target_wrapper = self.punica_wrapper target_wrapper = self.punica_wrapper
if self.supports_mm_lora: if self.supports_tower_connector_lora:
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model: if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_name = self.mm_mapping.tower_model[0] target_name = self.mm_mapping.tower_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name] target_wrapper = self.mm_punica_wrapper_mapping[target_name]
@ -361,8 +363,9 @@ class LoRAModelManager:
if self._filter_unsupported_mm_module(module_name): if self._filter_unsupported_mm_module(module_name):
logger.warning( logger.warning(
"Module %s does not support adding LoRA for " "Regarding %s, vLLM currently only supports adding LoRA to"
"now and has been ignored.", " language model, {module_name} will be ignored.",
self.model.__class__.__name__,
module_name, module_name,
) )
continue continue
@ -424,7 +427,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
if self.supports_mm_lora: if self.supports_tower_connector_lora:
new_module.set_mapping(self._get_mm_punica_wrapper(module_name)) new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
else: else:
new_module.set_mapping(self.punica_wrapper) new_module.set_mapping(self.punica_wrapper)
@ -545,7 +548,7 @@ class LoRAModelManager:
""" """
if self.supports_mm: if self.supports_mm:
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
if self.supports_mm_lora: if self.supports_tower_connector_lora:
return self._get_mm_punica_wrapper(module_name) is None return self._get_mm_punica_wrapper(module_name) is None
else: else:
return any([module_name.startswith(prefix) for prefix in prefix_lst]) return any([module_name.startswith(prefix) for prefix in prefix_lst])
@ -556,7 +559,7 @@ class LoRAModelManager:
Match the corresponding punica_wrapper based on module_name, Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module. and return None if lora is not supported for this module.
""" """
if self.supports_mm_lora: if self.supports_tower_connector_lora:
# Ensure matching by the longest prefix. # Ensure matching by the longest prefix.
sorted_prefixes = sorted( sorted_prefixes = sorted(
self.mm_punica_wrapper_mapping.keys(), self.mm_punica_wrapper_mapping.keys(),