[feat] add connector support

Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
This commit is contained in:
prashanth058 2025-11-20 15:04:33 +00:00
parent 5c156c9f09
commit a69bde7e8f
11 changed files with 325 additions and 133 deletions

View File

@ -225,6 +225,21 @@ def qwen25vl_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
@pytest.fixture(scope="session")
def qwen2vl_language_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")
@pytest.fixture(scope="session")
def qwen2vl_vision_tower_connector_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")
@pytest.fixture(scope="session")
def qwen2vl_vision_tower_lora_files():
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

View File

@ -79,7 +79,6 @@ class Qwen2VLTester:
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
generated_texts = [output.outputs[0].text.strip() for output in outputs]
# Validate outputs
for generated, expected in zip(generated_texts, expected_outputs):
assert expected.startswith(generated), (
@ -130,6 +129,22 @@ EXPECTED_OUTPUTS = [
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
]
EXPECTED_OUTPUTS_LANGUAGE = [
"A stop sign is shown in an Asian city, with buildings and a car in the "
"background.",
"The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.",
]
EXPECTED_OUTPUTS_VISION = [
"A stop sign in front of oriental buildings.",
"A tree with pink flowers in front of it and a blue sky behind the flowers.",
]
EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
"A stop sign is located on the street of a Chinese neighborhood.",
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
]
# NOTE - beam search .text contains the whole text
EXPECTED_BEAM_SEARCH_OUTPUTS = [
[
@ -190,3 +205,64 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
# Test with different LoRA IDs
for lora_id in [1, 2]:
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_language_lora(qwen2vl_language_lora_files):
"""
Test language-only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files):
"""
Test vision tower + connector LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_connector_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id
)
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="Qwen2-VL dependency xformers incompatible with ROCm",
)
def test_qwen2vl_vision_no_connector_lora(
qwen2vl_vision_tower_lora_files,
):
"""
Test vision tower only LoRA adapter.
"""
config = TestConfig(
model_path=QWEN2VL_MODEL_PATH,
lora_path=qwen2vl_vision_tower_lora_files,
)
tester = Qwen2VLTester(config)
for lora_id in [1, 2]:
tester.run_test(
TEST_IMAGES,
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
lora_id=lora_id,
)

View File

@ -17,7 +17,7 @@ from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.utils import LoRAMapping
from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [
@ -36,4 +36,5 @@ __all__ = [
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"LoRAMappingType",
]

View File

@ -63,22 +63,25 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply(input_parallel)
# Only fuse bias add into GEMM for rank 0 (matches base
# RowParallelLinear behavior). This ensures bias will not get
# added more than once in TP>1 case and matches the numerical
# behavior of the unwrapped layer
bias_ = (
None
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
output_parallel = self.apply(input_parallel, bias_)
if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
# Bias was already added by rank 0 in apply(), no need to add again
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
output = output_
if not self.base_layer.return_bias:
return output

View File

@ -2,17 +2,24 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import Enum
import torch
import torch.nn as nn
class LoRAMappingType(Enum):
LANGUAGE = 1
TOWER = 2
CONNECTOR = 3
@dataclass
class LoRAMapping:
index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...]
is_prefill: bool = False
is_mm_input: bool = False
type: LoRAMappingType = LoRAMappingType.LANGUAGE
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)

View File

@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig, ModelConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, LoRAMappingType
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
@ -374,50 +374,7 @@ class LoRAModelManager:
f" {self.model.__class__.__name__}."
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
else:
self.supports_mm_lora = False
if self.supports_mm_lora:
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
# limit_per_prompt: int = max(
# self.info.get_allowed_mm_limits().values())
limit_per_prompt = 5 # TODO
# For vision tower
# max_num_batched_tokens = encoder_budget
# max_batches = max_batches * limit_per_prompt
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
# For language model
self.mm_punica_wrapper_mapping.update(
{self.mm_mapping.language_model[0]: self.punica_wrapper}
)
# TODO Connector is not supported at the moment.
self.mm_punica_wrapper_mapping.update(
{name: None for name in self.mm_mapping.connector}
)
self._init_multimodal_config(model_config)
self.is_pooling_model = is_pooling_model(self.model)
self.is_moe_model = is_moe_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
@ -427,6 +384,72 @@ class LoRAModelManager:
self._create_lora_modules()
self.model.lora_manager = self
def _init_multimodal_config(self, model_config):
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility
self.supports_mm_lora = False
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(model_config).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens"
)
if not self.supports_mm_lora:
return
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
# For vision tower
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
self.max_num_batched_tokens
)
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
# For language model
self.mm_punica_wrapper_mapping.update(
{self.mm_mapping.language_model[0]: self.punica_wrapper}
)
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.info, "get_num_mm_connector_tokens"):
connector_tokens = self.info.get_num_mm_connector_tokens(
num_encoder_tokens
)
connector_punica_wrapper = get_punica_wrapper(
connector_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.mm_punica_wrapper_mapping.update(
{
name: connector_punica_wrapper
for name in self.mm_mapping.connector
}
)
else:
logger.warning_once(
"Connector LoRA support disabled: model does not implement "
"get_num_mm_connector_tokens(). This method is required to "
"determine the connector's token budget for LoRA operations."
)
def __len__(self) -> int:
return len(self._registered_adapters)
@ -499,35 +522,27 @@ class LoRAModelManager:
) # type: ignore
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# update lora states
if not self.supports_mm_lora:
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
elif mapping.is_mm_input:
self.mm_punica_wrapper_mapping[
self.mm_mapping.tower_model[0]
].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
else:
self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]
].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
# Default to the main language model wrapper
target_wrapper = self.punica_wrapper
if self.supports_mm_lora:
if mapping.type == LoRAMappingType.TOWER:
target_name = self.mm_mapping.tower_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
elif mapping.type == LoRAMappingType.CONNECTOR:
target_name = self.mm_mapping.connector[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
else:
target_name = self.mm_mapping.language_model[0]
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
target_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
@ -548,15 +563,6 @@ class LoRAModelManager:
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
@ -604,6 +610,7 @@ class LoRAModelManager:
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
if self.supports_mm_lora:

View File

@ -1040,6 +1040,25 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
for modality in ("image", "video")
]
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,

View File

@ -1094,6 +1094,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:

View File

@ -43,6 +43,7 @@ from vllm.distributed.parallel_state import (
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping, LoRAMappingType
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -1689,7 +1690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _batch_mm_kwargs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
) -> tuple[
list[MultiModalKwargsItem],
list[tuple[str, PlaceholderRange]],
list[str],
]:
"""Batch multimodal kwargs from scheduled encoder inputs.
Args:
@ -1697,17 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs.
Returns:
A tuple of (mm_kwargs, req_ids_pos) where:
A tuple of (mm_kwargs, mm_hashes_pos, req_ids) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
- req_ids: List of request IDs for each encoder input
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return [], []
return [], [], []
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
# list of request IDs for each encoder input
req_ids = list[str]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
@ -1716,13 +1724,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
req_ids.append(req_id)
return mm_kwargs, mm_hashes_pos
return mm_kwargs, mm_hashes_pos, req_ids
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output
mm_kwargs, mm_hashes_pos, encoder_req_ids = (
self._batch_mm_kwargs_from_scheduler(scheduler_output)
)
if not mm_kwargs:
@ -1739,16 +1748,62 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_outputs = []
if self.lora_config and self.supports_mm_lora:
mm_tokens = [
self.info.get_num_mm_encoder_tokens(pos_info.length)
for _, pos_info in mm_hashes_pos
]
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
self.set_active_loras(
self.input_batch,
num_scheduled_tokens,
is_mm_input=True,
# Build LoRA mappings independently for encoder inputs
# (encoder batch structure is different from main batch)
prompt_lora_mapping = []
token_lora_mapping = []
lora_requests = set()
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
req_idx = self.input_batch.req_id_to_index[req_id]
lora_id = int(self.input_batch.request_lora_mapping[req_idx])
num_tokens = self.info.get_num_mm_encoder_tokens(pos_info.length)
prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens)
if lora_id > 0:
lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
if lora_request is not None:
lora_requests.add(lora_request)
lora_mapping = LoRAMapping(
tuple(token_lora_mapping),
tuple(prompt_lora_mapping),
is_prefill=True,
type=LoRAMappingType.TOWER,
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
if hasattr(self.info, "get_num_mm_connector_tokens"):
num_post_op_tokens = []
for _, pos_info in mm_hashes_pos:
mm_token_count = self.info.get_num_mm_encoder_tokens(
pos_info.length
)
post_op_count = self.info.get_num_mm_connector_tokens(
mm_token_count
)
num_post_op_tokens.append(post_op_count)
lora_ids = np.array(
self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
dtype=np.int32,
)
post_op_counts_np = np.array(num_post_op_tokens, dtype=np.int32)
new_token_indices = lora_ids.repeat(post_op_counts_np)
connector_mapping = LoRAMapping(
index_mapping=tuple(new_token_indices.tolist()),
prompt_mapping=self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
is_prefill=self.lora_manager._adapter_manager._last_mapping.is_prefill,
type=LoRAMappingType.CONNECTOR,
)
self.lora_manager.set_active_adapters(
lora_requests,
connector_mapping,
)
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
@ -1898,7 +1953,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
mm_kwargs, _, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}

View File

@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import TypeAlias
import numpy as np
import torch
@ -13,14 +14,14 @@ import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.layers import LoRAMapping, LoRAMappingType
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models import supports_lora
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = TPUInputBatch | GPUInputBatch
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__)
@ -37,12 +38,6 @@ class LoRAModelRunnerMixin:
if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
if supports_multimodal(model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config,
@ -57,7 +52,7 @@ class LoRAModelRunnerMixin:
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
is_mm_input: bool = False,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
self._ensure_lora_enabled()
@ -69,7 +64,7 @@ class LoRAModelRunnerMixin:
token_lora_mapping,
prompt_lora_mapping,
is_prefill=True,
is_mm_input=is_mm_input,
type=mapping_type,
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
@ -81,7 +76,7 @@ class LoRAModelRunnerMixin:
self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
@ -90,7 +85,7 @@ class LoRAModelRunnerMixin:
input_batch.make_lora_inputs(num_scheduled_tokens)
)
return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests, is_mm_input
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
)
@contextmanager
@ -134,7 +129,7 @@ class LoRAModelRunnerMixin:
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
):
if lora_config is None:
yield
@ -166,7 +161,7 @@ class LoRAModelRunnerMixin:
tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests,
is_mm_input,
mapping_type,
)
yield
@ -177,12 +172,12 @@ class LoRAModelRunnerMixin:
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
remove_lora: bool = True,
is_mm_input: bool = False,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, is_mm_input
lora_config, num_scheduled_tokens, mapping_type
),
):
yield

View File

@ -32,7 +32,8 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
@ -1422,11 +1423,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._hidden_states_dtype = out.dtype
def _set_active_loras(
self, prompt_lora_mapping, token_lora_mapping, lora_requests
self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
torch_xla.sync(wait=False) # Captures input updates
super()._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
)
torch_xla.sync(wait=False) # Captures metadata updates