update mm filter

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-07-09 13:55:45 +08:00
parent 98debc2424
commit 7db0d5990a
2 changed files with 28 additions and 12 deletions

View File

@ -386,19 +386,33 @@ class LoRAModelManager(AdapterModelManager):
self.supports_mm_lora = False
if self.supports_mm_lora:
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(
[1] + \
list(self.mm_config.limit_per_prompt.values())
)
# For vision tower
self.mm_punica_wrapper_mapping = {
name:
get_punica_wrapper(
self.info.get_num_mm_encoder_tokens(
max_num_batched_tokens),
max_batches=self.max_num_seqs, # TODO
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]] = self.punica_wrapper
# For language model
self.mm_punica_wrapper_mapping.update(
{
self.mm_mapping.language_model[0]: self.punica_wrapper
}
)
# For other
# TODO
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
@ -539,9 +553,7 @@ class LoRAModelManager(AdapterModelManager):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if (self._filter_unsupported_mm_module(module_name)
and not self.supports_mm_lora
or self._get_mm_punica_wrapper(module_name) is None):
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
@ -678,10 +690,13 @@ class LoRAModelManager(AdapterModelManager):
be filtered out.
"""
if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
if self.supports_mm_lora:
return self._get_mm_punica_wrapper(module_name) is None
else:
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:

View File

@ -157,10 +157,11 @@ class LoRAModelRunnerMixin:
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False):
with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens):
lora_config, num_scheduled_tokens, is_mm_input):
yield
def add_lora(self, lora_request: LoRARequest) -> bool: