From e7026a7c50f4049ac0e071a3af2d4d991ec1fabf Mon Sep 17 00:00:00 2001 From: bk-201 Date: Thu, 22 May 2025 00:31:56 +0800 Subject: [PATCH] add mm_punica_warpper Signed-off-by: bk-201 --- requirements/test.txt | 22 ++++- vllm/lora/layers.py | 8 ++ vllm/lora/models.py | 97 +++++++++++++++++++---- vllm/lora/worker_manager.py | 4 +- vllm/model_executor/models/idefics3.py | 9 +++ vllm/model_executor/models/qwen2_vl.py | 10 +++ vllm/multimodal/profiling.py | 4 + vllm/v1/worker/gpu_model_runner.py | 55 ++++++++++--- vllm/v1/worker/lora_model_runner_mixin.py | 27 ++++--- 9 files changed, 200 insertions(+), 36 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 89d477017342e..df3770856022f 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,6 +27,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -129,6 +133,11 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +exceptiongroup==1.3.0 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -640,7 +649,6 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter - # torch # triton shellingham==1.5.4 # via typer @@ -700,8 +708,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -775,13 +788,18 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black + # exceptiongroup # huggingface-hub # librosa # mistral-common # mteb + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 023c8e9c9a864..4a1f860b3bbb8 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -77,6 +77,7 @@ def _not_fully_sharded_can_replace(can_replace): @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False + is_mm_input: bool = False class BaseLayerWithLoRA(nn.Module): @@ -410,6 +411,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + # Store original shape for later reshaping + original_shape = output.shape if output.ndim == 3 else None + # In transformers backend, x and output have extra batch dimension like # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # therefore we need to flatten the batch dimensions. @@ -424,6 +428,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): if not current_platform.can_update_inplace(): output = lora_output + # Restore original shape if it was flattened + if original_shape is not None: + output = output.reshape(original_shape) + return output @property diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 959fe4a672a6d..9556579ca5a6c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -17,14 +17,14 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, get_adapter, list_adapters, remove_adapter, set_adapter_mapping) -from vllm.config import LoRAConfig +from vllm.config import LoRAConfig, ModelConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, LinearScalingRotaryEmbeddingWithLoRA, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper -from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, get_supported_lora_modules, is_regex_target_modules, @@ -33,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -311,6 +312,7 @@ class LoRAModelManager(AdapterModelManager): max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, + model_config: Optional[ModelConfig], device: torch.device, ): """Create a LoRAModelManager and adapter for a given model. @@ -357,6 +359,30 @@ class LoRAModelManager(AdapterModelManager): # In case the model only supports LoRA for # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) + # For v0 compatibility + if model_config is not None: + self.mm_registry = MULTIMODAL_REGISTRY + self.info = self.mm_registry.create_processor( + model_config, disable_cache=True).info + self.supports_mm_lora = self.supports_mm and hasattr( + self.info, "get_num_mm_encoder_tokens") + else: + self.supports_mm_lora = False + if self.supports_mm_lora: + self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() + self.mm_punica_wrapper_mapping = { + name: + get_punica_wrapper( + self.info.get_num_mm_encoder_tokens( + max_num_batched_tokens), + max_batches=self.max_num_seqs, # TODO + device=self.device, + max_loras=self.lora_config.max_loras, + ) + for name in self.mm_mapping.tower_model + } + self.mm_punica_wrapper_mapping[ + self.mm_mapping.language_model[0]] = self.punica_wrapper self.is_pooling_model = is_pooling_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} @@ -452,14 +478,35 @@ class LoRAModelManager(AdapterModelManager): def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # update lora states - self.punica_wrapper.update_metadata( - mapping, - self.lora_index_to_id, - self.lora_slots + 1, - self.vocab_size, - self.lora_config.lora_extra_vocab_size, - self.long_lora_context, - ) + if not self.supports_mm_lora: + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + elif mapping.is_mm_input: + self.mm_punica_wrapper_mapping[ + self.mm_mapping.tower_model[0]].update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + else: + self.mm_punica_wrapper_mapping[ + self.mm_mapping.language_model[0]].update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" @@ -476,7 +523,9 @@ class LoRAModelManager(AdapterModelManager): continue # A temporary approach for multimodal models to support LoRA # TODO: Remove this restriction - if self._filter_unsupported_mm_module(module_name): + if (self._filter_unsupported_mm_module(module_name) + and not self.supports_mm_lora + or self._get_mm_punica_wrapper(module_name) is None): logger.warning( "Regarding multimodal models, vLLM currently only supports " "adding LoRA to language model, %s will be ignored.", @@ -519,7 +568,11 @@ class LoRAModelManager(AdapterModelManager): self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. - new_module.set_mapping(self.punica_wrapper) + if self.supports_mm_lora: + new_module.set_mapping( + self._get_mm_punica_wrapper(module_name)) + else: + new_module.set_mapping(self.punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) @@ -615,6 +668,19 @@ class LoRAModelManager(AdapterModelManager): [module_name.startswith(prefix) for prefix in prefix_lst]) return False + def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase: + """ + TODO + """ + if self.supports_mm_lora: + for ( + prefix, + punica_wrapper, + ) in self.mm_punica_wrapper_mapping.items(): + if module_name.startswith(prefix): + return punica_wrapper + return None + def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") module_name = parts[-1] @@ -713,9 +779,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager): def __init__(self, model: nn.Module, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, - lora_config: LoRAConfig, device: torch.device): + lora_config: LoRAConfig, model_config: ModelConfig, + device: torch.device): super().__init__(model, max_num_seqs, max_num_batched_tokens, - vocab_size, lora_config, device) + vocab_size, lora_config, model_config, device) self._registered_adapters: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_adapter) self._active_adapters: LoRALRUCache = LoRALRUCache( @@ -785,6 +852,7 @@ def create_lora_manager( max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, + model_config: ModelConfig, device: torch.device, lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: @@ -797,6 +865,7 @@ def create_lora_manager( max_num_batched_tokens=max_num_batched_tokens, vocab_size=vocab_size, lora_config=lora_config, + model_config=model_config, device=device, **kwargs) return lora_manager diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 8e5bc61066593..016de3cbc0f25 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -10,7 +10,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker, list_adapters_worker, set_active_adapters_worker) from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.config import LoRAConfig +from vllm.config import LoRAConfig, ModelConfig from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) @@ -200,6 +200,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): def create_lora_manager( self, model: torch.nn.Module, + model_config: Optional[ModelConfig] = None, ) -> Any: lora_manager = create_lora_manager( model, @@ -209,6 +210,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): lora_config=self.lora_config, device=self.device, max_num_batched_tokens=self.max_num_batched_tokens, + model_config=model_config, ) self._adapter_manager = lora_manager return lora_manager.model diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index fdb128ef5b541..96602848dd53f 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -279,6 +279,15 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): height=image_processor.size["longest_edge"], ) + def get_num_mm_encoder_tokens( + self, + num_image_tokens: int, + ) -> int: + hf_config = self.get_hf_config() + scale_factor = hf_config.scale_factor + + return num_image_tokens * scale_factor**2 + class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ff0836b08975..5da8798620585 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -962,6 +962,16 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): image_processor=None, ) + def get_num_mm_encoder_tokens( + self, + num_image_tokens: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + merge_size = vision_config.spatial_merge_size + + return num_image_tokens * merge_size**2 + class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index b5875124c1266..8e6e53413d923 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import copy from abc import ABC from collections.abc import Mapping from dataclasses import dataclass, field @@ -44,6 +45,7 @@ class DummyDecoderData(NamedTuple): prompt_token_ids: list[int] multi_modal_data: MultiModalKwargs multi_modal_placeholders: MultiModalPlaceholderDict + multi_modal_token_ids: list[int] _I = TypeVar("_I", bound=BaseProcessingInfo) @@ -249,6 +251,7 @@ class MultiModalProfiler(Generic[_I]): str(self._get_mm_num_tokens(mm_inputs)), ) + multi_modal_token_ids = copy.deepcopy(prompt_token_ids) if total_len < seq_len: prompt_token_ids.extend([0] * (seq_len - total_len)) @@ -256,6 +259,7 @@ class MultiModalProfiler(Generic[_I]): prompt_token_ids=prompt_token_ids, multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_placeholders=mm_inputs["mm_placeholders"], + multi_modal_token_ids=multi_modal_token_ids, ) def get_mm_max_tokens( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 201796c96ee5c..98d76d6afe68d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -263,6 +263,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Multimodal LoRA support + if self.is_multimodal_model: + self.info = self.mm_registry.create_processor( + self.model_config, disable_cache=True).info + self.supports_mm_lora = hasattr(self.info, + "get_num_mm_encoder_tokens") + else: + self.supports_mm_lora = False + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -892,12 +901,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): return # Batch the multi-modal inputs. + mm_tokens = list[int]() mm_inputs = list[MultiModalKwargs]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_tokens.append(req_state.mm_positions[mm_input_id].length) mm_inputs.append(req_state.mm_inputs[mm_input_id]) req_ids_pos.append( (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) @@ -911,6 +922,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): # encoder outputs. grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + if self.lora_config and self.supports_mm_lora: + mm_tokens = [ + self.info.get_num_mm_encoder_tokens(num_token) + for num_token in mm_tokens + ] + num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32) + self.set_active_loras(self.input_batch, + num_scheduled_tokens, + is_mm_input=True) + encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) @@ -1826,22 +1847,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. - dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( + dummy_mm_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, - mm_counts={ - dummy_data_modality: 1 - }, - ).multi_modal_data + mm_counts={dummy_data_modality: 1}, + ) + dummy_mm_kwargs = dummy_mm_data.multi_modal_data + dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids + max_num_mm_items = 1 # temporary batched_dummy_mm_inputs = MultiModalKwargs.batch( - [dummy_mm_kwargs] * max_num_mm_items) + [dummy_mm_kwargs] * max_num_mm_items) # ??? batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + if self.supports_mm_lora: + num_scheduled_tokens_list = [ + self.info.get_num_mm_encoder_tokens( + len(dummy_mm_token_ids)) + ] * max_num_mm_items + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + lora_config = self.lora_config + else: + num_scheduled_tokens = None + lora_config = None + + with self.maybe_dummy_run_with_lora(lora_config, + num_scheduled_tokens, + is_mm_input=True): + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 3cbab840e9693..41a795d91b6f0 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -50,11 +50,13 @@ class LoRAModelRunnerMixin: model.embedding_padding_modules, max_position_embeddings=text_config.max_position_embeddings, ) - return self.lora_manager.create_lora_manager(model) + return self.lora_manager.create_lora_manager(model, model_config) - def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], + def _set_active_loras(self, + prompt_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...], - lora_requests: set[LoRARequest]) -> None: + lora_requests: set[LoRARequest], + is_mm_input: bool = False) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -64,11 +66,14 @@ class LoRAModelRunnerMixin: # decode and this flag is generally ignored. lora_mapping = LoRAMapping(token_lora_mapping, prompt_lora_mapping, - is_prefill=True) + is_prefill=True, + is_mm_input=is_mm_input) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - def set_active_loras(self, input_batch: InputBatch, - num_scheduled_tokens: np.ndarray) -> None: + def set_active_loras(self, + input_batch: InputBatch, + num_scheduled_tokens: np.ndarray, + is_mm_input: bool = False) -> None: prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs token_lora_mapping: tuple[int, @@ -77,11 +82,13 @@ class LoRAModelRunnerMixin: prompt_lora_mapping, token_lora_mapping, lora_requests = \ input_batch.make_lora_inputs(num_scheduled_tokens) return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + lora_requests, is_mm_input) @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_dummy_run_with_lora(self, + lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray, + is_mm_input: bool = False): if lora_config is None: yield else: @@ -117,7 +124,7 @@ class LoRAModelRunnerMixin: self._set_active_loras(tuple(prompt_lora_mapping), tuple(token_lora_mapping), - lora_requests) + lora_requests, is_mm_input) yield