diff --git a/vllm/lora/models.py b/vllm/lora/models.py index b9e6b3af3309d..2d65008074e01 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -386,19 +386,33 @@ class LoRAModelManager(AdapterModelManager): self.supports_mm_lora = False if self.supports_mm_lora: self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() + self.mm_config = model_config.multimodal_config + limit_per_prompt: int = max( + [1] + \ + list(self.mm_config.limit_per_prompt.values()) + ) + + # For vision tower self.mm_punica_wrapper_mapping = { name: get_punica_wrapper( self.info.get_num_mm_encoder_tokens( max_num_batched_tokens), - max_batches=self.max_num_seqs, # TODO + max_batches=self.max_num_seqs * limit_per_prompt, device=self.device, max_loras=self.lora_config.max_loras, ) for name in self.mm_mapping.tower_model } - self.mm_punica_wrapper_mapping[ - self.mm_mapping.language_model[0]] = self.punica_wrapper + # For language model + self.mm_punica_wrapper_mapping.update( + { + self.mm_mapping.language_model[0]: self.punica_wrapper + } + ) + # For other + # TODO + self.is_pooling_model = is_pooling_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} @@ -539,9 +553,7 @@ class LoRAModelManager(AdapterModelManager): continue # A temporary approach for multimodal models to support LoRA # TODO: Remove this restriction - if (self._filter_unsupported_mm_module(module_name) - and not self.supports_mm_lora - or self._get_mm_punica_wrapper(module_name) is None): + if self._filter_unsupported_mm_module(module_name): logger.warning( "Regarding multimodal models, vLLM currently only supports " "adding LoRA to language model, %s will be ignored.", @@ -678,10 +690,13 @@ class LoRAModelManager(AdapterModelManager): be filtered out. """ if self.supports_mm: - module_mapping: MultiModelKeys = self.model.get_mm_mapping() - prefix_lst = module_mapping.connector + module_mapping.tower_model - return any( - [module_name.startswith(prefix) for prefix in prefix_lst]) + prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model + if self.supports_mm_lora: + + return self._get_mm_punica_wrapper(module_name) is None + else: + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) return False def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e9a2ac4792e81..0634e274717e0 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -157,10 +157,11 @@ class LoRAModelRunnerMixin: @contextmanager def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + num_scheduled_tokens: np.ndarray, + is_mm_input: bool = False): with self.maybe_setup_dummy_loras( lora_config), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens): + lora_config, num_scheduled_tokens, is_mm_input): yield def add_lora(self, lora_request: LoRARequest) -> bool: