From a69bde7e8fb7036eb7ecbb58593e5233d1c08723 Mon Sep 17 00:00:00 2001 From: prashanth058 Date: Thu, 20 Nov 2025 15:04:33 +0000 Subject: [PATCH] [feat] add connector support Signed-off-by: prashanth058 --- tests/lora/conftest.py | 15 ++ tests/lora/test_qwen2vl.py | 78 +++++++++- vllm/lora/layers/__init__.py | 3 +- vllm/lora/layers/row_parallel_linear.py | 25 ++-- vllm/lora/layers/utils.py | 9 +- vllm/lora/models.py | 173 +++++++++++----------- vllm/model_executor/models/qwen2_5_vl.py | 19 +++ vllm/model_executor/models/qwen2_vl.py | 9 ++ vllm/v1/worker/gpu_model_runner.py | 87 +++++++++-- vllm/v1/worker/lora_model_runner_mixin.py | 29 ++-- vllm/v1/worker/tpu_model_runner.py | 11 +- 11 files changed, 325 insertions(+), 133 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index f805a74a4dba8..e5328cde3a046 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -225,6 +225,21 @@ def qwen25vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") +@pytest.fixture(scope="session") +def qwen2vl_language_lora_files(): + return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language") + + +@pytest.fixture(scope="session") +def qwen2vl_vision_tower_connector_lora_files(): + return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector") + + +@pytest.fixture(scope="session") +def qwen2vl_vision_tower_lora_files(): + return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 1800ca107a426..a323bd642b7c2 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -79,7 +79,6 @@ class Qwen2VLTester: lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request) generated_texts = [output.outputs[0].text.strip() for output in outputs] - # Validate outputs for generated, expected in zip(generated_texts, expected_outputs): assert expected.startswith(generated), ( @@ -130,6 +129,22 @@ EXPECTED_OUTPUTS = [ "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 ] +EXPECTED_OUTPUTS_LANGUAGE = [ + "A stop sign is shown in an Asian city, with buildings and a car in the " + "background.", + "The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.", +] + +EXPECTED_OUTPUTS_VISION = [ + "A stop sign in front of oriental buildings.", + "A tree with pink flowers in front of it and a blue sky behind the flowers.", +] + +EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [ + "A stop sign is located on the street of a Chinese neighborhood.", + "A closeup shot of the Tokyo Skytree with pink flowers in the foreground.", +] + # NOTE - beam search .text contains the whole text EXPECTED_BEAM_SEARCH_OUTPUTS = [ [ @@ -190,3 +205,64 @@ def test_qwen25vl_lora(qwen25vl_lora_files): # Test with different LoRA IDs for lora_id in [1, 2]: tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) +def test_qwen2vl_language_lora(qwen2vl_language_lora_files): + """ + Test language-only LoRA adapter. + """ + config = TestConfig( + model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files + ) + tester = Qwen2VLTester(config) + for lora_id in [1, 2]: + tester.run_test( + TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id + ) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) +def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files): + """ + Test vision tower + connector LoRA adapter. + """ + config = TestConfig( + model_path=QWEN2VL_MODEL_PATH, + lora_path=qwen2vl_vision_tower_connector_lora_files, + ) + tester = Qwen2VLTester(config) + for lora_id in [1, 2]: + tester.run_test( + TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id + ) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) +def test_qwen2vl_vision_no_connector_lora( + qwen2vl_vision_tower_lora_files, +): + """ + Test vision tower only LoRA adapter. + + """ + config = TestConfig( + model_path=QWEN2VL_MODEL_PATH, + lora_path=qwen2vl_vision_tower_lora_files, + ) + tester = Qwen2VLTester(config) + for lora_id in [1, 2]: + tester.run_test( + TEST_IMAGES, + expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR, + lora_id=lora_id, + ) diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 4915ef85f4f73..80dc5b382031e 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -17,7 +17,7 @@ from vllm.lora.layers.row_parallel_linear import ( RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA, ) -from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA __all__ = [ @@ -36,4 +36,5 @@ __all__ = [ "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", "LoRAMapping", + "LoRAMappingType", ] diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 2ef1bd98fc612..d74e403ca39c3 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -63,22 +63,25 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. - output_parallel = self.apply(input_parallel) + # Only fuse bias add into GEMM for rank 0 (matches base + # RowParallelLinear behavior). This ensures bias will not get + # added more than once in TP>1 case and matches the numerical + # behavior of the unwrapped layer + bias_ = ( + None + if (self.tp_rank > 0 or self.base_layer.skip_bias_add) + else self.base_layer.bias + ) + output_parallel = self.apply(input_parallel, bias_) + if self.base_layer.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel - if not self.base_layer.skip_bias_add: - output = ( - output_ + self.base_layer.bias - if self.base_layer.bias is not None - else output_ - ) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias + # Bias was already added by rank 0 in apply(), no need to add again + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + output = output_ if not self.base_layer.return_bias: return output diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 002dc934636b9..3f89f77b663c8 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -2,17 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import Enum import torch import torch.nn as nn +class LoRAMappingType(Enum): + LANGUAGE = 1 + TOWER = 2 + CONNECTOR = 3 + + @dataclass class LoRAMapping: index_mapping: tuple[int, ...] prompt_mapping: tuple[int, ...] is_prefill: bool = False - is_mm_input: bool = False + type: LoRAMappingType = LoRAMappingType.LANGUAGE def __post_init__(self): self.index_mapping = tuple(self.index_mapping) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 0de2b4ceec9bf..0c536d8ea192e 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -13,7 +13,7 @@ from torch import nn from vllm.config.lora import LoRAConfig, ModelConfig from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, LoRAMappingType from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper @@ -374,50 +374,7 @@ class LoRAModelManager: f" {self.model.__class__.__name__}." self.packed_modules_mapping = get_packed_modules_mapping(self.model) - # Used to indicate whether the model is a multimodal model - self.supports_mm: bool = ( - supports_multimodal(self.model) - # In case the model only supports LoRA for - # text modules (e.g. ChatGLM) - and hasattr(self.model, "get_mm_mapping") - ) - # For v0 compatibility - if model_config is not None: - self.mm_registry = MULTIMODAL_REGISTRY - self.info = self.mm_registry.create_processor(model_config).info - self.supports_mm_lora = self.supports_mm and hasattr( - self.info, "get_num_mm_encoder_tokens" - ) - else: - self.supports_mm_lora = False - if self.supports_mm_lora: - self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() - self.mm_config = model_config.multimodal_config - # limit_per_prompt: int = max( - # self.info.get_allowed_mm_limits().values()) - limit_per_prompt = 5 # TODO - - # For vision tower - # max_num_batched_tokens = encoder_budget - # max_batches = max_batches * limit_per_prompt - self.mm_punica_wrapper_mapping = { - name: get_punica_wrapper( - self.info.get_num_mm_encoder_tokens(max_num_batched_tokens), - max_batches=self.max_num_seqs * limit_per_prompt, - device=self.device, - max_loras=self.lora_config.max_loras, - ) - for name in self.mm_mapping.tower_model - } - # For language model - self.mm_punica_wrapper_mapping.update( - {self.mm_mapping.language_model[0]: self.punica_wrapper} - ) - # TODO Connector is not supported at the moment. - self.mm_punica_wrapper_mapping.update( - {name: None for name in self.mm_mapping.connector} - ) - + self._init_multimodal_config(model_config) self.is_pooling_model = is_pooling_model(self.model) self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} @@ -427,6 +384,72 @@ class LoRAModelManager: self._create_lora_modules() self.model.lora_manager = self + def _init_multimodal_config(self, model_config): + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = ( + supports_multimodal(self.model) + # In case the model only supports LoRA for + # text modules (e.g. ChatGLM) + and hasattr(self.model, "get_mm_mapping") + ) + # For v0 compatibility + self.supports_mm_lora = False + if model_config is not None: + self.mm_registry = MULTIMODAL_REGISTRY + self.info = self.mm_registry.create_processor(model_config).info + self.supports_mm_lora = self.supports_mm and hasattr( + self.info, "get_num_mm_encoder_tokens" + ) + + if not self.supports_mm_lora: + return + + self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() + self.mm_config = model_config.multimodal_config + limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values()) + + # For vision tower + num_encoder_tokens = self.info.get_num_mm_encoder_tokens( + self.max_num_batched_tokens + ) + self.mm_punica_wrapper_mapping = { + name: get_punica_wrapper( + num_encoder_tokens, + max_batches=self.max_num_seqs * limit_per_prompt, + device=self.device, + max_loras=self.lora_config.max_loras, + ) + for name in self.mm_mapping.tower_model + } + # For language model + self.mm_punica_wrapper_mapping.update( + {self.mm_mapping.language_model[0]: self.punica_wrapper} + ) + # Use wrapper for connector if present. + if self.mm_mapping.connector: + if hasattr(self.info, "get_num_mm_connector_tokens"): + connector_tokens = self.info.get_num_mm_connector_tokens( + num_encoder_tokens + ) + connector_punica_wrapper = get_punica_wrapper( + connector_tokens, + max_batches=self.max_num_seqs * limit_per_prompt, + device=self.device, + max_loras=self.lora_config.max_loras, + ) + self.mm_punica_wrapper_mapping.update( + { + name: connector_punica_wrapper + for name in self.mm_mapping.connector + } + ) + else: + logger.warning_once( + "Connector LoRA support disabled: model does not implement " + "get_num_mm_connector_tokens(). This method is required to " + "determine the connector's token budget for LoRA operations." + ) + def __len__(self) -> int: return len(self._registered_adapters) @@ -499,35 +522,27 @@ class LoRAModelManager: ) # type: ignore def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: - # update lora states - if not self.supports_mm_lora: - self.punica_wrapper.update_metadata( - mapping, - self.lora_index_to_id, - self.lora_slots + 1, - self.vocab_size, - self.lora_config.lora_extra_vocab_size, - ) - elif mapping.is_mm_input: - self.mm_punica_wrapper_mapping[ - self.mm_mapping.tower_model[0] - ].update_metadata( - mapping, - self.lora_index_to_id, - self.lora_slots + 1, - self.vocab_size, - self.lora_config.lora_extra_vocab_size, - ) - else: - self.mm_punica_wrapper_mapping[ - self.mm_mapping.language_model[0] - ].update_metadata( - mapping, - self.lora_index_to_id, - self.lora_slots + 1, - self.vocab_size, - self.lora_config.lora_extra_vocab_size, - ) + # Default to the main language model wrapper + target_wrapper = self.punica_wrapper + + if self.supports_mm_lora: + if mapping.type == LoRAMappingType.TOWER: + target_name = self.mm_mapping.tower_model[0] + target_wrapper = self.mm_punica_wrapper_mapping[target_name] + elif mapping.type == LoRAMappingType.CONNECTOR: + target_name = self.mm_mapping.connector[0] + target_wrapper = self.mm_punica_wrapper_mapping[target_name] + else: + target_name = self.mm_mapping.language_model[0] + target_wrapper = self.mm_punica_wrapper_mapping[target_name] + + target_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + ) def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" @@ -548,15 +563,6 @@ class LoRAModelManager: continue if not self._match_target_modules(module_name): continue - # A temporary approach for multimodal models to support LoRA - # TODO: Remove this restriction - if self._filter_unsupported_mm_module(module_name): - logger.warning( - "Regarding multimodal models, vLLM currently only supports " - "adding LoRA to language model, %s will be ignored.", - module_name, - ) - continue parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( @@ -604,6 +610,7 @@ class LoRAModelManager: if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): continue self.register_module(module_name, new_module) + self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. if self.supports_mm_lora: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3f205307cb225..dd06431b54eea 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1040,6 +1040,25 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): for modality in ("image", "video") ] + def get_num_mm_encoder_tokens( + self, + num_image_tokens: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + merge_size = vision_config.spatial_merge_size + + return num_image_tokens * merge_size**2 + + def get_num_mm_connector_tokens( + self, + num_vision_tokens: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + merge_size = vision_config.spatial_merge_size + return num_vision_tokens // merge_size**2 + @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 004eacc1b4b79..287a55a66c6bc 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1094,6 +1094,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): return num_image_tokens * merge_size**2 + def get_num_mm_connector_tokens( + self, + num_vision_tokens: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + merge_size = vision_config.spatial_merge_size + return num_vision_tokens // merge_size**2 + class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b776870763956..052b85d5cb336 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -43,6 +43,7 @@ from vllm.distributed.parallel_state import ( ) from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping, LoRAMappingType from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -1689,7 +1690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _batch_mm_kwargs_from_scheduler( self, scheduler_output: "SchedulerOutput", - ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + ) -> tuple[ + list[MultiModalKwargsItem], + list[tuple[str, PlaceholderRange]], + list[str], + ]: """Batch multimodal kwargs from scheduled encoder inputs. Args: @@ -1697,17 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): inputs. Returns: - A tuple of (mm_kwargs, req_ids_pos) where: + A tuple of (mm_kwargs, mm_hashes_pos, req_ids) where: - mm_kwargs: List of multimodal kwargs items to be batched - mm_hashes_pos: List of (mm_hash, position_info) tuples + - req_ids: List of request IDs for each encoder input """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return [], [] + return [], [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() # list of tuple (mm_hash, position_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + # list of request IDs for each encoder input + req_ids = list[str]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] @@ -1716,13 +1724,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_hash = mm_feature.identifier mm_kwargs.append(mm_feature.data) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + req_ids.append(req_id) - return mm_kwargs, mm_hashes_pos + return mm_kwargs, mm_hashes_pos, req_ids def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs using the helper method. - mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output + mm_kwargs, mm_hashes_pos, encoder_req_ids = ( + self._batch_mm_kwargs_from_scheduler(scheduler_output) ) if not mm_kwargs: @@ -1739,16 +1748,62 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): encoder_outputs = [] if self.lora_config and self.supports_mm_lora: - mm_tokens = [ - self.info.get_num_mm_encoder_tokens(pos_info.length) - for _, pos_info in mm_hashes_pos - ] - num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32) - self.set_active_loras( - self.input_batch, - num_scheduled_tokens, - is_mm_input=True, + # Build LoRA mappings independently for encoder inputs + # (encoder batch structure is different from main batch) + prompt_lora_mapping = [] + token_lora_mapping = [] + lora_requests = set() + + for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos): + req_idx = self.input_batch.req_id_to_index[req_id] + lora_id = int(self.input_batch.request_lora_mapping[req_idx]) + + num_tokens = self.info.get_num_mm_encoder_tokens(pos_info.length) + prompt_lora_mapping.append(lora_id) + token_lora_mapping.extend([lora_id] * num_tokens) + + if lora_id > 0: + lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id) + if lora_request is not None: + lora_requests.add(lora_request) + + lora_mapping = LoRAMapping( + tuple(token_lora_mapping), + tuple(prompt_lora_mapping), + is_prefill=True, + type=LoRAMappingType.TOWER, ) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + if hasattr(self.info, "get_num_mm_connector_tokens"): + num_post_op_tokens = [] + for _, pos_info in mm_hashes_pos: + mm_token_count = self.info.get_num_mm_encoder_tokens( + pos_info.length + ) + post_op_count = self.info.get_num_mm_connector_tokens( + mm_token_count + ) + num_post_op_tokens.append(post_op_count) + + lora_ids = np.array( + self.lora_manager._adapter_manager._last_mapping.prompt_mapping, + dtype=np.int32, + ) + post_op_counts_np = np.array(num_post_op_tokens, dtype=np.int32) + new_token_indices = lora_ids.repeat(post_op_counts_np) + + connector_mapping = LoRAMapping( + index_mapping=tuple(new_token_indices.tolist()), + prompt_mapping=self.lora_manager._adapter_manager._last_mapping.prompt_mapping, + is_prefill=self.lora_manager._adapter_manager._last_mapping.is_prefill, + type=LoRAMappingType.CONNECTOR, + ) + + self.lora_manager.set_active_adapters( + lora_requests, + connector_mapping, + ) for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, @@ -1898,7 +1953,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): inputs and formats them for the encoder-decoder model forward pass. """ # Batch the multi-modal inputs using the helper method. - mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + mm_kwargs, _, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) if not mm_kwargs: return {} diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 31094dcbda124..98f2825e483a9 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners. """ from contextlib import contextmanager +from typing import TypeAlias import numpy as np import torch @@ -13,14 +14,14 @@ import torch.nn as nn from vllm.config import ModelConfig, VllmConfig from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping +from vllm.lora.layers import LoRAMapping, LoRAMappingType from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models import supports_lora from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch -InputBatch = TPUInputBatch | GPUInputBatch +InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch logger = init_logger(__name__) @@ -37,12 +38,6 @@ class LoRAModelRunnerMixin: if not supports_lora(model): raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") - if supports_multimodal(model): - logger.warning( - "Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model." - ) - # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( vllm_config, @@ -57,7 +52,7 @@ class LoRAModelRunnerMixin: prompt_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...], lora_requests: set[LoRARequest], - is_mm_input: bool = False, + mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ) -> None: self._ensure_lora_enabled() @@ -69,7 +64,7 @@ class LoRAModelRunnerMixin: token_lora_mapping, prompt_lora_mapping, is_prefill=True, - is_mm_input=is_mm_input, + type=mapping_type, ) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) @@ -81,7 +76,7 @@ class LoRAModelRunnerMixin: self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray, - is_mm_input: bool = False, + mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ) -> None: prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) @@ -90,7 +85,7 @@ class LoRAModelRunnerMixin: input_batch.make_lora_inputs(num_scheduled_tokens) ) return self._set_active_loras( - prompt_lora_mapping, token_lora_mapping, lora_requests, is_mm_input + prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type ) @contextmanager @@ -134,7 +129,7 @@ class LoRAModelRunnerMixin: self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, - is_mm_input: bool = False, + mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ): if lora_config is None: yield @@ -166,7 +161,7 @@ class LoRAModelRunnerMixin: tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests, - is_mm_input, + mapping_type, ) yield @@ -177,12 +172,12 @@ class LoRAModelRunnerMixin: lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, remove_lora: bool = True, - is_mm_input: bool = False, + mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ): with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( - lora_config, num_scheduled_tokens, is_mm_input + lora_config, num_scheduled_tokens, mapping_type ), ): yield diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6fd71259dbcbf..88cd19ba3935d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -32,7 +32,8 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_ from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType +from vllm.lora.request import LoRARequest from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader @@ -1422,11 +1423,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._hidden_states_dtype = out.dtype def _set_active_loras( - self, prompt_lora_mapping, token_lora_mapping, lora_requests + self, + prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest], + mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ) -> None: torch_xla.sync(wait=False) # Captures input updates super()._set_active_loras( - prompt_lora_mapping, token_lora_mapping, lora_requests + prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type ) torch_xla.sync(wait=False) # Captures metadata updates