mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 13:51:23 +08:00
update punica_wrapper_mapping
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
c1bb71ef6b
commit
58d2c47b9a
@ -42,6 +42,7 @@ from vllm.v1.worker.utils import MultiModalBudget
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
DEFAULT_WRAPPER_KEY = "__default__"
|
||||||
|
|
||||||
|
|
||||||
class AdapterLRUCache(LRUCache[int, T]):
|
class AdapterLRUCache(LRUCache[int, T]):
|
||||||
@ -117,6 +118,11 @@ class LoRAModelManager:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {
|
||||||
|
DEFAULT_WRAPPER_KEY: self.punica_wrapper
|
||||||
|
}
|
||||||
|
|
||||||
self._maybe_init_mm(vllm_config)
|
self._maybe_init_mm(vllm_config)
|
||||||
|
|
||||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||||
@ -132,8 +138,8 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self.supports_tower_connector_lora = False
|
self.supports_tower_connector_lora = False
|
||||||
model_config: ModelConfig = vllm_config.model_config
|
model_config: ModelConfig = vllm_config.model_config
|
||||||
|
|
||||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
|
|
||||||
if self.lora_config.enable_tower_connector_lora:
|
if self.lora_config.enable_tower_connector_lora:
|
||||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
||||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||||
@ -153,24 +159,26 @@ class LoRAModelManager:
|
|||||||
MULTIMODAL_REGISTRY,
|
MULTIMODAL_REGISTRY,
|
||||||
)
|
)
|
||||||
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
||||||
|
|
||||||
# For vision tower
|
|
||||||
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
||||||
mm_budget.get_encoder_budget()
|
mm_budget.get_encoder_budget()
|
||||||
)
|
)
|
||||||
self.mm_punica_wrapper_mapping = {
|
|
||||||
name: get_punica_wrapper(
|
self.punica_wrapper_mapping = {}
|
||||||
|
|
||||||
|
# Tower wrappers
|
||||||
|
for name in self.mm_mapping.tower_model:
|
||||||
|
self.punica_wrapper_mapping[name] = get_punica_wrapper(
|
||||||
num_encoder_tokens,
|
num_encoder_tokens,
|
||||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
for name in self.mm_mapping.tower_model
|
|
||||||
}
|
# Language wrapper
|
||||||
# For language model
|
self.punica_wrapper_mapping[self.mm_mapping.language_model[0]] = (
|
||||||
self.mm_punica_wrapper_mapping.update(
|
self.punica_wrapper
|
||||||
{self.mm_mapping.language_model[0]: self.punica_wrapper}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use wrapper for connector if present.
|
# Use wrapper for connector if present.
|
||||||
if self.mm_mapping.connector:
|
if self.mm_mapping.connector:
|
||||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||||
@ -183,7 +191,7 @@ class LoRAModelManager:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
self.mm_punica_wrapper_mapping.update(
|
self.punica_wrapper_mapping.update(
|
||||||
{
|
{
|
||||||
name: connector_punica_wrapper
|
name: connector_punica_wrapper
|
||||||
for name in self.mm_mapping.connector
|
for name in self.mm_mapping.connector
|
||||||
@ -323,20 +331,19 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
# Default to the main language model wrapper
|
# Default to the main language model wrapper
|
||||||
target_wrapper = self.punica_wrapper
|
if not (self.supports_mm and self.supports_tower_connector_lora):
|
||||||
|
target_wrapper = self.punica_wrapper_mapping[DEFAULT_WRAPPER_KEY]
|
||||||
if self.supports_mm and self.supports_tower_connector_lora:
|
else:
|
||||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||||
target_name = self.mm_mapping.tower_model[0]
|
target_prefix = self.mm_mapping.tower_model[0]
|
||||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
|
||||||
elif (
|
elif (
|
||||||
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
|
mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector
|
||||||
):
|
):
|
||||||
target_name = self.mm_mapping.connector[0]
|
target_prefix = self.mm_mapping.connector[0]
|
||||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
|
||||||
else:
|
else:
|
||||||
target_name = self.mm_mapping.language_model[0]
|
target_prefix = self.mm_mapping.language_model[0]
|
||||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
|
||||||
|
target_wrapper = self.punica_wrapper_mapping[target_prefix]
|
||||||
|
|
||||||
target_wrapper.update_metadata(
|
target_wrapper.update_metadata(
|
||||||
mapping,
|
mapping,
|
||||||
@ -369,7 +376,7 @@ class LoRAModelManager:
|
|||||||
if self._filter_unsupported_mm_module(module_name):
|
if self._filter_unsupported_mm_module(module_name):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||||
" language model, {module_name} will be ignored.",
|
" language model, %s will be ignored.",
|
||||||
self.model.__class__.__name__,
|
self.model.__class__.__name__,
|
||||||
module_name,
|
module_name,
|
||||||
)
|
)
|
||||||
@ -432,10 +439,10 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
# All lora layers share the same punica_wrapper based on reference.
|
# All lora layers share the same punica_wrapper based on reference.
|
||||||
if self.supports_mm and self.supports_tower_connector_lora:
|
wrapper = self._get_punica_wrapper_for_module(module_name)
|
||||||
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
if wrapper is None:
|
||||||
else:
|
continue
|
||||||
new_module.set_mapping(self.punica_wrapper)
|
new_module.set_mapping(wrapper)
|
||||||
|
|
||||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
assert isinstance(module, BaseLayerWithLoRA), (
|
assert isinstance(module, BaseLayerWithLoRA), (
|
||||||
@ -551,31 +558,36 @@ class LoRAModelManager:
|
|||||||
language model. LoRA for other modules, such as the vision tower, will
|
language model. LoRA for other modules, such as the vision tower, will
|
||||||
be filtered out.
|
be filtered out.
|
||||||
"""
|
"""
|
||||||
if self.supports_mm:
|
if not self.supports_mm:
|
||||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
return False
|
||||||
if self.supports_tower_connector_lora:
|
|
||||||
return self._get_mm_punica_wrapper(module_name) is None
|
|
||||||
else:
|
|
||||||
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
|
if self.supports_tower_connector_lora:
|
||||||
|
return self._get_punica_wrapper_for_module(module_name) is None
|
||||||
|
|
||||||
|
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||||
|
return any(module_name.startswith(prefix) for prefix in prefix_lst)
|
||||||
|
|
||||||
|
def _get_punica_wrapper_for_module(
|
||||||
|
self, module_name: str
|
||||||
|
) -> PunicaWrapperBase | None:
|
||||||
"""
|
"""
|
||||||
Match the corresponding punica_wrapper based on module_name,
|
Match the corresponding punica_wrapper based on module_name,
|
||||||
and return None if lora is not supported for this module.
|
and return None if lora is not supported for this module.
|
||||||
"""
|
"""
|
||||||
if self.supports_tower_connector_lora:
|
best_prefix = None
|
||||||
|
for prefix in self.punica_wrapper_mapping:
|
||||||
|
if prefix == DEFAULT_WRAPPER_KEY:
|
||||||
|
continue
|
||||||
# Ensure matching by the longest prefix.
|
# Ensure matching by the longest prefix.
|
||||||
sorted_prefixes = sorted(
|
if module_name.startswith(prefix) and (
|
||||||
self.mm_punica_wrapper_mapping.keys(),
|
best_prefix is None or len(prefix) > len(best_prefix)
|
||||||
key=lambda x: len(x),
|
):
|
||||||
reverse=True,
|
best_prefix = prefix
|
||||||
)
|
|
||||||
|
|
||||||
for prefix in sorted_prefixes:
|
if best_prefix is not None:
|
||||||
if module_name.startswith(prefix):
|
return self.punica_wrapper_mapping[best_prefix]
|
||||||
return self.mm_punica_wrapper_mapping[prefix]
|
|
||||||
return None
|
return self.punica_wrapper_mapping.get(DEFAULT_WRAPPER_KEY)
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
parts = module_full_name.split(".")
|
parts = module_full_name.split(".")
|
||||||
|
|||||||
@ -2161,9 +2161,7 @@ class GPUModelRunner(
|
|||||||
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
||||||
# Fall back to length if is_embed is None.
|
# Fall back to length if is_embed is None.
|
||||||
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||||
pos_info.length
|
pos_info.get_num_embeds()
|
||||||
if pos_info.is_embed is None
|
|
||||||
else pos_info.is_embed.sum()
|
|
||||||
)
|
)
|
||||||
prompt_lora_mapping.append(lora_id)
|
prompt_lora_mapping.append(lora_id)
|
||||||
token_lora_mapping.extend([lora_id] * num_tokens)
|
token_lora_mapping.extend([lora_id] * num_tokens)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user