mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (#6909)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
74fa1d123c
commit
82c49d3260
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
scaling = peft_helper.lora_alpha / peft_helper.r
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
# test RSLoRA
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_rslora=True)
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
|
||||
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
expected_error = "vLLM only supports modules_to_save being None."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
|
||||
modules_to_save=["lm_head"],
|
||||
)
|
||||
PEFTHelper.from_dict(config)
|
||||
expected_error = "vLLM does not yet support RSLoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_rslora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
expected_error = "vLLM does not yet support DoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
|
||||
@ -67,15 +67,9 @@ class LoRALayerWeights:
|
||||
peft_helper: PEFTHelper,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
) -> "LoRALayerWeights":
|
||||
return cls(
|
||||
module_name,
|
||||
peft_helper.r,
|
||||
peft_helper.lora_alpha,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
embeddings_tensor,
|
||||
)
|
||||
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
|
||||
None, None, embeddings_tensor,
|
||||
peft_helper.vllm_lora_scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def create_dummy_lora_weights(
|
||||
|
||||
@ -173,7 +173,7 @@ class LoRAModel(AdapterModel):
|
||||
return cls(lora_model_id,
|
||||
peft_helper.r,
|
||||
loras,
|
||||
scaling_factor=peft_helper.vllm_scaling_factor)
|
||||
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
|
||||
@ -4,6 +4,8 @@ import math
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from vllm.utils import print_info_once
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
@ -14,21 +16,22 @@ class PEFTHelper:
|
||||
|
||||
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
||||
modules_to_save: Optional[list[str]] = field(default=None)
|
||||
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
|
||||
use_rslora: bool = field(default=False)
|
||||
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
||||
use_dora: bool = field(default=False)
|
||||
# long lora field
|
||||
# long context lora field
|
||||
context_length: int = field(default=0)
|
||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||
vllm_lora_scaling_factor: float = field(default=1.0)
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_scaling_factor: Optional[float] = field(default=None)
|
||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self):
|
||||
error_msg = []
|
||||
|
||||
if self.modules_to_save:
|
||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||
if self.use_rslora:
|
||||
error_msg.append("vLLM does not yet support RSLoRA.")
|
||||
|
||||
if self.use_dora:
|
||||
error_msg.append("vLLM does not yet support DoRA.")
|
||||
@ -38,10 +41,15 @@ class PEFTHelper:
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_features()
|
||||
if self.use_rslora:
|
||||
print_info_once("Loading LoRA weights trained with rsLoRA.")
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||
else:
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||
if self.context_length:
|
||||
if self.vllm_max_position_embeddings is None:
|
||||
self.vllm_max_position_embeddings = self.context_length
|
||||
self.vllm_scaling_factor = float(
|
||||
self.vllm_long_context_scaling_factor = float(
|
||||
math.ceil(self.context_length /
|
||||
self.vllm_max_position_embeddings))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user