[feat] add connector support

Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
This commit is contained in:
prashanth058 2025-11-20 15:04:33 +00:00
parent 5c156c9f09
commit a69bde7e8f
11 changed files with 325 additions and 133 deletions

View File

@ -225,6 +225,21 @@ def qwen25vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
@pytest.fixture(scope="session")
def qwen2vl_language_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")
@pytest.fixture(scope="session")
def qwen2vl_vision_tower_connector_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")
@pytest.fixture(scope="session")
def qwen2vl_vision_tower_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tinyllama_lora_files(): def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

View File

@ -79,7 +79,6 @@ class Qwen2VLTester:
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request) outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
generated_texts = [output.outputs[0].text.strip() for output in outputs] generated_texts = [output.outputs[0].text.strip() for output in outputs]
# Validate outputs # Validate outputs
for generated, expected in zip(generated_texts, expected_outputs): for generated, expected in zip(generated_texts, expected_outputs):
assert expected.startswith(generated), ( assert expected.startswith(generated), (
@ -130,6 +129,22 @@ EXPECTED_OUTPUTS = [
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
] ]
EXPECTED_OUTPUTS_LANGUAGE = [
"A stop sign is shown in an Asian city, with buildings and a car in the "
"background.",
"The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.",
]
EXPECTED_OUTPUTS_VISION = [
"A stop sign in front of oriental buildings.",
"A tree with pink flowers in front of it and a blue sky behind the flowers.",
]
EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
"A stop sign is located on the street of a Chinese neighborhood.",
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
]
# NOTE - beam search .text contains the whole text # NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [ EXPECTED_BEAM_SEARCH_OUTPUTS = [
[ [
@ -190,3 +205,64 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
# Test with different LoRA IDs # Test with different LoRA IDs
for lora_id in [1, 2]: for lora_id in [1, 2]:
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_language_lora(qwen2vl_language_lora_files):
"""
Test language-only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files):
"""
Test vision tower + connector LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_connector_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_vision_no_connector_lora(
qwen2vl_vision_tower_lora_files,
):
"""
Test vision tower only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
lora_id=lora_id,
)

View File

@ -17,7 +17,7 @@ from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
) )
from vllm.lora.layers.utils import LoRAMapping from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [ __all__ = [
@ -36,4 +36,5 @@ __all__ = [
"RowParallelLinearWithShardedLoRA", "RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA", "ReplicatedLinearWithLoRA",
"LoRAMapping", "LoRAMapping",
"LoRAMappingType",
] ]

View File

@ -63,22 +63,25 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply(input_parallel) # Only fuse bias add into GEMM for rank 0 (matches base
# RowParallelLinear behavior). This ensures bias will not get
# added more than once in TP>1 case and matches the numerical
# behavior of the unwrapped layer
bias_ = (
None
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
output_parallel = self.apply(input_parallel, bias_)
if self.base_layer.reduce_results and self.tp_size > 1: if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output_ = output_parallel
if not self.base_layer.skip_bias_add: # Bias was already added by rank 0 in apply(), no need to add again
output = ( output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
output_ + self.base_layer.bias output = output_
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
if not self.base_layer.return_bias: if not self.base_layer.return_bias:
return output return output

View File

@ -2,17 +2,24 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
import torch import torch
import torch.nn as nn import torch.nn as nn
class LoRAMappingType(Enum):
LANGUAGE = 1
TOWER = 2
CONNECTOR = 3
@dataclass @dataclass
class LoRAMapping: class LoRAMapping:
index_mapping: tuple[int, ...] index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...] prompt_mapping: tuple[int, ...]
is_prefill: bool = False is_prefill: bool = False
is_mm_input: bool = False type: LoRAMappingType = LoRAMappingType.LANGUAGE
def __post_init__(self): def __post_init__(self):
self.index_mapping = tuple(self.index_mapping) self.index_mapping = tuple(self.index_mapping)

View File

@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig, ModelConfig from vllm.config.lora import LoRAConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, LoRAMappingType
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
@ -374,50 +374,7 @@ class LoRAModelManager:
f" {self.model.__class__.__name__}." f" {self.model.__class__.__name__}."
self.packed_modules_mapping = get_packed_modules_mapping(self.model) self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model self._init_multimodal_config(model_config)
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
else:
self.supports_mm_lora = False
if self.supports_mm_lora:
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
# limit_per_prompt: int = max(
# self.info.get_allowed_mm_limits().values())
limit_per_prompt = 5 # TODO
# For vision tower
# max_num_batched_tokens = encoder_budget
# max_batches = max_batches * limit_per_prompt
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
# For language model
self.mm_punica_wrapper_mapping.update(
{self.mm_mapping.language_model[0]: self.punica_wrapper}
)
# TODO Connector is not supported at the moment.
self.mm_punica_wrapper_mapping.update(
{name: None for name in self.mm_mapping.connector}
)
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
self.is_moe_model = is_moe_model(self.model) self.is_moe_model = is_moe_model(self.model)
self.packed_modules: dict[str, list[str]] = {} self.packed_modules: dict[str, list[str]] = {}
@ -427,6 +384,72 @@ class LoRAModelManager:
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
def _init_multimodal_config(self, model_config):
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility
self.supports_mm_lora = False
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
if not self.supports_mm_lora:
return
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
# For vision tower
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
self.max_num_batched_tokens
)
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
# For language model
self.mm_punica_wrapper_mapping.update(
{self.mm_mapping.language_model[0]: self.punica_wrapper}
)
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"):
connector_tokens = self.info.get_num_mm_connector_tokens(
num_encoder_tokens
)
connector_punica_wrapper = get_punica_wrapper(
connector_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.mm_punica_wrapper_mapping.update(
{
name: connector_punica_wrapper
for name in self.mm_mapping.connector
}
)
else:
logger.warning_once(
"Connector LoRA support disabled: model does not implement "
"get_num_mm_connector_tokens(). This method is required to "
"determine the connector's token budget for LoRA operations."
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._registered_adapters) return len(self._registered_adapters)
@ -499,35 +522,27 @@ class LoRAModelManager:
) # type: ignore ) # type: ignore
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# update lora states # Default to the main language model wrapper
if not self.supports_mm_lora: target_wrapper = self.punica_wrapper
self.punica_wrapper.update_metadata(
mapping, if self.supports_mm_lora:
self.lora_index_to_id, if mapping.type == LoRAMappingType.TOWER:
self.lora_slots + 1, target_name = self.mm_mapping.tower_model[0]
self.vocab_size, target_wrapper = self.mm_punica_wrapper_mapping[target_name]
self.lora_config.lora_extra_vocab_size, elif mapping.type == LoRAMappingType.CONNECTOR:
) target_name = self.mm_mapping.connector[0]
elif mapping.is_mm_input: target_wrapper = self.mm_punica_wrapper_mapping[target_name]
self.mm_punica_wrapper_mapping[ else:
self.mm_mapping.tower_model[0] target_name = self.mm_mapping.language_model[0]
].update_metadata( target_wrapper = self.mm_punica_wrapper_mapping[target_name]
mapping,
self.lora_index_to_id, target_wrapper.update_metadata(
self.lora_slots + 1, mapping,
self.vocab_size, self.lora_index_to_id,
self.lora_config.lora_extra_vocab_size, self.lora_slots + 1,
) self.vocab_size,
else: self.lora_config.lora_extra_vocab_size,
self.mm_punica_wrapper_mapping[ )
self.mm_mapping.language_model[0]
].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
def remove_all_adapters(self): def remove_all_adapters(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
@ -548,15 +563,6 @@ class LoRAModelManager:
continue continue
if not self._match_target_modules(module_name): if not self._match_target_modules(module_name):
continue continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1] parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, []) packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule( new_module = replace_submodule(
@ -604,6 +610,7 @@ class LoRAModelManager:
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
continue continue
self.register_module(module_name, new_module) self.register_module(module_name, new_module)
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
if self.supports_mm_lora: if self.supports_mm_lora:

View File

@ -1040,6 +1040,25 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
for modality in ("image", "video") for modality in ("image", "video")
] ]
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLMultiModalProcessor,

View File

@ -1094,6 +1094,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return num_image_tokens * merge_size**2 return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:

View File

@ -43,6 +43,7 @@ from vllm.distributed.parallel_state import (
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping, LoRAMappingType
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -1689,7 +1690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _batch_mm_kwargs_from_scheduler( def _batch_mm_kwargs_from_scheduler(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: ) -> tuple[
list[MultiModalKwargsItem],
list[tuple[str, PlaceholderRange]],
list[str],
]:
"""Batch multimodal kwargs from scheduled encoder inputs. """Batch multimodal kwargs from scheduled encoder inputs.
Args: Args:
@ -1697,17 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs. inputs.
Returns: Returns:
A tuple of (mm_kwargs, req_ids_pos) where: A tuple of (mm_kwargs, mm_hashes_pos, req_ids) where:
- mm_kwargs: List of multimodal kwargs items to be batched - mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples - mm_hashes_pos: List of (mm_hash, position_info) tuples
- req_ids: List of request IDs for each encoder input
""" """
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return [], [] return [], [], []
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info) # list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]() mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
# list of request IDs for each encoder input
req_ids = list[str]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
@ -1716,13 +1724,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data) mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
req_ids.append(req_id)
return mm_kwargs, mm_hashes_pos return mm_kwargs, mm_hashes_pos, req_ids
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs using the helper method. # Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( mm_kwargs, mm_hashes_pos, encoder_req_ids = (
scheduler_output self._batch_mm_kwargs_from_scheduler(scheduler_output)
) )
if not mm_kwargs: if not mm_kwargs:
@ -1739,16 +1748,62 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
if self.lora_config and self.supports_mm_lora: if self.lora_config and self.supports_mm_lora:
mm_tokens = [ # Build LoRA mappings independently for encoder inputs
self.info.get_num_mm_encoder_tokens(pos_info.length) # (encoder batch structure is different from main batch)
for _, pos_info in mm_hashes_pos prompt_lora_mapping = []
] token_lora_mapping = []
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32) lora_requests = set()
self.set_active_loras(
self.input_batch, for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
num_scheduled_tokens, req_idx = self.input_batch.req_id_to_index[req_id]
is_mm_input=True, lora_id = int(self.input_batch.request_lora_mapping[req_idx])
num_tokens = self.info.get_num_mm_encoder_tokens(pos_info.length)
prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens)
if lora_id > 0:
lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
if lora_request is not None:
lora_requests.add(lora_request)
lora_mapping = LoRAMapping(
tuple(token_lora_mapping),
tuple(prompt_lora_mapping),
is_prefill=True,
type=LoRAMappingType.TOWER,
) )
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
if hasattr(self.info, "get_num_mm_connector_tokens"):
num_post_op_tokens = []
for _, pos_info in mm_hashes_pos:
mm_token_count = self.info.get_num_mm_encoder_tokens(
pos_info.length
)
post_op_count = self.info.get_num_mm_connector_tokens(
mm_token_count
)
num_post_op_tokens.append(post_op_count)
lora_ids = np.array(
self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
dtype=np.int32,
)
post_op_counts_np = np.array(num_post_op_tokens, dtype=np.int32)
new_token_indices = lora_ids.repeat(post_op_counts_np)
connector_mapping = LoRAMapping(
index_mapping=tuple(new_token_indices.tolist()),
prompt_mapping=self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
is_prefill=self.lora_manager._adapter_manager._last_mapping.is_prefill,
type=LoRAMappingType.CONNECTOR,
)
self.lora_manager.set_active_adapters(
lora_requests,
connector_mapping,
)
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
@ -1898,7 +1953,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs and formats them for the encoder-decoder model forward pass. inputs and formats them for the encoder-decoder model forward pass.
""" """
# Batch the multi-modal inputs using the helper method. # Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) mm_kwargs, _, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs: if not mm_kwargs:
return {} return {}

View File

@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import TypeAlias
import numpy as np import numpy as np
import torch import torch
@ -13,14 +14,14 @@ import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping, LoRAMappingType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = TPUInputBatch | GPUInputBatch InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__) logger = init_logger(__name__)
@ -37,12 +38,6 @@ class LoRAModelRunnerMixin:
if not supports_lora(model): if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
if supports_multimodal(model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
# Add LoRA Manager to the Model Runner # Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config, vllm_config,
@ -57,7 +52,7 @@ class LoRAModelRunnerMixin:
prompt_lora_mapping: tuple[int, ...], prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest], lora_requests: set[LoRARequest],
is_mm_input: bool = False, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None: ) -> None:
self._ensure_lora_enabled() self._ensure_lora_enabled()
@ -69,7 +64,7 @@ class LoRAModelRunnerMixin:
token_lora_mapping, token_lora_mapping,
prompt_lora_mapping, prompt_lora_mapping,
is_prefill=True, is_prefill=True,
is_mm_input=is_mm_input, type=mapping_type,
) )
self.lora_manager.set_active_adapters(lora_requests, lora_mapping) self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
@ -81,7 +76,7 @@ class LoRAModelRunnerMixin:
self, self,
input_batch: InputBatch, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None: ) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
@ -90,7 +85,7 @@ class LoRAModelRunnerMixin:
input_batch.make_lora_inputs(num_scheduled_tokens) input_batch.make_lora_inputs(num_scheduled_tokens)
) )
return self._set_active_loras( return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests, is_mm_input prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
) )
@contextmanager @contextmanager
@ -134,7 +129,7 @@ class LoRAModelRunnerMixin:
self, self,
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
): ):
if lora_config is None: if lora_config is None:
yield yield
@ -166,7 +161,7 @@ class LoRAModelRunnerMixin:
tuple(prompt_lora_mapping), tuple(prompt_lora_mapping),
tuple(token_lora_mapping), tuple(token_lora_mapping),
lora_requests, lora_requests,
is_mm_input, mapping_type,
) )
yield yield
@ -177,12 +172,12 @@ class LoRAModelRunnerMixin:
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
remove_lora: bool = True, remove_lora: bool = True,
is_mm_input: bool = False, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
): ):
with ( with (
self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras( self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, is_mm_input lora_config, num_scheduled_tokens, mapping_type
), ),
): ):
yield yield

View File

@ -32,7 +32,8 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.model_loader.tpu import TPUModelLoader
@ -1422,11 +1423,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._hidden_states_dtype = out.dtype self._hidden_states_dtype = out.dtype
def _set_active_loras( def _set_active_loras(
self, prompt_lora_mapping, token_lora_mapping, lora_requests self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None: ) -> None:
torch_xla.sync(wait=False) # Captures input updates torch_xla.sync(wait=False) # Captures input updates
super()._set_active_loras( super()._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
) )
torch_xla.sync(wait=False) # Captures metadata updates torch_xla.sync(wait=False) # Captures metadata updates