mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 16:45:48 +08:00
update
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
236e0fe9fd
commit
cb1a6f074a
@ -17,7 +17,6 @@ aiohttp==3.13.0
|
||||
# aiohttp-cors
|
||||
# datasets
|
||||
# fsspec
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# ray
|
||||
aiohttp-cors==0.8.1
|
||||
@ -45,9 +44,7 @@ argcomplete==3.5.1
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
async-timeout==5.0.1
|
||||
# via
|
||||
# aiohttp
|
||||
# redis
|
||||
# via redis
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -108,8 +105,6 @@ chardet==5.2.0
|
||||
# via mbstrdecoder
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
chz==0.3.0
|
||||
# via gpt-oss
|
||||
click==8.1.7
|
||||
# via
|
||||
# black
|
||||
@ -180,9 +175,7 @@ distlib==0.3.9
|
||||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
docker==7.1.0
|
||||
# via
|
||||
# gpt-oss
|
||||
# mlflow
|
||||
# via mlflow
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
docstring-parser==0.17.0
|
||||
@ -208,9 +201,7 @@ eval-type-backport==0.2.2
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastapi==0.116.1
|
||||
# via
|
||||
# gpt-oss
|
||||
# mlflow-skinny
|
||||
# via mlflow-skinny
|
||||
fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
@ -285,8 +276,6 @@ google-resumable-media==2.7.2
|
||||
# via google-cloud-storage
|
||||
googleapis-common-protos==1.70.0
|
||||
# via google-api-core
|
||||
gpt-oss==0.0.8
|
||||
# via -r requirements/test.in
|
||||
graphene==3.4.3
|
||||
# via mlflow
|
||||
graphql-core==3.2.6
|
||||
@ -314,8 +303,6 @@ hf-xet==1.1.7
|
||||
# via huggingface-hub
|
||||
hiredis==3.0.0
|
||||
# via tensorizer
|
||||
html2text==2025.4.15
|
||||
# via gpt-oss
|
||||
httpcore==1.0.6
|
||||
# via httpx
|
||||
httpx==0.27.2
|
||||
@ -450,7 +437,6 @@ lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b772215
|
||||
lxml==5.3.0
|
||||
# via
|
||||
# blobfile
|
||||
# gpt-oss
|
||||
# sacrebleu
|
||||
mako==1.3.10
|
||||
# via alembic
|
||||
@ -620,8 +606,6 @@ omegaconf==2.3.0
|
||||
# lightning
|
||||
open-clip-torch==2.32.0
|
||||
# via -r requirements/test.in
|
||||
openai-harmony==0.0.4
|
||||
# via gpt-oss
|
||||
opencensus==0.11.4
|
||||
# via ray
|
||||
opencensus-context==0.1.3
|
||||
@ -793,12 +777,10 @@ pydantic==2.12.0
|
||||
# albumentations
|
||||
# datamodel-code-generator
|
||||
# fastapi
|
||||
# gpt-oss
|
||||
# lightly
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# openai-harmony
|
||||
# pydantic-extra-types
|
||||
# ray
|
||||
pydantic-core==2.41.1
|
||||
@ -929,7 +911,6 @@ requests==2.32.3
|
||||
# evaluate
|
||||
# google-api-core
|
||||
# google-cloud-storage
|
||||
# gpt-oss
|
||||
# huggingface-hub
|
||||
# lightly
|
||||
# lm-eval
|
||||
@ -1072,8 +1053,6 @@ starlette-testclient==0.4.1
|
||||
# via schemathesis
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
structlog==25.4.0
|
||||
# via gpt-oss
|
||||
sympy==1.13.3
|
||||
# via
|
||||
# einx
|
||||
@ -1088,15 +1067,12 @@ tcolorpy==0.1.6
|
||||
# via pytablewriter
|
||||
tenacity==9.1.2
|
||||
# via
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# plotly
|
||||
tensorboardx==2.6.4
|
||||
# via lightning
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
termcolor==3.1.0
|
||||
# via gpt-oss
|
||||
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
@ -1107,7 +1083,6 @@ tifffile==2025.3.30
|
||||
# terratorch
|
||||
tiktoken==0.12.0
|
||||
# via
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.17
|
||||
@ -1121,12 +1096,9 @@ tokenizers==0.22.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# transformers
|
||||
toml==0.10.2
|
||||
# via datamodel-code-generator
|
||||
tomli==2.2.1
|
||||
# via
|
||||
# black
|
||||
# pytest
|
||||
# coverage
|
||||
# schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
@ -1235,7 +1207,6 @@ typing-extensions==4.15.0
|
||||
# aiosignal
|
||||
# albumentations
|
||||
# alembic
|
||||
# chz
|
||||
# fastapi
|
||||
# graphene
|
||||
# huggingface-hub
|
||||
@ -1275,9 +1246,7 @@ urllib3==2.2.3
|
||||
# responses
|
||||
# tritonclient
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# gpt-oss
|
||||
# mlflow-skinny
|
||||
# via mlflow-skinny
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements/test.in
|
||||
virtualenv==20.31.2
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA, PunicaWrapperBase
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.layers.column_parallel_linear import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
@ -36,5 +36,4 @@ __all__ = [
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"PunicaWrapperBase",
|
||||
]
|
||||
|
||||
@ -124,6 +124,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# Store original shape for later reshaping
|
||||
original_shape = output.shape if output.ndim == 3 else None
|
||||
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
@ -137,6 +140,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
# Restore original shape if it was flattened
|
||||
if original_shape is not None:
|
||||
output = output.reshape(original_shape)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
|
||||
@ -12,6 +12,7 @@ class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
is_mm_input: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
|
||||
@ -12,10 +12,10 @@ from torch import nn
|
||||
|
||||
from vllm.config.lora import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, PunicaWrapperBase
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
|
||||
from vllm.lora.utils import (
|
||||
from_layer,
|
||||
from_layer_logits_processor,
|
||||
@ -30,8 +30,8 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.model_executor.utils import get_packed_modules_mapping
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.cache import LRUCache
|
||||
|
||||
@ -378,17 +378,18 @@ class LoRAModelManager:
|
||||
supports_multimodal(self.model)
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping"))
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
# For v0 compatibility
|
||||
if model_config is not None:
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.info = self.mm_registry.create_processor(
|
||||
model_config, disable_cache=True).info
|
||||
self.info = self.mm_registry.create_processor(model_config).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens")
|
||||
self.info, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
if self.supports_mm_lora: # 从init传进来就可以了,不需要model_config了
|
||||
if self.supports_mm_lora: # 从init传进来就可以了,不需要model_config了
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_config = model_config.multimodal_config
|
||||
# limit_per_prompt: int = max(
|
||||
@ -399,10 +400,8 @@ class LoRAModelManager:
|
||||
# max_num_batched_tokens = encoder_budget
|
||||
# max_batches = max_batches * limit_per_prompt
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name:
|
||||
get_punica_wrapper(
|
||||
self.info.get_num_mm_encoder_tokens(
|
||||
max_num_batched_tokens),
|
||||
name: get_punica_wrapper(
|
||||
self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
|
||||
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
@ -411,16 +410,11 @@ class LoRAModelManager:
|
||||
}
|
||||
# For language model
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{
|
||||
self.mm_mapping.language_model[0]: self.punica_wrapper
|
||||
}
|
||||
{self.mm_mapping.language_model[0]: self.punica_wrapper}
|
||||
)
|
||||
# TODO Connector is not supported at the moment.
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{
|
||||
name: None
|
||||
for name in self.mm_mapping.connector
|
||||
}
|
||||
{name: None for name in self.mm_mapping.connector}
|
||||
)
|
||||
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
@ -512,28 +506,27 @@ class LoRAModelManager:
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
elif mapping.is_mm_input:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.tower_model[0]].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
self.mm_mapping.tower_model[0]
|
||||
].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
else:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.language_model[0]].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
self.mm_mapping.language_model[0]
|
||||
].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
@ -613,8 +606,7 @@ class LoRAModelManager:
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
if self.supports_mm_lora:
|
||||
new_module.set_mapping(
|
||||
self._get_mm_punica_wrapper(module_name))
|
||||
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
||||
else:
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
@ -711,22 +703,23 @@ class LoRAModelManager:
|
||||
if self.supports_mm:
|
||||
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||
if self.supports_mm_lora:
|
||||
|
||||
return self._get_mm_punica_wrapper(module_name) is None
|
||||
else:
|
||||
return any(
|
||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return any([module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
|
||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
||||
def _get_mm_punica_wrapper(self, module_name: str) -> Optional[PunicaWrapperBase]:
|
||||
"""
|
||||
Match the corresponding punica_wrapper based on module_name,
|
||||
Match the corresponding punica_wrapper based on module_name,
|
||||
and return None if lora is not supported for this module.
|
||||
"""
|
||||
if self.supports_mm_lora:
|
||||
# Ensure matching by the longest prefix.
|
||||
sorted_prefixes = sorted(self.mm_punica_wrapper_mapping.keys(),
|
||||
key=lambda x: len(x), reverse=True)
|
||||
sorted_prefixes = sorted(
|
||||
self.mm_punica_wrapper_mapping.keys(),
|
||||
key=lambda x: len(x),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for prefix in sorted_prefixes:
|
||||
if module_name.startswith(prefix):
|
||||
@ -834,12 +827,25 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||
|
||||
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||
max_num_batched_tokens: int, vocab_size: int,
|
||||
lora_config: LoRAConfig, model_config: ModelConfig,
|
||||
device: torch.device):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config, model_config, device)
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
max_num_seqs,
|
||||
max_num_batched_tokens,
|
||||
vocab_size,
|
||||
lora_config,
|
||||
model_config,
|
||||
device,
|
||||
)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter
|
||||
)
|
||||
@ -906,15 +912,16 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs,
|
||||
) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not isinstance(model, SupportsLoRA):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, ModelConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (
|
||||
LoRAModel,
|
||||
@ -71,6 +71,7 @@ class WorkerLoRAManager:
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -80,6 +81,7 @@ class WorkerLoRAManager:
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
model_config=model_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
@ -512,6 +512,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Multimodal LoRA support
|
||||
if self.supports_mm_inputs:
|
||||
self.info = self.mm_registry.create_processor(self.model_config).info
|
||||
self.supports_mm_lora = hasattr(self.info, "get_num_mm_encoder_tokens")
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@ -571,15 +578,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
return model_kwargs
|
||||
|
||||
# Multimodal LoRA support
|
||||
if self.is_multimodal_model:
|
||||
self.info = self.mm_registry.create_processor(
|
||||
self.model_config, disable_cache=True).info
|
||||
self.supports_mm_lora = hasattr(self.info,
|
||||
"get_num_mm_encoder_tokens")
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
@ -1751,6 +1749,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_outputs = []
|
||||
|
||||
if self.lora_config and self.supports_mm_lora:
|
||||
mm_tokens = [
|
||||
self.info.get_num_mm_encoder_tokens(pos_info.length)
|
||||
for _, pos_info in mm_hashes_pos
|
||||
]
|
||||
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
|
||||
self.set_active_loras(
|
||||
self.input_batch,
|
||||
num_scheduled_tokens,
|
||||
is_mm_input=True,
|
||||
)
|
||||
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
@ -2903,7 +2914,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model, self.vllm_config, self.device
|
||||
self.model, self.vllm_config, self.device, self.model_config
|
||||
)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info("Loading drafter model...")
|
||||
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
@ -29,7 +29,11 @@ logger = init_logger(__name__)
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
def load_lora_model(
|
||||
self, model: nn.Module, vllm_config: VllmConfig, device: torch.device
|
||||
self,
|
||||
model: nn.Module,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
model_config: ModelConfig = None,
|
||||
) -> nn.Module:
|
||||
if not supports_lora(model):
|
||||
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
|
||||
@ -54,7 +58,7 @@ class LoRAModelRunnerMixin:
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest],
|
||||
is_mm_input: bool = False
|
||||
is_mm_input: bool = False,
|
||||
) -> None:
|
||||
self._ensure_lora_enabled()
|
||||
|
||||
@ -63,7 +67,10 @@ class LoRAModelRunnerMixin:
|
||||
# On cuda platforms we use the same kernels for prefill and
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(
|
||||
token_lora_mapping, prompt_lora_mapping, is_prefill=True, is_mm_input=is_mm_input
|
||||
token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True,
|
||||
is_mm_input=is_mm_input,
|
||||
)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
@ -72,7 +79,10 @@ class LoRAModelRunnerMixin:
|
||||
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
|
||||
|
||||
def set_active_loras(
|
||||
self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray, is_mm_input: bool = False
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False,
|
||||
) -> None:
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
|
||||
@ -122,7 +132,10 @@ class LoRAModelRunnerMixin:
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(
|
||||
self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray, is_mm_input: bool = False
|
||||
self,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False,
|
||||
):
|
||||
if lora_config is None:
|
||||
yield
|
||||
@ -151,7 +164,10 @@ class LoRAModelRunnerMixin:
|
||||
}
|
||||
|
||||
self._set_active_loras(
|
||||
tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests, is_mm_input
|
||||
tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests,
|
||||
is_mm_input,
|
||||
)
|
||||
|
||||
yield
|
||||
@ -162,11 +178,13 @@ class LoRAModelRunnerMixin:
|
||||
lora_config: Optional[LoRAConfig],
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
remove_lora: bool = True,
|
||||
is_mm_input: bool = False
|
||||
is_mm_input: bool = False,
|
||||
):
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens, is_mm_input),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, is_mm_input
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user