Improve configs - LoRAConfig + PromptAdapterConfig (#16980)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-24 18:29:34 +01:00 committed by GitHub
parent 0422ce109f
commit 0fa939e2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 91 deletions

View File

@ -31,6 +31,8 @@ DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] if current_platform.is_cuda_alike() else ["cpu"]) ] if current_platform.is_cuda_alike() else ["cpu"])
DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch): def use_v0_only(monkeypatch: pytest.MonkeyPatch):
@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
model = dummy_model model = dummy_model
manager = LoRAModelManager( manager = LoRAModelManager(
model, 1, 1, 1, model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), LoRAConfig(max_lora_rank=8,
torch.device(DEVICES[0])) max_cpu_loras=8,
max_loras=8,
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
model = manager.model model = manager.model
assert isinstance(model.get_submodule("dense1"), assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA) ColumnParallelLinearWithLoRA)
@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
2, 2,
LoRAConfig(max_lora_rank=8, LoRAConfig(max_lora_rank=8,
max_cpu_loras=3, max_cpu_loras=3,
max_loras=2), max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device) device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
2, 2,
LoRAConfig(max_lora_rank=8, LoRAConfig(max_lora_rank=8,
max_cpu_loras=3, max_cpu_loras=3,
max_loras=2), max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device) device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
2, 2,
LoRAConfig(max_lora_rank=8, LoRAConfig(max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_loras=2), max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device) device=device)
assert all(x is None for x in manager.lora_index_to_id) assert all(x is None for x in manager.lora_index_to_id)
@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device): sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = LRUCacheWorkerLoRAManager( worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device, lora_config.lora_extra_vocab_size, lora_config, device,
@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device): sql_lora_files, device):
# Should remove every LoRA not specified in the request. # Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager( worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device, lora_config.lora_extra_vocab_size, lora_config, device,
@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
2, 2,
LoRAConfig(max_lora_rank=8, LoRAConfig(max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_loras=2), max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device) device=device)
model = manager.model model = manager.model

View File

@ -2565,18 +2565,41 @@ class SpeculativeConfig:
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
LoRADType = Literal["auto", "float16", "bfloat16"]
@config
@dataclass @dataclass
class LoRAConfig: class LoRAConfig:
max_lora_rank: int """Configuration for LoRA."""
max_loras: int
max_lora_rank: int = 16
"""Max LoRA rank."""
max_loras: int = 1
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None """Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant. # This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256 lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[tuple[float]] = None long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
bias_enabled: bool = False bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -2641,12 +2664,19 @@ class LoRAConfig:
"V1 LoRA does not support long LoRA, please use V0.") "V1 LoRA does not support long LoRA, please use V0.")
@config
@dataclass @dataclass
class PromptAdapterConfig: class PromptAdapterConfig:
max_prompt_adapters: int max_prompt_adapters: int = 1
max_prompt_adapter_token: int """Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None """Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -2678,7 +2708,7 @@ class PromptAdapterConfig:
self.max_cpu_prompt_adapters = self.max_prompt_adapters self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig): def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"): if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str): elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch, self.prompt_adapter_dtype = getattr(torch,

View File

@ -7,7 +7,7 @@ import json
import re import re
import threading import threading
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type, from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin) TypeVar, Union, cast, get_args, get_origin)
import torch import torch
@ -192,18 +192,23 @@ class EngineArgs:
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = 1 max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = 16 max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1 max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = 0 max_prompt_adapter_token: int = \
fully_sharded_loras: bool = False PromptAdapterConfig.max_prompt_adapter_token
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: Device = DeviceConfig.device device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
@ -338,10 +343,21 @@ class EngineArgs:
kwargs[name]["choices"] = choices kwargs[name]["choices"] = choices
choice_type = type(choices[0]) choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), ( assert all(type(c) is choice_type for c in choices), (
f"All choices must be of the same type. " "All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}" f"Got {choices} with types {[type(c) for c in choices]}"
) )
kwargs[name]["type"] = choice_type kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int): elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float): elif can_be_type(field_type, float):
@ -685,70 +701,49 @@ class EngineArgs:
'inputs.') 'inputs.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', lora_kwargs = get_kwargs(LoRAConfig)
action='store_true', lora_group = parser.add_argument_group(
title="LoRAConfig",
description=LoRAConfig.__doc__,
)
lora_group.add_argument(
'--enable-lora',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of LoRA adapters.') help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias', lora_group.add_argument('--enable-lora-bias',
action='store_true', **lora_kwargs["bias_enabled"])
help='If True, enable bias for LoRA adapters.') lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
parser.add_argument('--max-loras', lora_group.add_argument('--max-lora-rank',
type=int, **lora_kwargs["max_lora_rank"])
default=EngineArgs.max_loras, lora_group.add_argument('--lora-extra-vocab-size',
help='Max number of LoRAs in a single batch.') **lora_kwargs["lora_extra_vocab_size"])
parser.add_argument('--max-lora-rank', lora_group.add_argument(
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype', '--lora-dtype',
type=str, **lora_kwargs["lora_dtype"],
default=EngineArgs.lora_dtype, )
choices=['auto', 'float16', 'bfloat16'], lora_group.add_argument('--long-lora-scaling-factors',
help=('Data type for LoRA. If auto, will default to ' **lora_kwargs["long_lora_scaling_factors"])
'base model dtype.')) lora_group.add_argument('--max-cpu-loras',
parser.add_argument( **lora_kwargs["max_cpu_loras"])
'--long-lora-scaling-factors', lora_group.add_argument('--fully-sharded-loras',
type=optional_str, **lora_kwargs["fully_sharded_loras"])
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can ' # PromptAdapter related configs
'be different from base model scaling factor ' prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
'- see eg. Long LoRA) to allow for multiple ' prompt_adapter_group = parser.add_argument_group(
'LoRA adapters trained with those scaling ' title="PromptAdapterConfig",
'factors to be used at the same time. If not ' description=PromptAdapterConfig.__doc__,
'specified, only adapters trained with the ' )
'base model scaling factor are allowed.')) prompt_adapter_group.add_argument(
parser.add_argument( '--enable-prompt-adapter',
'--max-cpu-loras', action=argparse.BooleanOptionalAction,
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_loras.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.') help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters', prompt_adapter_group.add_argument(
type=int, '--max-prompt-adapters',
default=EngineArgs.max_prompt_adapters, **prompt_adapter_kwargs["max_prompt_adapters"])
help='Max number of PromptAdapters in a batch.') prompt_adapter_group.add_argument(
parser.add_argument('--max-prompt-adapter-token', '--max-prompt-adapter-token',
type=int, **prompt_adapter_kwargs["max_prompt_adapter_token"])
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
# Device arguments # Device arguments
device_kwargs = get_kwargs(DeviceConfig) device_kwargs = get_kwargs(DeviceConfig)