[V0 deprecation] Remove long context LoRA (#21169)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-07-19 17:15:41 +08:00 committed by GitHub
parent cf8cc32674
commit 1eaff27815
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 35 additions and 301 deletions

View File

@ -221,11 +221,6 @@ def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)

View File

@ -38,8 +38,8 @@ ERROR_CASES = [
]
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
def test_peft_helper_pass(sql_lora_files, tmp_path):
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
max_position_embeddings=4096)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config)
@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
"embed_tokens",
"lm_head",
]
assert peft_helper.context_length == 16384
assert peft_helper.vllm_max_position_embeddings == 4096
assert peft_helper.vllm_long_context_scaling_factor == float(
math.ceil(peft_helper.context_length /
peft_helper.vllm_max_position_embeddings))
# test RSLoRA
rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora"
shutil.copytree(long_context_lora_files_16k_1, test_dir)
shutil.copytree(sql_lora_files, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"

View File

@ -3014,12 +3014,7 @@ class LoRAConfig:
(added to the base model vocabulary)."""
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras: Optional[dict[str, str]] = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
@ -3052,7 +3047,6 @@ class LoRAConfig:
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
@ -3091,11 +3085,6 @@ class LoRAConfig:
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
def verify_lora_support(self):
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
raise ValueError(
"V1 LoRA does not support long LoRA, please use V0.")
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@ -4564,7 +4553,6 @@ class VllmConfig:
if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_lora_support()
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)

View File

@ -358,8 +358,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
@ -723,8 +721,6 @@ class EngineArgs:
"--lora-dtype",
**lora_kwargs["lora_dtype"],
)
lora_group.add_argument("--long-lora-scaling-factors",
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument("--max-cpu-loras",
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument("--fully-sharded-loras",
@ -1245,7 +1241,6 @@ class EngineArgs:
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

View File

@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform
@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
) -> bool:
# Special handling for the LogitsProcessor.
return False
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialized kernel.
"""
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
@property
def scaling_factors(self):
return self.base_layer.scaling_factors
@property
def rotary_dim(self):
return self.base_layer.rotary_dim
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = (list(lora_config.long_lora_scaling_factors)
if lora_config.long_lora_scaling_factors else [])
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = LinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
def reset_lora(self, index: int):
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
...
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
key,
offsets=self.punica_wrapper.long_lora_indices,
)
@property
def scaling_factor_to_offset(self) -> dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return (type(source_layer) is LinearScalingRotaryEmbedding
or type(source_layer) is RotaryEmbedding)
def extra_repr(self) -> str:
return self.base_layer.extra_repr()

View File

@ -4,7 +4,6 @@
import math
import os
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
import regex as re
@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping)
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
@ -43,18 +40,6 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: list[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
@ -80,20 +65,16 @@ class LoRAModel(AdapterModel):
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (
lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}"
@ -192,10 +173,7 @@ class LoRAModel(AdapterModel):
for lora in loras.values():
lora.optimize()
return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager):
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: dict[float, int] = {}
super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager):
except ValueError:
pass
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora)
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager):
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
def remove_all_adapters(self):
@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager):
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module_name = 'logits_processor'
@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager):
self,
lora_id: int,
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager):
self._deactivate_adapter)
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
logger.debug("Adding lora. Model id: %d, "
"int id: %d", adapter.id, adapter.id)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
logger.debug("Adding lora. Model id: %d, "
"int id: %d", lora.id, lora.id)
if lora.id not in self._registered_adapters:
self._add_adapter(lora)
was_added = True

View File

@ -35,12 +35,9 @@ class PEFTHelper:
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# long context lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self) -> list[str]:
"""
@ -59,12 +56,6 @@ class PEFTHelper:
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))
@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":

View File

@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
class PunicaWrapperABC(ABC):
@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC):
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs,
) -> None:
"""
@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_num_batched_tokens,
dtype=torch.long,
device=device)
self._long_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
# 5 is the number of indices tensors.
# 4 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: list[Optional[int]] = [None] * 5
# embeddings_indices
self.indices_len: list[Optional[int]] = [None] * 4
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long,
@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
vocab_size,
extra_vocab_size,
self.device,
long_lora_context,
)
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,
@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
@property
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
def update_metadata(self, mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]], max_loras: int,
vocab_size: int, extra_vocab_size: int, **kwargs):
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
vocab_size, extra_vocab_size)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices)

View File

@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, Optional, Union, final
from typing import Optional, Union, final
import torch
@ -21,10 +21,6 @@ if HAS_TRITON:
from .punica_base import PunicaWrapperBase
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.models import LongContextLoRAContext
@final
class PunicaWrapperGPU(PunicaWrapperBase):
@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_prompts,
device=device)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
def update_metadata(self, mapping: LoRAMapping,
lora_index_to_id: list[Optional[int]], max_loras: int,
vocab_size: int, extra_vocab_size: int, **kwargs):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
vocab_size, extra_vocab_size)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)

View File

@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
from .punica_base import PunicaWrapperBase
@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
True)
@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
# Make sure we don't accidentally collect outside operations
xm.mark_step()
@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
vocab_size,
extra_vocab_size,
"cpu",
long_lora_context,
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape,
@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape,
dims=2).to(self.device)
if long_lora_offsets_tensor is not None:
self._long_lora_indices = self._pad_to_shape(
long_lora_offsets_tensor,
self._long_lora_indices.shape,
dims=1).to(self.device)
else:
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
dtype=torch.int32)
self._long_lora_indices = zeroed.to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,

View File

@ -8,7 +8,6 @@ import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
def compute_meta(
@ -49,9 +48,7 @@ def convert_mapping(
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], list[int]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
"""Converts LoRAMapping to index tensors.
Args:
@ -60,7 +57,6 @@ def convert_mapping(
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
@ -78,21 +74,14 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
embeddings_indices).
"""
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
@ -104,20 +93,13 @@ def convert_mapping(
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: list[Union[list[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
@ -136,11 +118,7 @@ def convert_mapping(
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
@ -148,17 +126,11 @@ def convert_mapping(
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)

View File

@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
}

View File

@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora_request.lora_int_id)
else:
dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules)
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora)