Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-10-13 02:14:36 +00:00
parent 236e0fe9fd
commit cb1a6f074a
8 changed files with 132 additions and 118 deletions

View File

@ -17,7 +17,6 @@ aiohttp==3.13.0
# aiohttp-cors # aiohttp-cors
# datasets # datasets
# fsspec # fsspec
# gpt-oss
# lm-eval # lm-eval
# ray # ray
aiohttp-cors==0.8.1 aiohttp-cors==0.8.1
@ -45,9 +44,7 @@ argcomplete==3.5.1
arrow==1.3.0 arrow==1.3.0
# via isoduration # via isoduration
async-timeout==5.0.1 async-timeout==5.0.1
# via # via redis
# aiohttp
# redis
attrs==24.2.0 attrs==24.2.0
# via # via
# aiohttp # aiohttp
@ -108,8 +105,6 @@ chardet==5.2.0
# via mbstrdecoder # via mbstrdecoder
charset-normalizer==3.4.0 charset-normalizer==3.4.0
# via requests # via requests
chz==0.3.0
# via gpt-oss
click==8.1.7 click==8.1.7
# via # via
# black # black
@ -180,9 +175,7 @@ distlib==0.3.9
dnspython==2.7.0 dnspython==2.7.0
# via email-validator # via email-validator
docker==7.1.0 docker==7.1.0
# via # via mlflow
# gpt-oss
# mlflow
docopt==0.6.2 docopt==0.6.2
# via num2words # via num2words
docstring-parser==0.17.0 docstring-parser==0.17.0
@ -208,9 +201,7 @@ eval-type-backport==0.2.2
evaluate==0.4.3 evaluate==0.4.3
# via lm-eval # via lm-eval
fastapi==0.116.1 fastapi==0.116.1
# via # via mlflow-skinny
# gpt-oss
# mlflow-skinny
fastparquet==2024.11.0 fastparquet==2024.11.0
# via genai-perf # via genai-perf
fastrlock==0.8.2 fastrlock==0.8.2
@ -285,8 +276,6 @@ google-resumable-media==2.7.2
# via google-cloud-storage # via google-cloud-storage
googleapis-common-protos==1.70.0 googleapis-common-protos==1.70.0
# via google-api-core # via google-api-core
gpt-oss==0.0.8
# via -r requirements/test.in
graphene==3.4.3 graphene==3.4.3
# via mlflow # via mlflow
graphql-core==3.2.6 graphql-core==3.2.6
@ -314,8 +303,6 @@ hf-xet==1.1.7
# via huggingface-hub # via huggingface-hub
hiredis==3.0.0 hiredis==3.0.0
# via tensorizer # via tensorizer
html2text==2025.4.15
# via gpt-oss
httpcore==1.0.6 httpcore==1.0.6
# via httpx # via httpx
httpx==0.27.2 httpx==0.27.2
@ -450,7 +437,6 @@ lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b772215
lxml==5.3.0 lxml==5.3.0
# via # via
# blobfile # blobfile
# gpt-oss
# sacrebleu # sacrebleu
mako==1.3.10 mako==1.3.10
# via alembic # via alembic
@ -620,8 +606,6 @@ omegaconf==2.3.0
# lightning # lightning
open-clip-torch==2.32.0 open-clip-torch==2.32.0
# via -r requirements/test.in # via -r requirements/test.in
openai-harmony==0.0.4
# via gpt-oss
opencensus==0.11.4 opencensus==0.11.4
# via ray # via ray
opencensus-context==0.1.3 opencensus-context==0.1.3
@ -793,12 +777,10 @@ pydantic==2.12.0
# albumentations # albumentations
# datamodel-code-generator # datamodel-code-generator
# fastapi # fastapi
# gpt-oss
# lightly # lightly
# mistral-common # mistral-common
# mlflow-skinny # mlflow-skinny
# mteb # mteb
# openai-harmony
# pydantic-extra-types # pydantic-extra-types
# ray # ray
pydantic-core==2.41.1 pydantic-core==2.41.1
@ -929,7 +911,6 @@ requests==2.32.3
# evaluate # evaluate
# google-api-core # google-api-core
# google-cloud-storage # google-cloud-storage
# gpt-oss
# huggingface-hub # huggingface-hub
# lightly # lightly
# lm-eval # lm-eval
@ -1072,8 +1053,6 @@ starlette-testclient==0.4.1
# via schemathesis # via schemathesis
statsmodels==0.14.4 statsmodels==0.14.4
# via genai-perf # via genai-perf
structlog==25.4.0
# via gpt-oss
sympy==1.13.3 sympy==1.13.3
# via # via
# einx # einx
@ -1088,15 +1067,12 @@ tcolorpy==0.1.6
# via pytablewriter # via pytablewriter
tenacity==9.1.2 tenacity==9.1.2
# via # via
# gpt-oss
# lm-eval # lm-eval
# plotly # plotly
tensorboardx==2.6.4 tensorboardx==2.6.4
# via lightning # via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
termcolor==3.1.0
# via gpt-oss
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
@ -1107,7 +1083,6 @@ tifffile==2025.3.30
# terratorch # terratorch
tiktoken==0.12.0 tiktoken==0.12.0
# via # via
# gpt-oss
# lm-eval # lm-eval
# mistral-common # mistral-common
timm==1.0.17 timm==1.0.17
@ -1121,12 +1096,9 @@ tokenizers==0.22.0
# via # via
# -r requirements/test.in # -r requirements/test.in
# transformers # transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.2.1 tomli==2.2.1
# via # via
# black # coverage
# pytest
# schemathesis # schemathesis
tomli-w==1.2.0 tomli-w==1.2.0
# via schemathesis # via schemathesis
@ -1235,7 +1207,6 @@ typing-extensions==4.15.0
# aiosignal # aiosignal
# albumentations # albumentations
# alembic # alembic
# chz
# fastapi # fastapi
# graphene # graphene
# huggingface-hub # huggingface-hub
@ -1275,9 +1246,7 @@ urllib3==2.2.3
# responses # responses
# tritonclient # tritonclient
uvicorn==0.35.0 uvicorn==0.35.0
# via # via mlflow-skinny
# gpt-oss
# mlflow-skinny
vector-quantize-pytorch==1.21.2 vector-quantize-pytorch==1.21.2
# via -r requirements/test.in # via -r requirements/test.in
virtualenv==20.31.2 virtualenv==20.31.2

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.layers.base import BaseLayerWithLoRA, PunicaWrapperBase from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import ( from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
@ -36,5 +36,4 @@ __all__ = [
"RowParallelLinearWithShardedLoRA", "RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA", "ReplicatedLinearWithLoRA",
"LoRAMapping", "LoRAMapping",
"PunicaWrapperBase",
] ]

View File

@ -124,6 +124,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
) -> torch.Tensor: ) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# Store original shape for later reshaping
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like # In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions. # therefore we need to flatten the batch dimensions.
@ -137,6 +140,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
if not current_platform.can_update_inplace(): if not current_platform.can_update_inplace():
output = lora_output output = lora_output
# Restore original shape if it was flattened
if original_shape is not None:
output = output.reshape(original_shape)
return output return output
@property @property

View File

@ -12,6 +12,7 @@ class LoRAMapping:
index_mapping: tuple[int, ...] index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...] prompt_mapping: tuple[int, ...]
is_prefill: bool = False is_prefill: bool = False
is_mm_input: bool = False
def __post_init__(self): def __post_init__(self):
self.index_mapping = tuple(self.index_mapping) self.index_mapping = tuple(self.index_mapping)

View File

@ -12,10 +12,10 @@ from torch import nn
from vllm.config.lora import LoRAConfig, ModelConfig from vllm.config.lora import LoRAConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, PunicaWrapperBase from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
from vllm.lora.utils import ( from vllm.lora.utils import (
from_layer, from_layer,
from_layer_logits_processor, from_layer_logits_processor,
@ -30,8 +30,8 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.model_executor.utils import get_packed_modules_mapping from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache from vllm.utils.cache import LRUCache
@ -378,17 +378,18 @@ class LoRAModelManager:
supports_multimodal(self.model) supports_multimodal(self.model)
# In case the model only supports LoRA for # In case the model only supports LoRA for
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")) and hasattr(self.model, "get_mm_mapping")
)
# For v0 compatibility # For v0 compatibility
if model_config is not None: if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor( self.info = self.mm_registry.create_processor(model_config).info
model_config, disable_cache=True).info
self.supports_mm_lora = self.supports_mm and hasattr( self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens") self.info, "get_num_mm_encoder_tokens"
)
else: else:
self.supports_mm_lora = False self.supports_mm_lora = False
if self.supports_mm_lora: # 从init传进来就可以了不需要model_config了 if self.supports_mm_lora: # 从init传进来就可以了不需要model_config了
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config self.mm_config = model_config.multimodal_config
# limit_per_prompt: int = max( # limit_per_prompt: int = max(
@ -399,10 +400,8 @@ class LoRAModelManager:
# max_num_batched_tokens = encoder_budget # max_num_batched_tokens = encoder_budget
# max_batches = max_batches * limit_per_prompt # max_batches = max_batches * limit_per_prompt
self.mm_punica_wrapper_mapping = { self.mm_punica_wrapper_mapping = {
name: name: get_punica_wrapper(
get_punica_wrapper( self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
self.info.get_num_mm_encoder_tokens(
max_num_batched_tokens),
max_batches=self.max_num_seqs * limit_per_prompt, max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, max_loras=self.lora_config.max_loras,
@ -411,16 +410,11 @@ class LoRAModelManager:
} }
# For language model # For language model
self.mm_punica_wrapper_mapping.update( self.mm_punica_wrapper_mapping.update(
{ {self.mm_mapping.language_model[0]: self.punica_wrapper}
self.mm_mapping.language_model[0]: self.punica_wrapper
}
) )
# TODO Connector is not supported at the moment. # TODO Connector is not supported at the moment.
self.mm_punica_wrapper_mapping.update( self.mm_punica_wrapper_mapping.update(
{ {name: None for name in self.mm_mapping.connector}
name: None
for name in self.mm_mapping.connector
}
) )
self.is_pooling_model = is_pooling_model(self.model) self.is_pooling_model = is_pooling_model(self.model)
@ -512,28 +506,27 @@ class LoRAModelManager:
self.lora_slots + 1, self.lora_slots + 1,
self.vocab_size, self.vocab_size,
self.lora_config.lora_extra_vocab_size, self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
) )
elif mapping.is_mm_input: elif mapping.is_mm_input:
self.mm_punica_wrapper_mapping[ self.mm_punica_wrapper_mapping[
self.mm_mapping.tower_model[0]].update_metadata( self.mm_mapping.tower_model[0]
mapping, ].update_metadata(
self.lora_index_to_id, mapping,
self.lora_slots + 1, self.lora_index_to_id,
self.vocab_size, self.lora_slots + 1,
self.lora_config.lora_extra_vocab_size, self.vocab_size,
self.long_lora_context, self.lora_config.lora_extra_vocab_size,
) )
else: else:
self.mm_punica_wrapper_mapping[ self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]].update_metadata( self.mm_mapping.language_model[0]
mapping, ].update_metadata(
self.lora_index_to_id, mapping,
self.lora_slots + 1, self.lora_index_to_id,
self.vocab_size, self.lora_slots + 1,
self.lora_config.lora_extra_vocab_size, self.vocab_size,
self.long_lora_context, self.lora_config.lora_extra_vocab_size,
) )
def remove_all_adapters(self): def remove_all_adapters(self):
"""Remove all LoRAModels from the manager.""" """Remove all LoRAModels from the manager."""
@ -613,8 +606,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference. # All lora layers share the same punica_wrapper based on reference.
if self.supports_mm_lora: if self.supports_mm_lora:
new_module.set_mapping( new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
self._get_mm_punica_wrapper(module_name))
else: else:
new_module.set_mapping(self.punica_wrapper) new_module.set_mapping(self.punica_wrapper)
@ -711,22 +703,23 @@ class LoRAModelManager:
if self.supports_mm: if self.supports_mm:
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
if self.supports_mm_lora: if self.supports_mm_lora:
return self._get_mm_punica_wrapper(module_name) is None return self._get_mm_punica_wrapper(module_name) is None
else: else:
return any( return any([module_name.startswith(prefix) for prefix in prefix_lst])
[module_name.startswith(prefix) for prefix in prefix_lst])
return False return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase: def _get_mm_punica_wrapper(self, module_name: str) -> Optional[PunicaWrapperBase]:
""" """
Match the corresponding punica_wrapper based on module_name, Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module. and return None if lora is not supported for this module.
""" """
if self.supports_mm_lora: if self.supports_mm_lora:
# Ensure matching by the longest prefix. # Ensure matching by the longest prefix.
sorted_prefixes = sorted(self.mm_punica_wrapper_mapping.keys(), sorted_prefixes = sorted(
key=lambda x: len(x), reverse=True) self.mm_punica_wrapper_mapping.keys(),
key=lambda x: len(x),
reverse=True,
)
for prefix in sorted_prefixes: for prefix in sorted_prefixes:
if module_name.startswith(prefix): if module_name.startswith(prefix):
@ -834,12 +827,25 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
class LRUCacheLoRAModelManager(LoRAModelManager): class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache.""" """A model manager that manages multiple LoRAs with LRU cache."""
def __init__(self, model: nn.Module, max_num_seqs: int, def __init__(
max_num_batched_tokens: int, vocab_size: int, self,
lora_config: LoRAConfig, model_config: ModelConfig, model: nn.Module,
device: torch.device): max_num_seqs: int,
super().__init__(model, max_num_seqs, max_num_batched_tokens, max_num_batched_tokens: int,
vocab_size, lora_config, model_config, device) vocab_size: int,
lora_config: LoRAConfig,
model_config: ModelConfig,
device: torch.device,
):
super().__init__(
model,
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
model_config,
device,
)
self._registered_adapters: LoRALRUCache = LoRALRUCache( self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter self.capacity, self.deactivate_adapter
) )
@ -906,15 +912,16 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def create_lora_manager( def create_lora_manager(
model: nn.Module, model: nn.Module,
max_num_seqs: int, max_num_seqs: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: ModelConfig, model_config: ModelConfig,
device: torch.device, device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager: **kwargs,
) -> LoRAModelManager:
"""Create a LoRA adapter for a given model.""" """Create a LoRA adapter for a given model."""
if not isinstance(model, SupportsLoRA): if not isinstance(model, SupportsLoRA):
raise ValueError(f"Model {type(model)} is not supported for LoRA.") raise ValueError(f"Model {type(model)} is not supported for LoRA.")

View File

@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
import torch import torch
from vllm.config import VllmConfig, ModelConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.models import ( from vllm.lora.models import (
LoRAModel, LoRAModel,
@ -71,6 +71,7 @@ class WorkerLoRAManager:
def create_lora_manager( def create_lora_manager(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
model_config: Optional[ModelConfig] = None,
) -> Any: ) -> Any:
lora_manager = create_lora_manager( lora_manager = create_lora_manager(
model, model,
@ -80,6 +81,7 @@ class WorkerLoRAManager:
lora_config=self.lora_config, lora_config=self.lora_config,
device=self.device, device=self.device,
lora_manager_cls=self._manager_cls, lora_manager_cls=self._manager_cls,
model_config=model_config,
) )
self._adapter_manager = lora_manager self._adapter_manager = lora_manager
return lora_manager.model return lora_manager.model

View File

@ -512,6 +512,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# Multimodal LoRA support
if self.supports_mm_inputs:
self.info = self.mm_registry.create_processor(self.model_config).info
self.supports_mm_lora = hasattr(self.info, "get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
if self.mm_budget: if self.mm_budget:
self.mm_budget.reset_cache() self.mm_budget.reset_cache()
@ -571,15 +578,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
return model_kwargs return model_kwargs
# Multimodal LoRA support
if self.is_multimodal_model:
self.info = self.mm_registry.create_processor(
self.model_config, disable_cache=True).info
self.supports_mm_lora = hasattr(self.info,
"get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
@ -1751,6 +1749,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model) model = cast(SupportsMultiModal, self.model)
encoder_outputs = [] encoder_outputs = []
if self.lora_config and self.supports_mm_lora:
mm_tokens = [
self.info.get_num_mm_encoder_tokens(pos_info.length)
for _, pos_info in mm_hashes_pos
]
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
self.set_active_loras(
self.input_batch,
num_scheduled_tokens,
is_mm_input=True,
)
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
@ -2903,7 +2914,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model( self.model = self.load_lora_model(
self.model, self.vllm_config, self.device self.model, self.vllm_config, self.device, self.model_config
) )
if hasattr(self, "drafter"): if hasattr(self, "drafter"):
logger.info("Loading drafter model...") logger.info("Loading drafter model...")

View File

@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
@ -29,7 +29,11 @@ logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner # Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin: class LoRAModelRunnerMixin:
def load_lora_model( def load_lora_model(
self, model: nn.Module, vllm_config: VllmConfig, device: torch.device self,
model: nn.Module,
vllm_config: VllmConfig,
device: torch.device,
model_config: ModelConfig = None,
) -> nn.Module: ) -> nn.Module:
if not supports_lora(model): if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
@ -54,7 +58,7 @@ class LoRAModelRunnerMixin:
prompt_lora_mapping: tuple[int, ...], prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest], lora_requests: set[LoRARequest],
is_mm_input: bool = False is_mm_input: bool = False,
) -> None: ) -> None:
self._ensure_lora_enabled() self._ensure_lora_enabled()
@ -63,7 +67,10 @@ class LoRAModelRunnerMixin:
# On cuda platforms we use the same kernels for prefill and # On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored. # decode and this flag is generally ignored.
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
token_lora_mapping, prompt_lora_mapping, is_prefill=True, is_mm_input=is_mm_input token_lora_mapping,
prompt_lora_mapping,
is_prefill=True,
is_mm_input=is_mm_input,
) )
self.lora_manager.set_active_adapters(lora_requests, lora_mapping) self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
@ -72,7 +79,10 @@ class LoRAModelRunnerMixin:
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
def set_active_loras( def set_active_loras(
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray, is_mm_input: bool = False self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False,
) -> None: ) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
@ -122,7 +132,10 @@ class LoRAModelRunnerMixin:
@contextmanager @contextmanager
def maybe_select_dummy_loras( def maybe_select_dummy_loras(
self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray, is_mm_input: bool = False self,
lora_config: Optional[LoRAConfig],
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False,
): ):
if lora_config is None: if lora_config is None:
yield yield
@ -151,7 +164,10 @@ class LoRAModelRunnerMixin:
} }
self._set_active_loras( self._set_active_loras(
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests, is_mm_input tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests,
is_mm_input,
) )
yield yield
@ -162,11 +178,13 @@ class LoRAModelRunnerMixin:
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
remove_lora: bool = True, remove_lora: bool = True,
is_mm_input: bool = False is_mm_input: bool = False,
): ):
with ( with (
self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens, is_mm_input), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, is_mm_input
),
): ):
yield yield