mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 13:15:45 +08:00
Move forward
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
208dc0c954
commit
d4f39dc38a
@ -103,7 +103,7 @@ class LoRAModelManager:
|
|||||||
f" {self.model.__class__.__name__}."
|
f" {self.model.__class__.__name__}."
|
||||||
|
|
||||||
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
|
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
|
||||||
self._init_multimodal_config(vllm_config)
|
self._maybe_init_mm(vllm_config)
|
||||||
self.is_pooling_model = is_pooling_model(self.model)
|
self.is_pooling_model = is_pooling_model(self.model)
|
||||||
self.packed_modules: dict[str, list[str]] = {}
|
self.packed_modules: dict[str, list[str]] = {}
|
||||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||||
@ -114,7 +114,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self.model.lora_manager = self
|
self.model.lora_manager = self
|
||||||
|
|
||||||
def _init_multimodal_config(self, vllm_config: VllmConfig | None = None):
|
def _maybe_init_mm(self, vllm_config: VllmConfig):
|
||||||
# Used to indicate whether the model is a multimodal model
|
# Used to indicate whether the model is a multimodal model
|
||||||
self.supports_mm: bool = (
|
self.supports_mm: bool = (
|
||||||
supports_multimodal(self.model)
|
supports_multimodal(self.model)
|
||||||
@ -122,24 +122,26 @@ class LoRAModelManager:
|
|||||||
# text modules (e.g. ChatGLM)
|
# text modules (e.g. ChatGLM)
|
||||||
and hasattr(self.model, "get_mm_mapping")
|
and hasattr(self.model, "get_mm_mapping")
|
||||||
)
|
)
|
||||||
|
if not self.supports_mm:
|
||||||
self.supports_mm_lora = False
|
|
||||||
|
|
||||||
if self.supports_mm and vllm_config is not None:
|
|
||||||
model_config: ModelConfig = vllm_config.model_config
|
|
||||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
|
||||||
if self.lora_config.enable_tower_connector_lora:
|
|
||||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
|
||||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
|
||||||
self.info, "get_num_mm_encoder_tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.supports_mm_lora:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
assert vllm_config is not None, (
|
self.supports_tower_connector_lora = False
|
||||||
"vllm_config should not be None when supports_mm_lora is True"
|
model_config: ModelConfig = vllm_config.model_config
|
||||||
|
|
||||||
|
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
|
if self.lora_config.enable_tower_connector_lora:
|
||||||
|
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||||
|
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||||
|
self.info, "get_num_mm_encoder_tokens"
|
||||||
|
)
|
||||||
|
if not self.supports_tower_connector_lora:
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"LoRA for the tower and connector of multimodal models is "
|
||||||
|
"experimental and may contain bugs. Please report any related issues on "
|
||||||
|
"GitHub if you encounter them."
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_budget = MultiModalBudget(
|
mm_budget = MultiModalBudget(
|
||||||
model_config,
|
model_config,
|
||||||
vllm_config.scheduler_config,
|
vllm_config.scheduler_config,
|
||||||
@ -318,7 +320,7 @@ class LoRAModelManager:
|
|||||||
# Default to the main language model wrapper
|
# Default to the main language model wrapper
|
||||||
target_wrapper = self.punica_wrapper
|
target_wrapper = self.punica_wrapper
|
||||||
|
|
||||||
if self.supports_mm_lora:
|
if self.supports_tower_connector_lora:
|
||||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||||
target_name = self.mm_mapping.tower_model[0]
|
target_name = self.mm_mapping.tower_model[0]
|
||||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||||
@ -361,8 +363,9 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
if self._filter_unsupported_mm_module(module_name):
|
if self._filter_unsupported_mm_module(module_name):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Module %s does not support adding LoRA for "
|
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||||
"now and has been ignored.",
|
" language model, {module_name} will be ignored.",
|
||||||
|
self.model.__class__.__name__,
|
||||||
module_name,
|
module_name,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@ -424,7 +427,7 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
# All lora layers share the same punica_wrapper based on reference.
|
# All lora layers share the same punica_wrapper based on reference.
|
||||||
if self.supports_mm_lora:
|
if self.supports_tower_connector_lora:
|
||||||
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
||||||
else:
|
else:
|
||||||
new_module.set_mapping(self.punica_wrapper)
|
new_module.set_mapping(self.punica_wrapper)
|
||||||
@ -545,7 +548,7 @@ class LoRAModelManager:
|
|||||||
"""
|
"""
|
||||||
if self.supports_mm:
|
if self.supports_mm:
|
||||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||||
if self.supports_mm_lora:
|
if self.supports_tower_connector_lora:
|
||||||
return self._get_mm_punica_wrapper(module_name) is None
|
return self._get_mm_punica_wrapper(module_name) is None
|
||||||
else:
|
else:
|
||||||
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
||||||
@ -556,7 +559,7 @@ class LoRAModelManager:
|
|||||||
Match the corresponding punica_wrapper based on module_name,
|
Match the corresponding punica_wrapper based on module_name,
|
||||||
and return None if lora is not supported for this module.
|
and return None if lora is not supported for this module.
|
||||||
"""
|
"""
|
||||||
if self.supports_mm_lora:
|
if self.supports_tower_connector_lora:
|
||||||
# Ensure matching by the longest prefix.
|
# Ensure matching by the longest prefix.
|
||||||
sorted_prefixes = sorted(
|
sorted_prefixes = sorted(
|
||||||
self.mm_punica_wrapper_mapping.keys(),
|
self.mm_punica_wrapper_mapping.keys(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user