mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-21 01:35:52 +08:00
[feat] add connector support
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
This commit is contained in:
parent
5c156c9f09
commit
a69bde7e8f
@ -225,6 +225,21 @@ def qwen25vl_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_language_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-language")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_vision_tower_connector_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower-connector")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_vision_tower_lora_files():
|
||||
return snapshot_download(repo_id="prashanth058/qwen2vl-flickr-lora-tower")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
@ -79,7 +79,6 @@ class Qwen2VLTester:
|
||||
lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path)
|
||||
outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request)
|
||||
generated_texts = [output.outputs[0].text.strip() for output in outputs]
|
||||
|
||||
# Validate outputs
|
||||
for generated, expected in zip(generated_texts, expected_outputs):
|
||||
assert expected.startswith(generated), (
|
||||
@ -130,6 +129,22 @@ EXPECTED_OUTPUTS = [
|
||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_LANGUAGE = [
|
||||
"A stop sign is shown in an Asian city, with buildings and a car in the "
|
||||
"background.",
|
||||
"The Tokyo Skytree can be seen behind the pink blossoms of the cherry trees.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION = [
|
||||
"A stop sign in front of oriental buildings.",
|
||||
"A tree with pink flowers in front of it and a blue sky behind the flowers.",
|
||||
]
|
||||
|
||||
EXPECTED_OUTPUTS_VISION_NO_CONNECTOR = [
|
||||
"A stop sign is located on the street of a Chinese neighborhood.",
|
||||
"A closeup shot of the Tokyo Skytree with pink flowers in the foreground.",
|
||||
]
|
||||
|
||||
# NOTE - beam search .text contains the whole text
|
||||
EXPECTED_BEAM_SEARCH_OUTPUTS = [
|
||||
[
|
||||
@ -190,3 +205,64 @@ def test_qwen25vl_lora(qwen25vl_lora_files):
|
||||
# Test with different LoRA IDs
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
||||
)
|
||||
def test_qwen2vl_language_lora(qwen2vl_language_lora_files):
|
||||
"""
|
||||
Test language-only LoRA adapter.
|
||||
"""
|
||||
config = TestConfig(
|
||||
model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_language_lora_files
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_LANGUAGE, lora_id=lora_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
||||
)
|
||||
def test_qwen2vl_vision_lora(qwen2vl_vision_tower_connector_lora_files):
|
||||
"""
|
||||
Test vision tower + connector LoRA adapter.
|
||||
"""
|
||||
config = TestConfig(
|
||||
model_path=QWEN2VL_MODEL_PATH,
|
||||
lora_path=qwen2vl_vision_tower_connector_lora_files,
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS_VISION, lora_id=lora_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm",
|
||||
)
|
||||
def test_qwen2vl_vision_no_connector_lora(
|
||||
qwen2vl_vision_tower_lora_files,
|
||||
):
|
||||
"""
|
||||
Test vision tower only LoRA adapter.
|
||||
|
||||
"""
|
||||
config = TestConfig(
|
||||
model_path=QWEN2VL_MODEL_PATH,
|
||||
lora_path=qwen2vl_vision_tower_lora_files,
|
||||
)
|
||||
tester = Qwen2VLTester(config)
|
||||
for lora_id in [1, 2]:
|
||||
tester.run_test(
|
||||
TEST_IMAGES,
|
||||
expected_outputs=EXPECTED_OUTPUTS_VISION_NO_CONNECTOR,
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
@ -36,4 +36,5 @@ __all__ = [
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"LoRAMappingType",
|
||||
]
|
||||
|
||||
@ -63,22 +63,25 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_parallel)
|
||||
# Only fuse bias add into GEMM for rank 0 (matches base
|
||||
# RowParallelLinear behavior). This ensures bias will not get
|
||||
# added more than once in TP>1 case and matches the numerical
|
||||
# behavior of the unwrapped layer
|
||||
bias_ = (
|
||||
None
|
||||
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
|
||||
else self.base_layer.bias
|
||||
)
|
||||
output_parallel = self.apply(input_parallel, bias_)
|
||||
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
# Bias was already added by rank 0 in apply(), no need to add again
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
output = output_
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
@ -2,17 +2,24 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LoRAMappingType(Enum):
|
||||
LANGUAGE = 1
|
||||
TOWER = 2
|
||||
CONNECTOR = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
is_mm_input: bool = False
|
||||
type: LoRAMappingType = LoRAMappingType.LANGUAGE
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
|
||||
@ -13,7 +13,7 @@ from torch import nn
|
||||
|
||||
from vllm.config.lora import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
|
||||
@ -374,50 +374,7 @@ class LoRAModelManager:
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
# For v0 compatibility
|
||||
if model_config is not None:
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.info = self.mm_registry.create_processor(model_config).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
if self.supports_mm_lora:
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_config = model_config.multimodal_config
|
||||
# limit_per_prompt: int = max(
|
||||
# self.info.get_allowed_mm_limits().values())
|
||||
limit_per_prompt = 5 # TODO
|
||||
|
||||
# For vision tower
|
||||
# max_num_batched_tokens = encoder_budget
|
||||
# max_batches = max_batches * limit_per_prompt
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name: get_punica_wrapper(
|
||||
self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for name in self.mm_mapping.tower_model
|
||||
}
|
||||
# For language model
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{self.mm_mapping.language_model[0]: self.punica_wrapper}
|
||||
)
|
||||
# TODO Connector is not supported at the moment.
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{name: None for name in self.mm_mapping.connector}
|
||||
)
|
||||
|
||||
self._init_multimodal_config(model_config)
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.is_moe_model = is_moe_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
@ -427,6 +384,72 @@ class LoRAModelManager:
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
|
||||
def _init_multimodal_config(self, model_config):
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
# For v0 compatibility
|
||||
self.supports_mm_lora = False
|
||||
if model_config is not None:
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.info = self.mm_registry.create_processor(model_config).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
|
||||
if not self.supports_mm_lora:
|
||||
return
|
||||
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_config = model_config.multimodal_config
|
||||
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
||||
|
||||
# For vision tower
|
||||
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
||||
self.max_num_batched_tokens
|
||||
)
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name: get_punica_wrapper(
|
||||
num_encoder_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for name in self.mm_mapping.tower_model
|
||||
}
|
||||
# For language model
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{self.mm_mapping.language_model[0]: self.punica_wrapper}
|
||||
)
|
||||
# Use wrapper for connector if present.
|
||||
if self.mm_mapping.connector:
|
||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||
connector_tokens = self.info.get_num_mm_connector_tokens(
|
||||
num_encoder_tokens
|
||||
)
|
||||
connector_punica_wrapper = get_punica_wrapper(
|
||||
connector_tokens,
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{
|
||||
name: connector_punica_wrapper
|
||||
for name in self.mm_mapping.connector
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Connector LoRA support disabled: model does not implement "
|
||||
"get_num_mm_connector_tokens(). This method is required to "
|
||||
"determine the connector's token budget for LoRA operations."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
|
||||
@ -499,35 +522,27 @@ class LoRAModelManager:
|
||||
) # type: ignore
|
||||
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# update lora states
|
||||
if not self.supports_mm_lora:
|
||||
self.punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
elif mapping.is_mm_input:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.tower_model[0]
|
||||
].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
else:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.language_model[0]
|
||||
].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
# Default to the main language model wrapper
|
||||
target_wrapper = self.punica_wrapper
|
||||
|
||||
if self.supports_mm_lora:
|
||||
if mapping.type == LoRAMappingType.TOWER:
|
||||
target_name = self.mm_mapping.tower_model[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
elif mapping.type == LoRAMappingType.CONNECTOR:
|
||||
target_name = self.mm_mapping.connector[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
else:
|
||||
target_name = self.mm_mapping.language_model[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
|
||||
target_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
@ -548,15 +563,6 @@ class LoRAModelManager:
|
||||
continue
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
# TODO: Remove this restriction
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently only supports "
|
||||
"adding LoRA to language model, %s will be ignored.",
|
||||
module_name,
|
||||
)
|
||||
continue
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
new_module = replace_submodule(
|
||||
@ -604,6 +610,7 @@ class LoRAModelManager:
|
||||
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
|
||||
continue
|
||||
self.register_module(module_name, new_module)
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
if self.supports_mm_lora:
|
||||
|
||||
@ -1040,6 +1040,25 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
for modality in ("image", "video")
|
||||
]
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
return num_vision_tokens // merge_size**2
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
|
||||
@ -1094,6 +1094,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
def get_num_mm_connector_tokens(
|
||||
self,
|
||||
num_vision_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
return num_vision_tokens // merge_size**2
|
||||
|
||||
|
||||
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -1689,7 +1690,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _batch_mm_kwargs_from_scheduler(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
|
||||
) -> tuple[
|
||||
list[MultiModalKwargsItem],
|
||||
list[tuple[str, PlaceholderRange]],
|
||||
list[str],
|
||||
]:
|
||||
"""Batch multimodal kwargs from scheduled encoder inputs.
|
||||
|
||||
Args:
|
||||
@ -1697,17 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of (mm_kwargs, req_ids_pos) where:
|
||||
A tuple of (mm_kwargs, mm_hashes_pos, req_ids) where:
|
||||
- mm_kwargs: List of multimodal kwargs items to be batched
|
||||
- mm_hashes_pos: List of (mm_hash, position_info) tuples
|
||||
- req_ids: List of request IDs for each encoder input
|
||||
"""
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return [], []
|
||||
return [], [], []
|
||||
# Batch the multi-modal inputs.
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
# list of tuple (mm_hash, position_info)
|
||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||
# list of request IDs for each encoder input
|
||||
req_ids = list[str]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
@ -1716,13 +1724,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mm_hash = mm_feature.identifier
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
||||
req_ids.append(req_id)
|
||||
|
||||
return mm_kwargs, mm_hashes_pos
|
||||
return mm_kwargs, mm_hashes_pos, req_ids
|
||||
|
||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
# Batch the multi-modal inputs using the helper method.
|
||||
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
||||
scheduler_output
|
||||
mm_kwargs, mm_hashes_pos, encoder_req_ids = (
|
||||
self._batch_mm_kwargs_from_scheduler(scheduler_output)
|
||||
)
|
||||
|
||||
if not mm_kwargs:
|
||||
@ -1739,16 +1748,62 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
|
||||
if self.lora_config and self.supports_mm_lora:
|
||||
mm_tokens = [
|
||||
self.info.get_num_mm_encoder_tokens(pos_info.length)
|
||||
for _, pos_info in mm_hashes_pos
|
||||
]
|
||||
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
|
||||
self.set_active_loras(
|
||||
self.input_batch,
|
||||
num_scheduled_tokens,
|
||||
is_mm_input=True,
|
||||
# Build LoRA mappings independently for encoder inputs
|
||||
# (encoder batch structure is different from main batch)
|
||||
prompt_lora_mapping = []
|
||||
token_lora_mapping = []
|
||||
lora_requests = set()
|
||||
|
||||
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
lora_id = int(self.input_batch.request_lora_mapping[req_idx])
|
||||
|
||||
num_tokens = self.info.get_num_mm_encoder_tokens(pos_info.length)
|
||||
prompt_lora_mapping.append(lora_id)
|
||||
token_lora_mapping.extend([lora_id] * num_tokens)
|
||||
|
||||
if lora_id > 0:
|
||||
lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
|
||||
if lora_request is not None:
|
||||
lora_requests.add(lora_request)
|
||||
|
||||
lora_mapping = LoRAMapping(
|
||||
tuple(token_lora_mapping),
|
||||
tuple(prompt_lora_mapping),
|
||||
is_prefill=True,
|
||||
type=LoRAMappingType.TOWER,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
||||
num_post_op_tokens = []
|
||||
for _, pos_info in mm_hashes_pos:
|
||||
mm_token_count = self.info.get_num_mm_encoder_tokens(
|
||||
pos_info.length
|
||||
)
|
||||
post_op_count = self.info.get_num_mm_connector_tokens(
|
||||
mm_token_count
|
||||
)
|
||||
num_post_op_tokens.append(post_op_count)
|
||||
|
||||
lora_ids = np.array(
|
||||
self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
|
||||
dtype=np.int32,
|
||||
)
|
||||
post_op_counts_np = np.array(num_post_op_tokens, dtype=np.int32)
|
||||
new_token_indices = lora_ids.repeat(post_op_counts_np)
|
||||
|
||||
connector_mapping = LoRAMapping(
|
||||
index_mapping=tuple(new_token_indices.tolist()),
|
||||
prompt_mapping=self.lora_manager._adapter_manager._last_mapping.prompt_mapping,
|
||||
is_prefill=self.lora_manager._adapter_manager._last_mapping.is_prefill,
|
||||
type=LoRAMappingType.CONNECTOR,
|
||||
)
|
||||
|
||||
self.lora_manager.set_active_adapters(
|
||||
lora_requests,
|
||||
connector_mapping,
|
||||
)
|
||||
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
@ -1898,7 +1953,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
inputs and formats them for the encoder-decoder model forward pass.
|
||||
"""
|
||||
# Batch the multi-modal inputs using the helper method.
|
||||
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
|
||||
mm_kwargs, _, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
|
||||
|
||||
if not mm_kwargs:
|
||||
return {}
|
||||
|
||||
@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -13,14 +14,14 @@ import torch.nn as nn
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.layers import LoRAMapping, LoRAMappingType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models import supports_lora
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = TPUInputBatch | GPUInputBatch
|
||||
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -37,12 +38,6 @@ class LoRAModelRunnerMixin:
|
||||
if not supports_lora(model):
|
||||
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model."
|
||||
)
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
vllm_config,
|
||||
@ -57,7 +52,7 @@ class LoRAModelRunnerMixin:
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
is_mm_input: bool = False,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
self._ensure_lora_enabled()
|
||||
|
||||
@ -69,7 +64,7 @@ class LoRAModelRunnerMixin:
|
||||
token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True,
|
||||
is_mm_input=is_mm_input,
|
||||
type=mapping_type,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
@ -81,7 +76,7 @@ class LoRAModelRunnerMixin:
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
|
||||
@ -90,7 +85,7 @@ class LoRAModelRunnerMixin:
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
)
|
||||
return self._set_active_loras(
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, is_mm_input
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
@ -134,7 +129,7 @@ class LoRAModelRunnerMixin:
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
):
|
||||
if lora_config is None:
|
||||
yield
|
||||
@ -166,7 +161,7 @@ class LoRAModelRunnerMixin:
|
||||
tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests,
|
||||
is_mm_input,
|
||||
mapping_type,
|
||||
)
|
||||
|
||||
yield
|
||||
@ -177,12 +172,12 @@ class LoRAModelRunnerMixin:
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
remove_lora: bool = True,
|
||||
is_mm_input: bool = False,
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
):
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, is_mm_input
|
||||
lora_config, num_scheduled_tokens, mapping_type
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@ -32,7 +32,8 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
@ -1422,11 +1423,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def _set_active_loras(
|
||||
self, prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||
self,
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
|
||||
) -> None:
|
||||
torch_xla.sync(wait=False) # Captures input updates
|
||||
super()._set_active_loras(
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
|
||||
)
|
||||
torch_xla.sync(wait=False) # Captures metadata updates
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user