diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 46cce598121cd..6afc73d2b04e5 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -385,15 +385,16 @@ class LoRAModelManager(AdapterModelManager): self.info, "get_num_mm_encoder_tokens") else: self.supports_mm_lora = False - if self.supports_mm_lora: + if self.supports_mm_lora: # 从init传进来就可以了,不需要model_config了 self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_config = model_config.multimodal_config - limit_per_prompt: int = max( - [1] + \ - list(self.mm_config.limit_per_prompt.values()) - ) + # limit_per_prompt: int = max( + # self.info.get_allowed_mm_limits().values()) + limit_per_prompt = 5 # For vision tower + # max_num_batched_tokens = encoder_budget + # max_batches = max_batches * limit_per_prompt self.mm_punica_wrapper_mapping = { name: get_punica_wrapper( @@ -411,8 +412,13 @@ class LoRAModelManager(AdapterModelManager): self.mm_mapping.language_model[0]: self.punica_wrapper } ) - # For other - # TODO + # TODO Connector is not supported at the moment. + self.mm_punica_wrapper_mapping.update( + { + name: None + for name in self.mm_mapping.connector + } + ) self.is_pooling_model = is_pooling_model(self.model) self.packed_modules: dict[str, list[str]] = {} @@ -702,15 +708,17 @@ class LoRAModelManager(AdapterModelManager): def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase: """ - TODO + Match the corresponding punica_wrapper based on module_name, + and return None if lora is not supported for this module. """ if self.supports_mm_lora: - for ( - prefix, - punica_wrapper, - ) in self.mm_punica_wrapper_mapping.items(): + # Ensure matching by the longest prefix. + sorted_prefixes = sorted(self.mm_punica_wrapper_mapping.keys(), + key=lambda x: len(x), reverse=True) + + for prefix in sorted_prefixes: if module_name.startswith(prefix): - return punica_wrapper + return self.mm_punica_wrapper_mapping[prefix] return None def _register_packed_modules(self, module_full_name: str) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 43620a0cc2d67..a8a513fd57aad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2271,7 +2271,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): mm_counts={ dummy_data_modality: 1 }, - ).multi_modal_data + ) + dummy_mm_kwargs = dummy_mm_data.multi_modal_data + dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items,