mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 23:41:20 +08:00
update punica_wrapper_mapping
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
c1bb71ef6b
commit
58d2c47b9a
@ -42,6 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
DEFAULT_WRAPPER_KEY = "__default__"
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
@ -117,6 +118,11 @@ class LoRAModelManager:
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
||||
DEFAULT_WRAPPER_KEY: self.punica_wrapper
|
||||
}
|
||||
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||
@ -132,8 +138,8 @@ class LoRAModelManager:
|
||||
|
||||
self.supports_tower_connector_lora = False
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
@ -153,24 +159,26 @@ class LoRAModelManager:
|
||||
MULTIMODAL_REGISTRY,
|
||||
)
|
||||
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
||||
|
||||
# For vision tower
|
||||
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name: get_punica_wrapper(
|
||||
|
||||
self.punica_wrapper_mapping = {}
|
||||
|
||||
# Tower wrappers
|
||||
for name in self.mm_mapping.tower_model:
|
||||
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for name in self.mm_mapping.tower_model
|
||||
}
|
||||
# For language model
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{self.mm_mapping.language_model[0]: self.punica_wrapper}
|
||||
|
||||
# Language wrapper
|
||||
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
|
||||
self.punica_wrapper
|
||||
)
|
||||
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||
@ -183,7 +191,7 @@ class LoRAModelManager:
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
self.punica_wrapper_mapping.update(
|
||||
{
|
||||
name: connector_punica_wrapper
|
||||
for name in self.mm_mapping.connector
|
||||
@ -323,20 +331,19 @@ class LoRAModelManager:
|
||||
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# Default to the main language model wrapper
|
||||
target_wrapper = self.punica_wrapper
|
||||
|
||||
if self.supports_mm and self.supports_tower_connector_lora:
|
||||
if not (self.supports_mm and self.supports_tower_connector_lora):
|
||||
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
|
||||
else:
|
||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||
target_name = self.mm_mapping.tower_model[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
target_prefix = self.mm_mapping.tower_model[0]
|
||||
elif (
|
||||
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
|
||||
):
|
||||
target_name = self.mm_mapping.connector[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
target_prefix = self.mm_mapping.connector[0]
|
||||
else:
|
||||
target_name = self.mm_mapping.language_model[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
target_prefix = self.mm_mapping.language_model[0]
|
||||
|
||||
target_wrapper = self.punica_wrapper_mapping[target_prefix]
|
||||
|
||||
target_wrapper.update_metadata(
|
||||
mapping,
|
||||
@ -369,7 +376,7 @@ class LoRAModelManager:
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
logger.warning(
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
" language model, {module_name} will be ignored.",
|
||||
" language model, %s will be ignored.",
|
||||
self.model.__class__.__name__,
|
||||
module_name,
|
||||
)
|
||||
@ -432,10 +439,10 @@ class LoRAModelManager:
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
if self.supports_mm and self.supports_tower_connector_lora:
|
||||
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
||||
else:
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
wrapper = self._get_punica_wrapper_for_module(module_name)
|
||||
if wrapper is None:
|
||||
continue
|
||||
new_module.set_mapping(wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
@ -551,31 +558,36 @@ class LoRAModelManager:
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||
if self.supports_tower_connector_lora:
|
||||
return self._get_mm_punica_wrapper(module_name) is None
|
||||
else:
|
||||
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
if not self.supports_mm:
|
||||
return False
|
||||
|
||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
|
||||
if self.supports_tower_connector_lora:
|
||||
return self._get_punica_wrapper_for_module(module_name) is None
|
||||
|
||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||
return any(module_name.startswith(prefix) for prefix in prefix_lst)
|
||||
|
||||
def _get_punica_wrapper_for_module(
|
||||
self, module_name: str
|
||||
) -> PunicaWrapperBase | None:
|
||||
"""
|
||||
Match the corresponding punica_wrapper based on module_name,
|
||||
and return None if lora is not supported for this module.
|
||||
"""
|
||||
if self.supports_tower_connector_lora:
|
||||
best_prefix = None
|
||||
for prefix in self.punica_wrapper_mapping:
|
||||
if prefix == DEFAULT_WRAPPER_KEY:
|
||||
continue
|
||||
# Ensure matching by the longest prefix.
|
||||
sorted_prefixes = sorted(
|
||||
self.mm_punica_wrapper_mapping.keys(),
|
||||
key=lambda x: len(x),
|
||||
reverse=True,
|
||||
)
|
||||
if module_name.startswith(prefix) and (
|
||||
best_prefix is None or len(prefix) > len(best_prefix)
|
||||
):
|
||||
best_prefix = prefix
|
||||
|
||||
for prefix in sorted_prefixes:
|
||||
if module_name.startswith(prefix):
|
||||
return self.mm_punica_wrapper_mapping[prefix]
|
||||
return None
|
||||
if best_prefix is not None:
|
||||
return self.punica_wrapper_mapping[best_prefix]
|
||||
|
||||
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
|
||||
@ -2161,9 +2161,7 @@ class GPUModelRunner(
|
||||
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
||||
# Fall back to length if is_embed is None.
|
||||
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||
pos_info.length
|
||||
if pos_info.is_embed is None
|
||||
else pos_info.is_embed.sum()
|
||||
pos_info.get_num_embeds()
|
||||
)
|
||||
prompt_lora_mapping.append(lora_id)
|
||||
token_lora_mapping.extend([lora_id] * num_tokens)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user