mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
[V0 deprecation] Remove long context LoRA (#21169)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
cf8cc32674
commit
1eaff27815
@ -221,11 +221,6 @@ def phi2_lora_files():
|
||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_16k_1():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
@ -38,8 +38,8 @@ ERROR_CASES = [
|
||||
]
|
||||
|
||||
|
||||
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
|
||||
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
|
||||
def test_peft_helper_pass(sql_lora_files, tmp_path):
|
||||
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||
peft_helper.validate_legal(lora_config)
|
||||
@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
assert peft_helper.context_length == 16384
|
||||
assert peft_helper.vllm_max_position_embeddings == 4096
|
||||
assert peft_helper.vllm_long_context_scaling_factor == float(
|
||||
math.ceil(peft_helper.context_length /
|
||||
peft_helper.vllm_max_position_embeddings))
|
||||
|
||||
# test RSLoRA
|
||||
rslora_config = dict(use_rslora=True)
|
||||
test_dir = tmp_path / "test_rslora"
|
||||
shutil.copytree(long_context_lora_files_16k_1, test_dir)
|
||||
shutil.copytree(sql_lora_files, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
|
||||
@ -3014,12 +3014,7 @@ class LoRAConfig:
|
||||
(added to the base model vocabulary)."""
|
||||
lora_vocab_padding_size: ClassVar[int] = current_platform\
|
||||
.get_lora_vocab_padding_size()
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
||||
"""Specify multiple scaling factors (which can be different from base model
|
||||
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
||||
trained with those scaling factors to be used at the same time. If not
|
||||
specified, only adapters trained with the base model scaling factor are
|
||||
allowed."""
|
||||
|
||||
default_mm_loras: Optional[dict[str, str]] = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
@ -3052,7 +3047,6 @@ class LoRAConfig:
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.lora_extra_vocab_size)
|
||||
factors.append(self.lora_vocab_padding_size)
|
||||
factors.append(self.long_lora_scaling_factors)
|
||||
factors.append(self.bias_enabled)
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
@ -3091,11 +3085,6 @@ class LoRAConfig:
|
||||
elif isinstance(self.lora_dtype, str):
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
|
||||
def verify_lora_support(self):
|
||||
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"V1 LoRA does not support long LoRA, please use V0.")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@ -4564,7 +4553,6 @@ class VllmConfig:
|
||||
if self.lora_config is not None:
|
||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_lora_support()
|
||||
if self.prompt_adapter_config is not None:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
@ -358,8 +358,6 @@ class EngineArgs:
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
|
||||
LoRAConfig.long_lora_scaling_factors
|
||||
# PromptAdapter fields
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||
@ -723,8 +721,6 @@ class EngineArgs:
|
||||
"--lora-dtype",
|
||||
**lora_kwargs["lora_dtype"],
|
||||
)
|
||||
lora_group.add_argument("--long-lora-scaling-factors",
|
||||
**lora_kwargs["long_lora_scaling_factors"])
|
||||
lora_group.add_argument("--max-cpu-loras",
|
||||
**lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument("--fully-sharded-loras",
|
||||
@ -1245,7 +1241,6 @@ class EngineArgs:
|
||||
default_mm_loras=self.default_mm_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
|
||||
@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
|
||||
|
||||
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
"""Implements RoPE-scaled embeddings with linear scaling for
|
||||
multiple LoRA adapters with a specialized kernel.
|
||||
|
||||
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
|
||||
which can handle multi lora adapters in a specialized kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: RotaryEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
@property
|
||||
def scaling_factors(self):
|
||||
return self.base_layer.scaling_factors
|
||||
|
||||
@property
|
||||
def rotary_dim(self):
|
||||
return self.base_layer.rotary_dim
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
||||
if lora_config.long_lora_scaling_factors else [])
|
||||
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
||||
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
|
||||
scaling_factors = sorted(
|
||||
list(set([base_scaling_factor] + scaling_factors)))
|
||||
self.base_layer = LinearScalingRotaryEmbedding(
|
||||
self.base_layer.head_size,
|
||||
self.base_layer.rotary_dim,
|
||||
self.base_layer.max_position_embeddings,
|
||||
self.base_layer.base,
|
||||
self.base_layer.is_neox_style,
|
||||
scaling_factors,
|
||||
self.base_layer.dtype,
|
||||
)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.base_layer(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets=self.punica_wrapper.long_lora_indices,
|
||||
)
|
||||
|
||||
@property
|
||||
def scaling_factor_to_offset(self) -> dict[float, int]:
|
||||
return self.base_layer.scaling_factor_to_offset
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
return (type(source_layer) is LinearScalingRotaryEmbedding
|
||||
or type(source_layer) is RotaryEmbedding)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return self.base_layer.extra_repr()
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import regex as re
|
||||
@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLoRA,
|
||||
LoRAMapping)
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
@ -43,18 +40,6 @@ logger = init_logger(__name__)
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongContextLoRAContext:
|
||||
"""Context for lora adapters that support long context."""
|
||||
# The scaling factors to support long context lora fine tuned models.
|
||||
scaling_factors: list[float]
|
||||
# dimension to apply rotary embedding.
|
||||
rot_dim: int
|
||||
# offsets to the sin_cos_cache for each lora_id loaded.
|
||||
# This value is dynamically modified.
|
||||
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def get_lora_id():
|
||||
global _GLOBAL_LORA_ID
|
||||
_GLOBAL_LORA_ID += 1
|
||||
@ -80,20 +65,16 @@ class LoRAModel(AdapterModel):
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
loras: dict[str, LoRALayerWeights],
|
||||
scaling_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lora_model_id: The integer id for the lora model.
|
||||
rank: lora rank.
|
||||
loras: module name -> weights for lora-replaced layers.
|
||||
scaling_factor: Scaling factor to support long context lora model.
|
||||
None if the lora is not tuned for long context support.
|
||||
|
||||
"""
|
||||
self.id = lora_model_id
|
||||
# Scaling factor for long context lora model. None if it is not
|
||||
# fine tuned for the long context.
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
assert (
|
||||
lora_model_id
|
||||
> 0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||
@ -192,10 +173,7 @@ class LoRAModel(AdapterModel):
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
|
||||
return cls(lora_model_id,
|
||||
peft_helper.r,
|
||||
loras,
|
||||
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
|
||||
return cls(lora_model_id, peft_helper.r, loras)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras)
|
||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: dict[float, int] = {}
|
||||
|
||||
super().__init__(model)
|
||||
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f" {self.model.__class__.__name__}."
|
||||
if lora_config.long_lora_scaling_factors:
|
||||
# We need to replace rotary emb layer to do batch computation
|
||||
# for long lora.
|
||||
self.supported_lora_modules.append("rotary_emb")
|
||||
|
||||
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _set_long_lora_context(self, lora: LoRAModel):
|
||||
if self.long_lora_context is None:
|
||||
return
|
||||
|
||||
if lora.scaling_factor is None:
|
||||
return
|
||||
|
||||
if (lora.scaling_factor not in self.scaling_factor_to_offset):
|
||||
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
|
||||
" has not been initialized.")
|
||||
|
||||
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
|
||||
if offsets:
|
||||
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
||||
|
||||
def _add_adapter(self, lora: LoRAModel):
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_adapters[lora.id] = lora
|
||||
self._set_long_lora_context(lora)
|
||||
|
||||
def pin_adapter(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager):
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
packed_moduled_lst, self.model.config))
|
||||
|
||||
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
|
||||
# long context lora. Register relevant metadata.
|
||||
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
|
||||
self.long_lora_context = LongContextLoRAContext(
|
||||
new_module.scaling_factors, new_module.rotary_dim)
|
||||
self.scaling_factor_to_offset = \
|
||||
new_module.scaling_factor_to_offset
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
logits_processor_module_name = 'logits_processor'
|
||||
@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self,
|
||||
lora_id: int,
|
||||
rank: int,
|
||||
scaling_factor: Optional[float],
|
||||
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||
model = LoRAModel(lora_id, rank, {})
|
||||
for module_name, module in self.model.named_modules():
|
||||
bias_enabled = self.lora_config.bias_enabled
|
||||
if (not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
|
||||
or self._filter_unsupported_mm_module(module_name)):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self._deactivate_adapter)
|
||||
|
||||
def add_adapter(self, adapter: LoRAModel) -> bool:
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", adapter.id, adapter.id,
|
||||
adapter.scaling_factor)
|
||||
logger.debug("Adding lora. Model id: %d, "
|
||||
"int id: %d", adapter.id, adapter.id)
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
|
||||
def add_adapter(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager."""
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||
logger.debug("Adding lora. Model id: %d, "
|
||||
"int id: %d", lora.id, lora.id)
|
||||
if lora.id not in self._registered_adapters:
|
||||
self._add_adapter(lora)
|
||||
was_added = True
|
||||
|
||||
@ -35,12 +35,9 @@ class PEFTHelper:
|
||||
use_rslora: bool = field(default=False)
|
||||
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
||||
use_dora: bool = field(default=False)
|
||||
# long context lora field
|
||||
context_length: int = field(default=0)
|
||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||
vllm_lora_scaling_factor: float = field(default=1.0)
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self) -> list[str]:
|
||||
"""
|
||||
@ -59,12 +56,6 @@ class PEFTHelper:
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||
else:
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||
if self.context_length:
|
||||
if self.vllm_max_position_embeddings is None:
|
||||
self.vllm_max_position_embeddings = self.context_length
|
||||
self.vllm_long_context_scaling_factor = float(
|
||||
math.ceil(self.context_length /
|
||||
self.vllm_max_position_embeddings))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
|
||||
|
||||
@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
class PunicaWrapperABC(ABC):
|
||||
@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC):
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._long_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indices tensors.
|
||||
# 4 is the number of indices tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: list[Optional[int]] = [None] * 5
|
||||
# embeddings_indices
|
||||
self.indices_len: list[Optional[int]] = [None] * 4
|
||||
# these attributes are the information required for sgmv kernel
|
||||
self._seq_start_locs = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
self.device,
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
self._embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self._long_lora_indices.zero_()
|
||||
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self,
|
||||
@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
embeddings_indices_len = self.indices_len[3]
|
||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||
|
||||
@property
|
||||
def long_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for long context
|
||||
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
|
||||
"""
|
||||
long_lora_len = self.indices_len[4]
|
||||
return self._long_lora_indices[:long_lora_len]
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
**kwargs):
|
||||
def update_metadata(self, mapping: "LoRAMapping",
|
||||
lora_index_to_id: list[Optional[int]], max_loras: int,
|
||||
vocab_size: int, extra_vocab_size: int, **kwargs):
|
||||
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
vocab_size, extra_vocab_size)
|
||||
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
self._update_prefill_metadata(self.token_lora_indices)
|
||||
|
||||
@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union, final
|
||||
from typing import Optional, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
@ -21,10 +21,6 @@ if HAS_TRITON:
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_num_prompts,
|
||||
device=device)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: LoRAMapping,
|
||||
lora_index_to_id: list[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
**kwargs):
|
||||
def update_metadata(self, mapping: LoRAMapping,
|
||||
lora_index_to_id: list[Optional[int]], max_loras: int,
|
||||
vocab_size: int, extra_vocab_size: int, **kwargs):
|
||||
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
vocab_size, extra_vocab_size)
|
||||
|
||||
# Prepare cuda kernel metadata tensors
|
||||
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
||||
|
||||
@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
|
||||
True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
|
||||
True)
|
||||
|
||||
@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
xm.mark_step()
|
||||
@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
"cpu",
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices = self._pad_to_shape(
|
||||
base_indices, self._token_lora_indices.shape,
|
||||
@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
self._embeddings_indices = self._pad_to_shape(
|
||||
embeddings_indices, self._embeddings_indices.shape,
|
||||
dims=2).to(self.device)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices = self._pad_to_shape(
|
||||
long_lora_offsets_tensor,
|
||||
self._long_lora_indices.shape,
|
||||
dims=1).to(self.device)
|
||||
else:
|
||||
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
|
||||
dtype=torch.int32)
|
||||
self._long_lora_indices = zeroed.to(self.device)
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metadata(self,
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
def compute_meta(
|
||||
@ -49,9 +48,7 @@ def convert_mapping(
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
device: torch.device,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], list[int]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
@ -60,7 +57,6 @@ def convert_mapping(
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
long_lora_context: Passed if there are long context lora in a batch.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
@ -78,21 +74,14 @@ def convert_mapping(
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
long_lora_indices: Tensor of shape [batch_size] mapping
|
||||
requests to RoPE offsets and rot dims for long LoRAs.
|
||||
None if long context lora doesn't exist.
|
||||
indices_len: List of lengths of the above tensors. It contains
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_indices).
|
||||
embeddings_indices).
|
||||
"""
|
||||
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
long_lora_offsets: Optional[torch.Tensor] = None
|
||||
if long_lora_context:
|
||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
|
||||
prompt_mapping: list[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
@ -104,20 +93,13 @@ def convert_mapping(
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
lora_indices[i] = lora_idx
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||
index_mapping_indices[i], 0)
|
||||
long_lora_offsets[i] = lora_offset
|
||||
|
||||
indices_list: list[Union[list[int], torch.Tensor]] = [
|
||||
index_mapping_indices,
|
||||
lora_indices,
|
||||
embedding_indices,
|
||||
]
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
indices_list.append(long_lora_offsets)
|
||||
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
dtype=torch.long,
|
||||
@ -136,11 +118,7 @@ def convert_mapping(
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||
sampler_indices_padded * len(sampler_indices_padded))
|
||||
long_lora_indices = None
|
||||
long_lora_indices_len: Optional[int] = None
|
||||
if long_lora_context:
|
||||
long_lora_indices = indices[3]
|
||||
long_lora_indices_len = long_lora_indices.shape[-1]
|
||||
|
||||
# Contain length of indices tensors. Used to index into each tensor.
|
||||
indices_len = [
|
||||
base_indices.shape[-1],
|
||||
@ -148,17 +126,11 @@ def convert_mapping(
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1],
|
||||
]
|
||||
if long_lora_indices_len is not None:
|
||||
indices_len.append(long_lora_indices_len)
|
||||
else:
|
||||
# If long_lora doesn't exist,append None
|
||||
indices_len.append(None)
|
||||
|
||||
return (
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_indices,
|
||||
indices_len,
|
||||
)
|
||||
|
||||
@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import (
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLoRA,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
lora_request.lora_int_id)
|
||||
else:
|
||||
dummy_lora = self._adapter_manager.create_dummy_lora(
|
||||
lora_request.lora_int_id, rank, 1, self.embedding_modules)
|
||||
lora_request.lora_int_id, rank, self.embedding_modules)
|
||||
if self._cached_dummy_lora is None:
|
||||
self._cached_dummy_lora = dummy_lora
|
||||
return self._adapter_manager.add_adapter(dummy_lora)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user