update mm filter

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-07-09 13:55:45 +08:00
parent 98debc2424
commit 7db0d5990a
2 changed files with 28 additions and 12 deletions

View File

@ -386,19 +386,33 @@ class LoRAModelManager(AdapterModelManager):
self.supports_mm_lora = False self.supports_mm_lora = False
if self.supports_mm_lora: if self.supports_mm_lora:
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(
[1] + \
list(self.mm_config.limit_per_prompt.values())
)
# For vision tower
self.mm_punica_wrapper_mapping = { self.mm_punica_wrapper_mapping = {
name: name:
get_punica_wrapper( get_punica_wrapper(
self.info.get_num_mm_encoder_tokens( self.info.get_num_mm_encoder_tokens(
max_num_batched_tokens), max_num_batched_tokens),
max_batches=self.max_num_seqs, # TODO max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
) )
for name in self.mm_mapping.tower_model for name in self.mm_mapping.tower_model
} }
self.mm_punica_wrapper_mapping[ # For language model
self.mm_mapping.language_model[0]] = self.punica_wrapper self.mm_punica_wrapper_mapping.update(
{
self.mm_mapping.language_model[0]: self.punica_wrapper
}
)
# For other
# TODO
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {} self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
@ -539,9 +553,7 @@ class LoRAModelManager(AdapterModelManager):
continue continue
# A temporary approach for multimodal models to support LoRA # A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction # TODO: Remove this restriction
if (self._filter_unsupported_mm_module(module_name) if self._filter_unsupported_mm_module(module_name):
and not self.supports_mm_lora
or self._get_mm_punica_wrapper(module_name) is None):
logger.warning( logger.warning(
"Regarding multimodal models, vLLM currently only supports " "Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.", "adding LoRA to language model, %s will be ignored.",
@ -678,10 +690,13 @@ class LoRAModelManager(AdapterModelManager):
be filtered out. be filtered out.
""" """
if self.supports_mm: if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping() prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
prefix_lst = module_mapping.connector + module_mapping.tower_model if self.supports_mm_lora:
return any(
[module_name.startswith(prefix) for prefix in prefix_lst]) return self._get_mm_punica_wrapper(module_name) is None
else:
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase: def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:

View File

@ -157,10 +157,11 @@ class LoRAModelRunnerMixin:
@contextmanager @contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray): num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False):
with self.maybe_setup_dummy_loras( with self.maybe_setup_dummy_loras(
lora_config), self.maybe_select_dummy_loras( lora_config), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens): lora_config, num_scheduled_tokens, is_mm_input):
yield yield
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool: