mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:57:10 +08:00
[V0 deprecation] Remove long context LoRA (#21169)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
cf8cc32674
commit
1eaff27815
@ -221,11 +221,6 @@ def phi2_lora_files():
|
|||||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def long_context_lora_files_16k_1():
|
|
||||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llama_2_7b_engine_extra_embeddings():
|
def llama_2_7b_engine_extra_embeddings():
|
||||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||||
|
|||||||
@ -38,8 +38,8 @@ ERROR_CASES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
|
def test_peft_helper_pass(sql_lora_files, tmp_path):
|
||||||
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
|
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
|
||||||
max_position_embeddings=4096)
|
max_position_embeddings=4096)
|
||||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||||
peft_helper.validate_legal(lora_config)
|
peft_helper.validate_legal(lora_config)
|
||||||
@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
|
|||||||
"embed_tokens",
|
"embed_tokens",
|
||||||
"lm_head",
|
"lm_head",
|
||||||
]
|
]
|
||||||
assert peft_helper.context_length == 16384
|
|
||||||
assert peft_helper.vllm_max_position_embeddings == 4096
|
assert peft_helper.vllm_max_position_embeddings == 4096
|
||||||
assert peft_helper.vllm_long_context_scaling_factor == float(
|
|
||||||
math.ceil(peft_helper.context_length /
|
|
||||||
peft_helper.vllm_max_position_embeddings))
|
|
||||||
# test RSLoRA
|
# test RSLoRA
|
||||||
rslora_config = dict(use_rslora=True)
|
rslora_config = dict(use_rslora=True)
|
||||||
test_dir = tmp_path / "test_rslora"
|
test_dir = tmp_path / "test_rslora"
|
||||||
shutil.copytree(long_context_lora_files_16k_1, test_dir)
|
shutil.copytree(sql_lora_files, test_dir)
|
||||||
|
|
||||||
# Load and modify configuration
|
# Load and modify configuration
|
||||||
config_path = test_dir / "adapter_config.json"
|
config_path = test_dir / "adapter_config.json"
|
||||||
|
|||||||
@ -3014,12 +3014,7 @@ class LoRAConfig:
|
|||||||
(added to the base model vocabulary)."""
|
(added to the base model vocabulary)."""
|
||||||
lora_vocab_padding_size: ClassVar[int] = current_platform\
|
lora_vocab_padding_size: ClassVar[int] = current_platform\
|
||||||
.get_lora_vocab_padding_size()
|
.get_lora_vocab_padding_size()
|
||||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
|
||||||
"""Specify multiple scaling factors (which can be different from base model
|
|
||||||
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
|
||||||
trained with those scaling factors to be used at the same time. If not
|
|
||||||
specified, only adapters trained with the base model scaling factor are
|
|
||||||
allowed."""
|
|
||||||
default_mm_loras: Optional[dict[str, str]] = None
|
default_mm_loras: Optional[dict[str, str]] = None
|
||||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||||
is only applicable to multimodal models and should be leveraged when a
|
is only applicable to multimodal models and should be leveraged when a
|
||||||
@ -3052,7 +3047,6 @@ class LoRAConfig:
|
|||||||
factors.append(self.lora_dtype)
|
factors.append(self.lora_dtype)
|
||||||
factors.append(self.lora_extra_vocab_size)
|
factors.append(self.lora_extra_vocab_size)
|
||||||
factors.append(self.lora_vocab_padding_size)
|
factors.append(self.lora_vocab_padding_size)
|
||||||
factors.append(self.long_lora_scaling_factors)
|
|
||||||
factors.append(self.bias_enabled)
|
factors.append(self.bias_enabled)
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
@ -3091,11 +3085,6 @@ class LoRAConfig:
|
|||||||
elif isinstance(self.lora_dtype, str):
|
elif isinstance(self.lora_dtype, str):
|
||||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||||
|
|
||||||
def verify_lora_support(self):
|
|
||||||
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
|
|
||||||
raise ValueError(
|
|
||||||
"V1 LoRA does not support long LoRA, please use V0.")
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
@ -4564,7 +4553,6 @@ class VllmConfig:
|
|||||||
if self.lora_config is not None:
|
if self.lora_config is not None:
|
||||||
self.lora_config.verify_with_cache_config(self.cache_config)
|
self.lora_config.verify_with_cache_config(self.cache_config)
|
||||||
self.lora_config.verify_with_model_config(self.model_config)
|
self.lora_config.verify_with_model_config(self.model_config)
|
||||||
self.lora_config.verify_lora_support()
|
|
||||||
if self.prompt_adapter_config is not None:
|
if self.prompt_adapter_config is not None:
|
||||||
self.prompt_adapter_config.verify_with_model_config(
|
self.prompt_adapter_config.verify_with_model_config(
|
||||||
self.model_config)
|
self.model_config)
|
||||||
|
|||||||
@ -358,8 +358,6 @@ class EngineArgs:
|
|||||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
|
|
||||||
LoRAConfig.long_lora_scaling_factors
|
|
||||||
# PromptAdapter fields
|
# PromptAdapter fields
|
||||||
enable_prompt_adapter: bool = False
|
enable_prompt_adapter: bool = False
|
||||||
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||||
@ -723,8 +721,6 @@ class EngineArgs:
|
|||||||
"--lora-dtype",
|
"--lora-dtype",
|
||||||
**lora_kwargs["lora_dtype"],
|
**lora_kwargs["lora_dtype"],
|
||||||
)
|
)
|
||||||
lora_group.add_argument("--long-lora-scaling-factors",
|
|
||||||
**lora_kwargs["long_lora_scaling_factors"])
|
|
||||||
lora_group.add_argument("--max-cpu-loras",
|
lora_group.add_argument("--max-cpu-loras",
|
||||||
**lora_kwargs["max_cpu_loras"])
|
**lora_kwargs["max_cpu_loras"])
|
||||||
lora_group.add_argument("--fully-sharded-loras",
|
lora_group.add_argument("--fully-sharded-loras",
|
||||||
@ -1245,7 +1241,6 @@ class EngineArgs:
|
|||||||
default_mm_loras=self.default_mm_loras,
|
default_mm_loras=self.default_mm_loras,
|
||||||
fully_sharded_loras=self.fully_sharded_loras,
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
|
||||||
lora_dtype=self.lora_dtype,
|
lora_dtype=self.lora_dtype,
|
||||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||||
|
|||||||
@ -28,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
|
||||||
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -1193,91 +1191,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
# Special handling for the LogitsProcessor.
|
# Special handling for the LogitsProcessor.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|
||||||
"""Implements RoPE-scaled embeddings with linear scaling for
|
|
||||||
multiple LoRA adapters with a specialized kernel.
|
|
||||||
|
|
||||||
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
|
|
||||||
which can handle multi lora adapters in a specialized kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_layer: RotaryEmbedding) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.base_layer = base_layer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def scaling_factors(self):
|
|
||||||
return self.base_layer.scaling_factors
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rotary_dim(self):
|
|
||||||
return self.base_layer.rotary_dim
|
|
||||||
|
|
||||||
def create_lora_weights(
|
|
||||||
self,
|
|
||||||
max_loras: int,
|
|
||||||
lora_config: LoRAConfig,
|
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
|
||||||
if lora_config.long_lora_scaling_factors else [])
|
|
||||||
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
|
||||||
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
|
|
||||||
scaling_factors = sorted(
|
|
||||||
list(set([base_scaling_factor] + scaling_factors)))
|
|
||||||
self.base_layer = LinearScalingRotaryEmbedding(
|
|
||||||
self.base_layer.head_size,
|
|
||||||
self.base_layer.rotary_dim,
|
|
||||||
self.base_layer.max_position_embeddings,
|
|
||||||
self.base_layer.base,
|
|
||||||
self.base_layer.is_neox_style,
|
|
||||||
scaling_factors,
|
|
||||||
self.base_layer.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
|
||||||
...
|
|
||||||
|
|
||||||
def set_lora(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
lora_a: torch.Tensor,
|
|
||||||
lora_b: torch.Tensor,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
...
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return self.base_layer(
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
offsets=self.punica_wrapper.long_lora_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def scaling_factor_to_offset(self) -> dict[float, int]:
|
|
||||||
return self.base_layer.scaling_factor_to_offset
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def can_replace_layer(
|
|
||||||
cls,
|
|
||||||
source_layer: nn.Module,
|
|
||||||
lora_config: LoRAConfig,
|
|
||||||
packed_modules_list: list,
|
|
||||||
model_config: Optional[PretrainedConfig],
|
|
||||||
) -> bool:
|
|
||||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
|
||||||
return (type(source_layer) is LinearScalingRotaryEmbedding
|
|
||||||
or type(source_layer) is RotaryEmbedding)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return self.base_layer.extra_repr()
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -19,9 +18,7 @@ from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
|||||||
remove_adapter, set_adapter_mapping)
|
remove_adapter, set_adapter_mapping)
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||||
LinearScalingRotaryEmbeddingWithLoRA,
|
|
||||||
LoRAMapping)
|
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.peft_helper import PEFTHelper
|
from vllm.lora.peft_helper import PEFTHelper
|
||||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||||
@ -43,18 +40,6 @@ logger = init_logger(__name__)
|
|||||||
_GLOBAL_LORA_ID = 0
|
_GLOBAL_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LongContextLoRAContext:
|
|
||||||
"""Context for lora adapters that support long context."""
|
|
||||||
# The scaling factors to support long context lora fine tuned models.
|
|
||||||
scaling_factors: list[float]
|
|
||||||
# dimension to apply rotary embedding.
|
|
||||||
rot_dim: int
|
|
||||||
# offsets to the sin_cos_cache for each lora_id loaded.
|
|
||||||
# This value is dynamically modified.
|
|
||||||
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_id():
|
def get_lora_id():
|
||||||
global _GLOBAL_LORA_ID
|
global _GLOBAL_LORA_ID
|
||||||
_GLOBAL_LORA_ID += 1
|
_GLOBAL_LORA_ID += 1
|
||||||
@ -80,20 +65,16 @@ class LoRAModel(AdapterModel):
|
|||||||
lora_model_id: int,
|
lora_model_id: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
loras: dict[str, LoRALayerWeights],
|
loras: dict[str, LoRALayerWeights],
|
||||||
scaling_factor: Optional[float] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lora_model_id: The integer id for the lora model.
|
lora_model_id: The integer id for the lora model.
|
||||||
rank: lora rank.
|
rank: lora rank.
|
||||||
loras: module name -> weights for lora-replaced layers.
|
loras: module name -> weights for lora-replaced layers.
|
||||||
scaling_factor: Scaling factor to support long context lora model.
|
|
||||||
None if the lora is not tuned for long context support.
|
|
||||||
"""
|
"""
|
||||||
self.id = lora_model_id
|
self.id = lora_model_id
|
||||||
# Scaling factor for long context lora model. None if it is not
|
|
||||||
# fine tuned for the long context.
|
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
assert (
|
assert (
|
||||||
lora_model_id
|
lora_model_id
|
||||||
> 0), f"a valid lora id should be greater than 0, got {self.id}"
|
> 0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||||
@ -192,10 +173,7 @@ class LoRAModel(AdapterModel):
|
|||||||
for lora in loras.values():
|
for lora in loras.values():
|
||||||
lora.optimize()
|
lora.optimize()
|
||||||
|
|
||||||
return cls(lora_model_id,
|
return cls(lora_model_id, peft_helper.r, loras)
|
||||||
peft_helper.r,
|
|
||||||
loras,
|
|
||||||
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_checkpoint(
|
def from_local_checkpoint(
|
||||||
@ -360,24 +338,17 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||||
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
|
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
|
||||||
self.punica_wrapper = get_punica_wrapper(
|
self.punica_wrapper = get_punica_wrapper(
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_batches=self.max_num_seqs,
|
max_batches=self.max_num_seqs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras)
|
max_loras=self.lora_config.max_loras)
|
||||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
|
||||||
# Used for long context lora.
|
|
||||||
self.scaling_factor_to_offset: dict[float, int] = {}
|
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||||
f" {self.model.__class__.__name__}."
|
f" {self.model.__class__.__name__}."
|
||||||
if lora_config.long_lora_scaling_factors:
|
|
||||||
# We need to replace rotary emb layer to do batch computation
|
|
||||||
# for long lora.
|
|
||||||
self.supported_lora_modules.append("rotary_emb")
|
|
||||||
|
|
||||||
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
||||||
# Used to indicate whether the model is a multimodal model
|
# Used to indicate whether the model is a multimodal model
|
||||||
@ -454,25 +425,9 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _set_long_lora_context(self, lora: LoRAModel):
|
|
||||||
if self.long_lora_context is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if lora.scaling_factor is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (lora.scaling_factor not in self.scaling_factor_to_offset):
|
|
||||||
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
|
|
||||||
" has not been initialized.")
|
|
||||||
|
|
||||||
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
|
|
||||||
if offsets:
|
|
||||||
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
|
||||||
|
|
||||||
def _add_adapter(self, lora: LoRAModel):
|
def _add_adapter(self, lora: LoRAModel):
|
||||||
self._create_merged_loras_inplace(lora)
|
self._create_merged_loras_inplace(lora)
|
||||||
self._registered_adapters[lora.id] = lora
|
self._registered_adapters[lora.id] = lora
|
||||||
self._set_long_lora_context(lora)
|
|
||||||
|
|
||||||
def pin_adapter(self, lora_id: int) -> bool:
|
def pin_adapter(self, lora_id: int) -> bool:
|
||||||
"""Pin a LoRAModel in the manager cache."""
|
"""Pin a LoRAModel in the manager cache."""
|
||||||
@ -488,7 +443,6 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.lora_slots + 1,
|
self.lora_slots + 1,
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
self.lora_config.lora_extra_vocab_size,
|
self.lora_config.lora_extra_vocab_size,
|
||||||
self.long_lora_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_all_adapters(self):
|
def remove_all_adapters(self):
|
||||||
@ -528,13 +482,6 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
from_layer(module, self.lora_slots, self.lora_config,
|
from_layer(module, self.lora_slots, self.lora_config,
|
||||||
packed_moduled_lst, self.model.config))
|
packed_moduled_lst, self.model.config))
|
||||||
|
|
||||||
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
|
|
||||||
# long context lora. Register relevant metadata.
|
|
||||||
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
|
|
||||||
self.long_lora_context = LongContextLoRAContext(
|
|
||||||
new_module.scaling_factors, new_module.rotary_dim)
|
|
||||||
self.scaling_factor_to_offset = \
|
|
||||||
new_module.scaling_factor_to_offset
|
|
||||||
# (yard1): TODO make this more robust
|
# (yard1): TODO make this more robust
|
||||||
if "lm_head" in module_name:
|
if "lm_head" in module_name:
|
||||||
logits_processor_module_name = 'logits_processor'
|
logits_processor_module_name = 'logits_processor'
|
||||||
@ -574,15 +521,13 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self,
|
self,
|
||||||
lora_id: int,
|
lora_id: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
scaling_factor: Optional[float],
|
|
||||||
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
|
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
|
||||||
"""Create zero-initialized LoRAModel for warmup."""
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
model = LoRAModel(lora_id, rank, {})
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules():
|
||||||
bias_enabled = self.lora_config.bias_enabled
|
bias_enabled = self.lora_config.bias_enabled
|
||||||
if (not self._match_target_modules(module_name)
|
if (not self._match_target_modules(module_name)
|
||||||
or not isinstance(module, BaseLayerWithLoRA)
|
or not isinstance(module, BaseLayerWithLoRA)
|
||||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
|
|
||||||
or self._filter_unsupported_mm_module(module_name)):
|
or self._filter_unsupported_mm_module(module_name)):
|
||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
@ -723,11 +668,8 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self._deactivate_adapter)
|
self._deactivate_adapter)
|
||||||
|
|
||||||
def add_adapter(self, adapter: LoRAModel) -> bool:
|
def add_adapter(self, adapter: LoRAModel) -> bool:
|
||||||
logger.debug(
|
logger.debug("Adding lora. Model id: %d, "
|
||||||
"Adding lora. Model id: %d, "
|
"int id: %d", adapter.id, adapter.id)
|
||||||
"int id: %d, "
|
|
||||||
"scaling factor: %s", adapter.id, adapter.id,
|
|
||||||
adapter.scaling_factor)
|
|
||||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||||
self._add_adapter)
|
self._add_adapter)
|
||||||
|
|
||||||
@ -772,10 +714,8 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||||||
|
|
||||||
def add_adapter(self, lora: LoRAModel) -> bool:
|
def add_adapter(self, lora: LoRAModel) -> bool:
|
||||||
"""Add a LoRAModel to the manager."""
|
"""Add a LoRAModel to the manager."""
|
||||||
logger.debug(
|
logger.debug("Adding lora. Model id: %d, "
|
||||||
"Adding lora. Model id: %d, "
|
"int id: %d", lora.id, lora.id)
|
||||||
"int id: %d, "
|
|
||||||
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
|
||||||
if lora.id not in self._registered_adapters:
|
if lora.id not in self._registered_adapters:
|
||||||
self._add_adapter(lora)
|
self._add_adapter(lora)
|
||||||
was_added = True
|
was_added = True
|
||||||
|
|||||||
@ -35,12 +35,9 @@ class PEFTHelper:
|
|||||||
use_rslora: bool = field(default=False)
|
use_rslora: bool = field(default=False)
|
||||||
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
||||||
use_dora: bool = field(default=False)
|
use_dora: bool = field(default=False)
|
||||||
# long context lora field
|
|
||||||
context_length: int = field(default=0)
|
|
||||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||||
vllm_lora_scaling_factor: float = field(default=1.0)
|
vllm_lora_scaling_factor: float = field(default=1.0)
|
||||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
|
||||||
|
|
||||||
def _validate_features(self) -> list[str]:
|
def _validate_features(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -59,12 +56,6 @@ class PEFTHelper:
|
|||||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||||
else:
|
else:
|
||||||
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||||
if self.context_length:
|
|
||||||
if self.vllm_max_position_embeddings is None:
|
|
||||||
self.vllm_max_position_embeddings = self.context_length
|
|
||||||
self.vllm_long_context_scaling_factor = float(
|
|
||||||
math.ceil(self.context_length /
|
|
||||||
self.vllm_max_position_embeddings))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
|
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
|
||||||
|
|||||||
@ -17,7 +17,6 @@ from .utils import compute_meta, convert_mapping
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# avoid circuit import
|
# avoid circuit import
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.models import LongContextLoRAContext
|
|
||||||
|
|
||||||
|
|
||||||
class PunicaWrapperABC(ABC):
|
class PunicaWrapperABC(ABC):
|
||||||
@ -33,7 +32,6 @@ class PunicaWrapperABC(ABC):
|
|||||||
max_loras: int,
|
max_loras: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -144,14 +142,11 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
self._long_lora_indices = torch.empty(max_num_batched_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
# 5 is the number of indices tensors.
|
# 4 is the number of indices tensors.
|
||||||
# base_indices, sampler_indices, sampler_indices_padded,
|
# base_indices, sampler_indices, sampler_indices_padded,
|
||||||
# embeddings_indices,long_lora_indices
|
# embeddings_indices
|
||||||
self.indices_len: list[Optional[int]] = [None] * 5
|
self.indices_len: list[Optional[int]] = [None] * 4
|
||||||
# these attributes are the information required for sgmv kernel
|
# these attributes are the information required for sgmv kernel
|
||||||
self._seq_start_locs = torch.empty(max_batches,
|
self._seq_start_locs = torch.empty(max_batches,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
@ -176,14 +171,12 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
max_loras: int,
|
max_loras: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
|
||||||
):
|
):
|
||||||
(
|
(
|
||||||
base_indices,
|
base_indices,
|
||||||
sampler_indices,
|
sampler_indices,
|
||||||
sampler_indices_padded,
|
sampler_indices_padded,
|
||||||
embeddings_indices,
|
embeddings_indices,
|
||||||
long_lora_offsets_tensor,
|
|
||||||
indices_len,
|
indices_len,
|
||||||
) = convert_mapping(
|
) = convert_mapping(
|
||||||
mapping,
|
mapping,
|
||||||
@ -192,7 +185,6 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
vocab_size,
|
vocab_size,
|
||||||
extra_vocab_size,
|
extra_vocab_size,
|
||||||
self.device,
|
self.device,
|
||||||
long_lora_context,
|
|
||||||
)
|
)
|
||||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||||
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||||
@ -201,11 +193,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
self._embeddings_indices[:embeddings_indices.
|
self._embeddings_indices[:embeddings_indices.
|
||||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||||
embeddings_indices)
|
embeddings_indices)
|
||||||
if long_lora_offsets_tensor is not None:
|
|
||||||
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
|
||||||
long_lora_offsets_tensor)
|
|
||||||
else:
|
|
||||||
self._long_lora_indices.zero_()
|
|
||||||
self.indices_len[:] = indices_len
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
def _update_prefill_metadata(self,
|
def _update_prefill_metadata(self,
|
||||||
@ -312,28 +300,13 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
embeddings_indices_len = self.indices_len[3]
|
embeddings_indices_len = self.indices_len[3]
|
||||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||||
|
|
||||||
@property
|
def update_metadata(self, mapping: "LoRAMapping",
|
||||||
def long_lora_indices(self) -> torch.Tensor:
|
lora_index_to_id: list[Optional[int]], max_loras: int,
|
||||||
"""
|
vocab_size: int, extra_vocab_size: int, **kwargs):
|
||||||
This property provides access to the indices used for long context
|
|
||||||
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
|
|
||||||
"""
|
|
||||||
long_lora_len = self.indices_len[4]
|
|
||||||
return self._long_lora_indices[:long_lora_len]
|
|
||||||
|
|
||||||
def update_metadata(
|
|
||||||
self,
|
|
||||||
mapping: "LoRAMapping",
|
|
||||||
lora_index_to_id: list[Optional[int]],
|
|
||||||
max_loras: int,
|
|
||||||
vocab_size: int,
|
|
||||||
extra_vocab_size: int,
|
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||||
vocab_size, extra_vocab_size,
|
vocab_size, extra_vocab_size)
|
||||||
long_lora_context)
|
|
||||||
if mapping.is_prefill:
|
if mapping.is_prefill:
|
||||||
# Update metadata required for prefill-related operators.
|
# Update metadata required for prefill-related operators.
|
||||||
self._update_prefill_metadata(self.token_lora_indices)
|
self._update_prefill_metadata(self.token_lora_indices)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union, final
|
from typing import Optional, Union, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,10 +21,6 @@ if HAS_TRITON:
|
|||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
# avoid circuit import
|
|
||||||
from vllm.lora.models import LongContextLoRAContext
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class PunicaWrapperGPU(PunicaWrapperBase):
|
class PunicaWrapperGPU(PunicaWrapperBase):
|
||||||
@ -55,20 +51,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
max_num_prompts,
|
max_num_prompts,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def update_metadata(
|
def update_metadata(self, mapping: LoRAMapping,
|
||||||
self,
|
lora_index_to_id: list[Optional[int]], max_loras: int,
|
||||||
mapping: LoRAMapping,
|
vocab_size: int, extra_vocab_size: int, **kwargs):
|
||||||
lora_index_to_id: list[Optional[int]],
|
|
||||||
max_loras: int,
|
|
||||||
vocab_size: int,
|
|
||||||
extra_vocab_size: int,
|
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
self.is_prefill = mapping.is_prefill
|
self.is_prefill = mapping.is_prefill
|
||||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||||
vocab_size, extra_vocab_size,
|
vocab_size, extra_vocab_size)
|
||||||
long_lora_context)
|
|
||||||
|
|
||||||
# Prepare cuda kernel metadata tensors
|
# Prepare cuda kernel metadata tensors
|
||||||
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from vllm.lora.punica_wrapper.utils import convert_mapping
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# avoid circuit import
|
# avoid circuit import
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.models import LongContextLoRAContext
|
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
@ -45,7 +44,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
|
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
|
||||||
True)
|
True)
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
|
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
|
||||||
True)
|
True)
|
||||||
|
|
||||||
@ -323,7 +321,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
max_loras: int,
|
max_loras: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
|
||||||
):
|
):
|
||||||
# Make sure we don't accidentally collect outside operations
|
# Make sure we don't accidentally collect outside operations
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
@ -339,7 +336,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
sampler_indices,
|
sampler_indices,
|
||||||
sampler_indices_padded,
|
sampler_indices_padded,
|
||||||
embeddings_indices,
|
embeddings_indices,
|
||||||
long_lora_offsets_tensor,
|
|
||||||
indices_len,
|
indices_len,
|
||||||
) = convert_mapping(
|
) = convert_mapping(
|
||||||
mapping,
|
mapping,
|
||||||
@ -348,7 +344,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
vocab_size,
|
vocab_size,
|
||||||
extra_vocab_size,
|
extra_vocab_size,
|
||||||
"cpu",
|
"cpu",
|
||||||
long_lora_context,
|
|
||||||
)
|
)
|
||||||
self._token_lora_indices = self._pad_to_shape(
|
self._token_lora_indices = self._pad_to_shape(
|
||||||
base_indices, self._token_lora_indices.shape,
|
base_indices, self._token_lora_indices.shape,
|
||||||
@ -362,15 +357,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|||||||
self._embeddings_indices = self._pad_to_shape(
|
self._embeddings_indices = self._pad_to_shape(
|
||||||
embeddings_indices, self._embeddings_indices.shape,
|
embeddings_indices, self._embeddings_indices.shape,
|
||||||
dims=2).to(self.device)
|
dims=2).to(self.device)
|
||||||
if long_lora_offsets_tensor is not None:
|
|
||||||
self._long_lora_indices = self._pad_to_shape(
|
|
||||||
long_lora_offsets_tensor,
|
|
||||||
self._long_lora_indices.shape,
|
|
||||||
dims=1).to(self.device)
|
|
||||||
else:
|
|
||||||
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
|
|
||||||
dtype=torch.int32)
|
|
||||||
self._long_lora_indices = zeroed.to(self.device)
|
|
||||||
self.indices_len[:] = indices_len
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
def _update_prefill_metadata(self,
|
def _update_prefill_metadata(self,
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# avoid circuit import
|
# avoid circuit import
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.models import LongContextLoRAContext
|
|
||||||
|
|
||||||
|
|
||||||
def compute_meta(
|
def compute_meta(
|
||||||
@ -49,9 +48,7 @@ def convert_mapping(
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
extra_vocab_size: int,
|
extra_vocab_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
|
||||||
Optional[torch.Tensor], list[int]]:
|
|
||||||
"""Converts LoRAMapping to index tensors.
|
"""Converts LoRAMapping to index tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -60,7 +57,6 @@ def convert_mapping(
|
|||||||
max_loras: Maximum number of LoRAs.
|
max_loras: Maximum number of LoRAs.
|
||||||
vocab_size: Model vocab size.
|
vocab_size: Model vocab size.
|
||||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||||
long_lora_context: Passed if there are long context lora in a batch.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of tensors:
|
A tuple of tensors:
|
||||||
@ -78,21 +74,14 @@ def convert_mapping(
|
|||||||
requests to embedding indices. First row is for embeddings
|
requests to embedding indices. First row is for embeddings
|
||||||
added by the LoRAs, second row is for the LoRA.lora_a
|
added by the LoRAs, second row is for the LoRA.lora_a
|
||||||
embeddings.
|
embeddings.
|
||||||
long_lora_indices: Tensor of shape [batch_size] mapping
|
|
||||||
requests to RoPE offsets and rot dims for long LoRAs.
|
|
||||||
None if long context lora doesn't exist.
|
|
||||||
indices_len: List of lengths of the above tensors. It contains
|
indices_len: List of lengths of the above tensors. It contains
|
||||||
(base_indices, sampler_indices, sampler_indices_padded,
|
(base_indices, sampler_indices, sampler_indices_padded,
|
||||||
embeddings_indices, long_lora_indices).
|
embeddings_indices).
|
||||||
"""
|
"""
|
||||||
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
|
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
|
||||||
embedding_indices = index_mapping_indices.copy()
|
embedding_indices = index_mapping_indices.copy()
|
||||||
lora_indices = index_mapping_indices.copy()
|
lora_indices = index_mapping_indices.copy()
|
||||||
long_lora_offsets: Optional[torch.Tensor] = None
|
|
||||||
if long_lora_context:
|
|
||||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long)
|
|
||||||
prompt_mapping: list[int] = [
|
prompt_mapping: list[int] = [
|
||||||
lora_index_to_id.index(x) if x > 0 else -1
|
lora_index_to_id.index(x) if x > 0 else -1
|
||||||
for x in mapping.prompt_mapping
|
for x in mapping.prompt_mapping
|
||||||
@ -104,20 +93,13 @@ def convert_mapping(
|
|||||||
if index_mapping_indices[i] > 0 else -1)
|
if index_mapping_indices[i] > 0 else -1)
|
||||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||||
lora_indices[i] = lora_idx
|
lora_indices[i] = lora_idx
|
||||||
if long_lora_context:
|
|
||||||
assert long_lora_offsets is not None
|
|
||||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
|
||||||
index_mapping_indices[i], 0)
|
|
||||||
long_lora_offsets[i] = lora_offset
|
|
||||||
|
|
||||||
indices_list: list[Union[list[int], torch.Tensor]] = [
|
indices_list: list[Union[list[int], torch.Tensor]] = [
|
||||||
index_mapping_indices,
|
index_mapping_indices,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
embedding_indices,
|
embedding_indices,
|
||||||
]
|
]
|
||||||
if long_lora_context:
|
|
||||||
assert long_lora_offsets is not None
|
|
||||||
indices_list.append(long_lora_offsets)
|
|
||||||
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
@ -136,11 +118,7 @@ def convert_mapping(
|
|||||||
sampler_indices_padded = torch.arange(
|
sampler_indices_padded = torch.arange(
|
||||||
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||||
sampler_indices_padded * len(sampler_indices_padded))
|
sampler_indices_padded * len(sampler_indices_padded))
|
||||||
long_lora_indices = None
|
|
||||||
long_lora_indices_len: Optional[int] = None
|
|
||||||
if long_lora_context:
|
|
||||||
long_lora_indices = indices[3]
|
|
||||||
long_lora_indices_len = long_lora_indices.shape[-1]
|
|
||||||
# Contain length of indices tensors. Used to index into each tensor.
|
# Contain length of indices tensors. Used to index into each tensor.
|
||||||
indices_len = [
|
indices_len = [
|
||||||
base_indices.shape[-1],
|
base_indices.shape[-1],
|
||||||
@ -148,17 +126,11 @@ def convert_mapping(
|
|||||||
sampler_indices_padded.shape[-1],
|
sampler_indices_padded.shape[-1],
|
||||||
embeddings_indices.shape[-1],
|
embeddings_indices.shape[-1],
|
||||||
]
|
]
|
||||||
if long_lora_indices_len is not None:
|
|
||||||
indices_len.append(long_lora_indices_len)
|
|
||||||
else:
|
|
||||||
# If long_lora doesn't exist,append None
|
|
||||||
indices_len.append(None)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
base_indices,
|
base_indices,
|
||||||
sampler_indices,
|
sampler_indices,
|
||||||
sampler_indices_padded,
|
sampler_indices_padded,
|
||||||
embeddings_indices,
|
embeddings_indices,
|
||||||
long_lora_indices,
|
|
||||||
indices_len,
|
indices_len,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from vllm.lora.fully_sharded_layers import (
|
|||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||||
LinearScalingRotaryEmbeddingWithLoRA,
|
|
||||||
LogitsProcessorWithLoRA,
|
LogitsProcessorWithLoRA,
|
||||||
MergedColumnParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
MergedQKVParallelLinearWithLoRA,
|
MergedQKVParallelLinearWithLoRA,
|
||||||
@ -56,7 +55,6 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
|||||||
MergedColumnParallelLinearWithShardedLoRA,
|
MergedColumnParallelLinearWithShardedLoRA,
|
||||||
MergedQKVParallelLinearWithShardedLoRA,
|
MergedQKVParallelLinearWithShardedLoRA,
|
||||||
RowParallelLinearWithShardedLoRA,
|
RowParallelLinearWithShardedLoRA,
|
||||||
LinearScalingRotaryEmbeddingWithLoRA,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -154,7 +154,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
lora_request.lora_int_id)
|
lora_request.lora_int_id)
|
||||||
else:
|
else:
|
||||||
dummy_lora = self._adapter_manager.create_dummy_lora(
|
dummy_lora = self._adapter_manager.create_dummy_lora(
|
||||||
lora_request.lora_int_id, rank, 1, self.embedding_modules)
|
lora_request.lora_int_id, rank, self.embedding_modules)
|
||||||
if self._cached_dummy_lora is None:
|
if self._cached_dummy_lora is None:
|
||||||
self._cached_dummy_lora = dummy_lora
|
self._cached_dummy_lora = dummy_lora
|
||||||
return self._adapter_manager.add_adapter(dummy_lora)
|
return self._adapter_manager.add_adapter(dummy_lora)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user