[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (#6909)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
John Giorgi 2024-12-31 01:15:58 -05:00 committed by GitHub
parent 74fa1d123c
commit 82c49d3260
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 22 deletions

View File

@ -1,4 +1,5 @@
import json import json
import math
import os import os
from typing import Dict, List from typing import Dict, List
@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
"embed_tokens", "embed_tokens",
"lm_head", "lm_head",
] ]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
expected_error = "vLLM only supports modules_to_save being None." expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error): with pytest.raises(ValueError, match=expected_error):
@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
modules_to_save=["lm_head"], modules_to_save=["lm_head"],
) )
PEFTHelper.from_dict(config) PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support RSLoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support DoRA." expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error): with pytest.raises(ValueError, match=expected_error):

View File

@ -67,15 +67,9 @@ class LoRALayerWeights:
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights": ) -> "LoRALayerWeights":
return cls( return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
module_name, None, None, embeddings_tensor,
peft_helper.r, peft_helper.vllm_lora_scaling_factor)
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)
@classmethod @classmethod
def create_dummy_lora_weights( def create_dummy_lora_weights(

View File

@ -173,7 +173,7 @@ class LoRAModel(AdapterModel):
return cls(lora_model_id, return cls(lora_model_id,
peft_helper.r, peft_helper.r,
loras, loras,
scaling_factor=peft_helper.vllm_scaling_factor) scaling_factor=peft_helper.vllm_long_context_scaling_factor)
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(

View File

@ -4,6 +4,8 @@ import math
from dataclasses import MISSING, dataclass, field, fields from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from vllm.utils import print_info_once
@dataclass @dataclass
class PEFTHelper: class PEFTHelper:
@ -14,21 +16,22 @@ class PEFTHelper:
bias: Literal["none", "all", "lora_only"] = field(default="none") bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None) modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False) use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False) use_dora: bool = field(default=False)
# long lora field # long context lora field
context_length: int = field(default=0) context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict # Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None) vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self): def _validate_features(self):
error_msg = [] error_msg = []
if self.modules_to_save: if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.") error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")
if self.use_dora: if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.") error_msg.append("vLLM does not yet support DoRA.")
@ -38,10 +41,15 @@ class PEFTHelper:
def __post_init__(self): def __post_init__(self):
self._validate_features() self._validate_features()
if self.use_rslora:
print_info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length: if self.context_length:
if self.vllm_max_position_embeddings is None: if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float( self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length / math.ceil(self.context_length /
self.vllm_max_position_embeddings)) self.vllm_max_position_embeddings))