Improve configs - LoRAConfig + PromptAdapterConfig (#16980)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-24 18:29:34 +01:00 committed by GitHub
parent 0422ce109f
commit 0fa939e2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 91 deletions

View File

@ -31,6 +31,8 @@ DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] if current_platform.is_cuda_alike() else ["cpu"])
DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0]))
LoRAConfig(max_lora_rank=8,
max_cpu_loras=8,
max_loras=8,
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
model = manager.model

View File

@ -2565,18 +2565,41 @@ class SpeculativeConfig:
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
LoRADType = Literal["auto", "float16", "bfloat16"]
@config
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
"""Configuration for LoRA."""
max_lora_rank: int = 16
"""Max LoRA rank."""
max_loras: int = 1
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[tuple[float]] = None
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
def compute_hash(self) -> str:
"""
@ -2641,12 +2664,19 @@ class LoRAConfig:
"V1 LoRA does not support long LoRA, please use V0.")
@config
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
@ -2678,7 +2708,7 @@ class PromptAdapterConfig:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,

View File

@ -7,7 +7,7 @@ import json
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
import torch
@ -192,18 +192,23 @@ class EngineArgs:
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
@ -338,10 +343,21 @@ class EngineArgs:
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
f"All choices must be of the same type. "
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
@ -685,70 +701,49 @@ class EngineArgs:
'inputs.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
lora_kwargs = get_kwargs(LoRAConfig)
lora_group = parser.add_argument_group(
title="LoRAConfig",
description=LoRAConfig.__doc__,
)
lora_group.add_argument(
'--enable-lora',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of LoRA adapters.')
lora_group.add_argument('--enable-lora-bias',
**lora_kwargs["bias_enabled"])
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
lora_group.add_argument('--max-lora-rank',
**lora_kwargs["max_lora_rank"])
lora_group.add_argument('--lora-extra-vocab-size',
**lora_kwargs["lora_extra_vocab_size"])
lora_group.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=optional_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_loras.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
**lora_kwargs["lora_dtype"],
)
lora_group.add_argument('--long-lora-scaling-factors',
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument('--max-cpu-loras',
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument('--fully-sharded-loras',
**lora_kwargs["fully_sharded_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
'--enable-prompt-adapter',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of PromptAdapters.')
prompt_adapter_group.add_argument(
'--max-prompt-adapters',
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
'--max-prompt-adapter-token',
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)