mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (#6909)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
74fa1d123c
commit
82c49d3260
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
|
|||||||
"embed_tokens",
|
"embed_tokens",
|
||||||
"lm_head",
|
"lm_head",
|
||||||
]
|
]
|
||||||
|
scaling = peft_helper.lora_alpha / peft_helper.r
|
||||||
|
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||||
|
|
||||||
|
# test RSLoRA
|
||||||
|
config = dict(r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["gate_proj"],
|
||||||
|
use_rslora=True)
|
||||||
|
peft_helper = PEFTHelper.from_dict(config)
|
||||||
|
|
||||||
|
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
|
||||||
|
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||||
|
|
||||||
expected_error = "vLLM only supports modules_to_save being None."
|
expected_error = "vLLM only supports modules_to_save being None."
|
||||||
with pytest.raises(ValueError, match=expected_error):
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
|
|||||||
modules_to_save=["lm_head"],
|
modules_to_save=["lm_head"],
|
||||||
)
|
)
|
||||||
PEFTHelper.from_dict(config)
|
PEFTHelper.from_dict(config)
|
||||||
expected_error = "vLLM does not yet support RSLoRA."
|
|
||||||
with pytest.raises(ValueError, match=expected_error):
|
|
||||||
config = dict(r=8,
|
|
||||||
lora_alpha=16,
|
|
||||||
target_modules=["gate_proj"],
|
|
||||||
use_rslora=True)
|
|
||||||
PEFTHelper.from_dict(config)
|
|
||||||
|
|
||||||
expected_error = "vLLM does not yet support DoRA."
|
expected_error = "vLLM does not yet support DoRA."
|
||||||
with pytest.raises(ValueError, match=expected_error):
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
|||||||
@ -67,15 +67,9 @@ class LoRALayerWeights:
|
|||||||
peft_helper: PEFTHelper,
|
peft_helper: PEFTHelper,
|
||||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||||
) -> "LoRALayerWeights":
|
) -> "LoRALayerWeights":
|
||||||
return cls(
|
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
|
||||||
module_name,
|
None, None, embeddings_tensor,
|
||||||
peft_helper.r,
|
peft_helper.vllm_lora_scaling_factor)
|
||||||
peft_helper.lora_alpha,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
embeddings_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_dummy_lora_weights(
|
def create_dummy_lora_weights(
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class LoRAModel(AdapterModel):
|
|||||||
return cls(lora_model_id,
|
return cls(lora_model_id,
|
||||||
peft_helper.r,
|
peft_helper.r,
|
||||||
loras,
|
loras,
|
||||||
scaling_factor=peft_helper.vllm_scaling_factor)
|
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_checkpoint(
|
def from_local_checkpoint(
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import math
|
|||||||
from dataclasses import MISSING, dataclass, field, fields
|
from dataclasses import MISSING, dataclass, field, fields
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from vllm.utils import print_info_once
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PEFTHelper:
|
class PEFTHelper:
|
||||||
@ -14,21 +16,22 @@ class PEFTHelper:
|
|||||||
|
|
||||||
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
||||||
modules_to_save: Optional[list[str]] = field(default=None)
|
modules_to_save: Optional[list[str]] = field(default=None)
|
||||||
|
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
|
||||||
use_rslora: bool = field(default=False)
|
use_rslora: bool = field(default=False)
|
||||||
|
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
||||||
use_dora: bool = field(default=False)
|
use_dora: bool = field(default=False)
|
||||||
# long lora field
|
# long context lora field
|
||||||
context_length: int = field(default=0)
|
context_length: int = field(default=0)
|
||||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||||
|
vllm_lora_scaling_factor: float = field(default=1.0)
|
||||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||||
vllm_scaling_factor: Optional[float] = field(default=None)
|
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||||
|
|
||||||
def _validate_features(self):
|
def _validate_features(self):
|
||||||
error_msg = []
|
error_msg = []
|
||||||
|
|
||||||
if self.modules_to_save:
|
if self.modules_to_save:
|
||||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||||
if self.use_rslora:
|
|
||||||
error_msg.append("vLLM does not yet support RSLoRA.")
|
|
||||||
|
|
||||||
if self.use_dora:
|
if self.use_dora:
|
||||||
error_msg.append("vLLM does not yet support DoRA.")
|
error_msg.append("vLLM does not yet support DoRA.")
|
||||||
@ -38,10 +41,15 @@ class PEFTHelper:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._validate_features()
|
self._validate_features()
|
||||||
|
if self.use_rslora:
|
||||||
|
print_info_once("Loading LoRA weights trained with rsLoRA.")
|
||||||
|
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||||
|
else:
|
||||||
|
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||||
if self.context_length:
|
if self.context_length:
|
||||||
if self.vllm_max_position_embeddings is None:
|
if self.vllm_max_position_embeddings is None:
|
||||||
self.vllm_max_position_embeddings = self.context_length
|
self.vllm_max_position_embeddings = self.context_length
|
||||||
self.vllm_scaling_factor = float(
|
self.vllm_long_context_scaling_factor = float(
|
||||||
math.ceil(self.context_length /
|
math.ceil(self.context_length /
|
||||||
self.vllm_max_position_embeddings))
|
self.vllm_max_position_embeddings))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user