[V0 deprecation] Remove long context LoRA (#21169)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-07-19 17:15:41 +08:00 committed by GitHub
parent cf8cc32674
commit 1eaff27815
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 35 additions and 301 deletions

View File

@ -221,11 +221,6 @@ def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
@pytest.fixture @pytest.fixture
def llama_2_7b_engine_extra_embeddings(): def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True) cleanup_dist_env_and_memory(shutdown_ray=True)

View File

@ -38,8 +38,8 @@ ERROR_CASES = [
] ]
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): def test_peft_helper_pass(sql_lora_files, tmp_path):
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1, peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
max_position_embeddings=4096) max_position_embeddings=4096)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config) peft_helper.validate_legal(lora_config)
@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
"embed_tokens", "embed_tokens",
"lm_head", "lm_head",
] ]
assert peft_helper.context_length == 16384
assert peft_helper.vllm_max_position_embeddings == 4096 assert peft_helper.vllm_max_position_embeddings == 4096
assert peft_helper.vllm_long_context_scaling_factor == float(
math.ceil(peft_helper.context_length /
peft_helper.vllm_max_position_embeddings))
# test RSLoRA # test RSLoRA
rslora_config = dict(use_rslora=True) rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora" test_dir = tmp_path / "test_rslora"
shutil.copytree(long_context_lora_files_16k_1, test_dir) shutil.copytree(sql_lora_files, test_dir)
# Load and modify configuration # Load and modify configuration
config_path = test_dir / "adapter_config.json" config_path = test_dir / "adapter_config.json"

View File

@ -3014,12 +3014,7 @@ class LoRAConfig:
(added to the base model vocabulary).""" (added to the base model vocabulary)."""
lora_vocab_padding_size: ClassVar[int] = current_platform\ lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size() .get_lora_vocab_padding_size()
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras: Optional[dict[str, str]] = None default_mm_loras: Optional[dict[str, str]] = None
"""Dictionary mapping specific modalities to LoRA model paths; this field """Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a is only applicable to multimodal models and should be leveraged when a
@ -3052,7 +3047,6 @@ class LoRAConfig:
factors.append(self.lora_dtype) factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size) factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size) factors.append(self.lora_vocab_padding_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled) factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(), hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
@ -3091,11 +3085,6 @@ class LoRAConfig:
elif isinstance(self.lora_dtype, str): elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype) self.lora_dtype = getattr(torch, self.lora_dtype)
def verify_lora_support(self):
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
raise ValueError(
"V1 LoRA does not support long LoRA, please use V0.")
@config @config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@ -4564,7 +4553,6 @@ class VllmConfig:
if self.lora_config is not None: if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_lora_support()
if self.prompt_adapter_config is not None: if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config( self.prompt_adapter_config.verify_with_model_config(
self.model_config) self.model_config)

View File

@ -358,8 +358,6 @@ class EngineArgs:
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields # PromptAdapter fields
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
@ -723,8 +721,6 @@ class EngineArgs:
"--lora-dtype", "--lora-dtype",
**lora_kwargs["lora_dtype"], **lora_kwargs["lora_dtype"],
) )
lora_group.add_argument("--long-lora-scaling-factors",
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument("--max-cpu-loras", lora_group.add_argument("--max-cpu-loras",
**lora_kwargs["max_cpu_loras"]) **lora_kwargs["max_cpu_loras"])
lora_group.add_argument("--fully-sharded-loras", lora_group.add_argument("--fully-sharded-loras",
@ -1245,7 +1241,6 @@ class EngineArgs:
default_mm_loras=self.default_mm_loras, default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size, lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None

View File

@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
# yapf: enable # yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
) -> bool: ) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialized kernel.
"""
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
@property
def scaling_factors(self):
return self.base_layer.scaling_factors
@property
def rotary_dim(self):
return self.base_layer.rotary_dim
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = (list(lora_config.long_lora_scaling_factors)
if lora_config.long_lora_scaling_factors else [])
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = LinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
def reset_lora(self, index: int):
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
...
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
key,
offsets=self.punica_wrapper.long_lora_indices,
)
@property
def scaling_factor_to_offset(self) -> dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return (type(source_layer) is LinearScalingRotaryEmbedding
or type(source_layer) is RotaryEmbedding)
def extra_repr(self) -> str:
return self.base_layer.extra_repr()

View File

@ -4,7 +4,6 @@
import math import math
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import regex as re import regex as re
@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
remove_adapter, set_adapter_mapping) remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
@ -43,18 +40,6 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: list[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
def get_lora_id(): def get_lora_id():
global _GLOBAL_LORA_ID global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1 _GLOBAL_LORA_ID += 1
@ -80,20 +65,16 @@ class LoRAModel(AdapterModel):
lora_model_id: int, lora_model_id: int,
rank: int, rank: int,
loras: dict[str, LoRALayerWeights], loras: dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None: ) -> None:
""" """
Args: Args:
lora_model_id: The integer id for the lora model. lora_model_id: The integer id for the lora model.
rank: lora rank. rank: lora rank.
loras: module name -> weights for lora-replaced layers. loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
""" """
self.id = lora_model_id self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert ( assert (
lora_model_id lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}" > 0), f"a valid lora id should be greater than 0, got {self.id}"
@ -192,10 +173,7 @@ class LoRAModel(AdapterModel):
for lora in loras.values(): for lora in loras.values():
lora.optimize() lora.optimize()
return cls(lora_model_id, return cls(lora_model_id, peft_helper.r, loras)
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager):
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper( self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras) max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: dict[float, int] = {}
super().__init__(model) super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model) self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in" assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}." f" {self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = get_packed_modules_mapping(self.model) self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model # Used to indicate whether the model is a multimodal model
@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager):
except ValueError: except ValueError:
pass pass
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_adapter(self, lora: LoRAModel): def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora) self._create_merged_loras_inplace(lora)
self._registered_adapters[lora.id] = lora self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora)
def pin_adapter(self, lora_id: int) -> bool: def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache.""" """Pin a LoRAModel in the manager cache."""
@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager):
self.lora_slots + 1, self.lora_slots + 1,
self.vocab_size, self.vocab_size,
self.lora_config.lora_extra_vocab_size, self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
) )
def remove_all_adapters(self): def remove_all_adapters(self):
@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager):
from_layer(module, self.lora_slots, self.lora_config, from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config)) packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust # (yard1): TODO make this more robust
if "lm_head" in module_name: if "lm_head" in module_name:
logits_processor_module_name = 'logits_processor' logits_processor_module_name = 'logits_processor'
@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager):
self, self,
lora_id: int, lora_id: int,
rank: int, rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup.""" """Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor) model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules(): for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name) if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA) or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
or self._filter_unsupported_mm_module(module_name)): or self._filter_unsupported_mm_module(module_name)):
continue continue
parts = module_name.split(".") parts = module_name.split(".")
@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager):
self._deactivate_adapter) self._deactivate_adapter)
def add_adapter(self, adapter: LoRAModel) -> bool: def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug( logger.debug("Adding lora. Model id: %d, "
"Adding lora. Model id: %d, " "int id: %d", adapter.id, adapter.id)
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
return add_adapter(adapter, self._registered_adapters, self.capacity, return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter) self._add_adapter)
@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def add_adapter(self, lora: LoRAModel) -> bool: def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager.""" """Add a LoRAModel to the manager."""
logger.debug( logger.debug("Adding lora. Model id: %d, "
"Adding lora. Model id: %d, " "int id: %d", lora.id, lora.id)
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_adapters: if lora.id not in self._registered_adapters:
self._add_adapter(lora) self._add_adapter(lora)
was_added = True was_added = True

View File

@ -35,12 +35,9 @@ class PEFTHelper:
use_rslora: bool = field(default=False) use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False) use_dora: bool = field(default=False)
# long context lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict # Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0) vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self) -> list[str]: def _validate_features(self) -> list[str]:
""" """
@ -59,12 +56,6 @@ class PEFTHelper:
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else: else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))
@classmethod @classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper": def from_dict(cls, config_dict: dict) -> "PEFTHelper":

View File

@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
# avoid circuit import # avoid circuit import
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
class PunicaWrapperABC(ABC): class PunicaWrapperABC(ABC):
@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC):
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_num_batched_tokens, max_num_batched_tokens,
dtype=torch.long, dtype=torch.long,
device=device) device=device)
self._long_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
# 5 is the number of indices tensors. # 4 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices # embeddings_indices
self.indices_len: list[Optional[int]] = [None] * 5 self.indices_len: list[Optional[int]] = [None] * 4
# these attributes are the information required for sgmv kernel # these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches, self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long, dtype=torch.long,
@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
): ):
( (
base_indices, base_indices,
sampler_indices, sampler_indices,
sampler_indices_padded, sampler_indices_padded,
embeddings_indices, embeddings_indices,
long_lora_offsets_tensor,
indices_len, indices_len,
) = convert_mapping( ) = convert_mapping(
mapping, mapping,
@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
vocab_size, vocab_size,
extra_vocab_size, extra_vocab_size,
self.device, self.device,
long_lora_context,
) )
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
self._embeddings_indices[:embeddings_indices. self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_( shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices) embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
def _update_prefill_metadata(self, def _update_prefill_metadata(self,
@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
embeddings_indices_len = self.indices_len[3] embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len] return self._embeddings_indices[:, :embeddings_indices_len]
@property def update_metadata(self, mapping: "LoRAMapping",
def long_lora_indices(self) -> torch.Tensor: lora_index_to_id: list[Optional[int]], max_loras: int,
""" vocab_size: int, extra_vocab_size: int, **kwargs):
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
self._update_base_metadata(mapping, lora_index_to_id, max_loras, self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size, vocab_size, extra_vocab_size)
long_lora_context)
if mapping.is_prefill: if mapping.is_prefill:
# Update metadata required for prefill-related operators. # Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices) self._update_prefill_metadata(self.token_lora_indices)

View File

@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import TYPE_CHECKING, Optional, Union, final from typing import Optional, Union, final
import torch import torch
@ -21,10 +21,6 @@ if HAS_TRITON:
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.models import LongContextLoRAContext
@final @final
class PunicaWrapperGPU(PunicaWrapperBase): class PunicaWrapperGPU(PunicaWrapperBase):
@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_prompts, max_num_prompts,
device=device) device=device)
def update_metadata( def update_metadata(self, mapping: LoRAMapping,
self, lora_index_to_id: list[Optional[int]], max_loras: int,
mapping: LoRAMapping, vocab_size: int, extra_vocab_size: int, **kwargs):
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
self.is_prefill = mapping.is_prefill self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras, self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size, vocab_size, extra_vocab_size)
long_lora_context)
# Prepare cuda kernel metadata tensors # Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices) self.token_mapping_meta.prepare_tensors(self.token_lora_indices)

View File

@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
# avoid circuit import # avoid circuit import
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
True) True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
True) True)
@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
): ):
# Make sure we don't accidentally collect outside operations # Make sure we don't accidentally collect outside operations
xm.mark_step() xm.mark_step()
@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
sampler_indices, sampler_indices,
sampler_indices_padded, sampler_indices_padded,
embeddings_indices, embeddings_indices,
long_lora_offsets_tensor,
indices_len, indices_len,
) = convert_mapping( ) = convert_mapping(
mapping, mapping,
@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
vocab_size, vocab_size,
extra_vocab_size, extra_vocab_size,
"cpu", "cpu",
long_lora_context,
) )
self._token_lora_indices = self._pad_to_shape( self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape, base_indices, self._token_lora_indices.shape,
@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
self._embeddings_indices = self._pad_to_shape( self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape, embeddings_indices, self._embeddings_indices.shape,
dims=2).to(self.device) dims=2).to(self.device)
if long_lora_offsets_tensor is not None:
self._long_lora_indices = self._pad_to_shape(
long_lora_offsets_tensor,
self._long_lora_indices.shape,
dims=1).to(self.device)
else:
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
dtype=torch.int32)
self._long_lora_indices = zeroed.to(self.device)
self.indices_len[:] = indices_len self.indices_len[:] = indices_len
def _update_prefill_metadata(self, def _update_prefill_metadata(self,

View File

@ -8,7 +8,6 @@ import torch
if TYPE_CHECKING: if TYPE_CHECKING:
# avoid circuit import # avoid circuit import
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
def compute_meta( def compute_meta(
@ -49,9 +48,7 @@ def convert_mapping(
vocab_size: int, vocab_size: int,
extra_vocab_size: int, extra_vocab_size: int,
device: torch.device, device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], list[int]]:
"""Converts LoRAMapping to index tensors. """Converts LoRAMapping to index tensors.
Args: Args:
@ -60,7 +57,6 @@ def convert_mapping(
max_loras: Maximum number of LoRAs. max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size. vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have. extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns: Returns:
A tuple of tensors: A tuple of tensors:
@ -78,21 +74,14 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a added by the LoRAs, second row is for the LoRA.lora_a
embeddings. embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded, (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). embeddings_indices).
""" """
index_mapping_indices: list[int] = list(mapping.index_mapping).copy() index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy() embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: list[int] = [ prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1 lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping for x in mapping.prompt_mapping
@ -104,20 +93,13 @@ def convert_mapping(
if index_mapping_indices[i] > 0 else -1) if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: list[Union[list[int], torch.Tensor]] = [ indices_list: list[Union[list[int], torch.Tensor]] = [
index_mapping_indices, index_mapping_indices,
lora_indices, lora_indices,
embedding_indices, embedding_indices,
] ]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device) indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping, prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long, dtype=torch.long,
@ -136,11 +118,7 @@ def convert_mapping(
sampler_indices_padded = torch.arange( sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded)) sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor. # Contain length of indices tensors. Used to index into each tensor.
indices_len = [ indices_len = [
base_indices.shape[-1], base_indices.shape[-1],
@ -148,17 +126,11 @@ def convert_mapping(
sampler_indices_padded.shape[-1], sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1], embeddings_indices.shape[-1],
] ]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return ( return (
base_indices, base_indices,
sampler_indices, sampler_indices,
sampler_indices_padded, sampler_indices_padded,
embeddings_indices, embeddings_indices,
long_lora_indices,
indices_len, indices_len,
) )

View File

@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
} }

View File

@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora_request.lora_int_id) lora_request.lora_int_id)
else: else:
dummy_lora = self._adapter_manager.create_dummy_lora( dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules) lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None: if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora) return self._adapter_manager.add_adapter(dummy_lora)