[Core] Modify the initialization parameters of the lora manager (#25249)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-09-20 02:01:28 +08:00 committed by GitHub
parent 6c117cff7d
commit 2821986450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 51 additions and 52 deletions

View File

@ -8,11 +8,12 @@ import torch
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA) RowParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager) LRUCacheLoRAModelManager)
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
target_modules=["layer1.dense1", "dense2"], target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE, lora_dtype=DEFAULT_DTYPE,
) )
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config,
lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager( worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_adapter_manager.max_num_seqs = 4
worker_adapter_manager.max_num_batched_tokens = 2
worker_adapter_manager.create_lora_manager(dummy_model) worker_adapter_manager.create_lora_manager(dummy_model)
mapping = LoRAMapping([], []) mapping = LoRAMapping([], [])
@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
max_cpu_loras=4, max_cpu_loras=4,
max_loras=4, max_loras=4,
lora_dtype=DEFAULT_DTYPE) lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, dummy_model_gate_up.unpadded_vocab_size - model_config = ModelConfig(max_model_len=16)
lora_config.lora_extra_vocab_size, lora_config, device, vllm_config = VllmConfig(model_config=model_config,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
EMBEDDING_MODULES,
EMBEDDING_PADDING_MODULES)
worker_adapter_manager.vocab_size = (
dummy_model_gate_up.unpadded_vocab_size -
lora_config.lora_extra_vocab_size)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up) worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
dummy_lora_files = f"{tmp_path}/lora_adapter" dummy_lora_files = f"{tmp_path}/lora_adapter"

View File

@ -9,7 +9,7 @@ from typing import Optional, Union
import torch import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager: class DummyLoRAManager:

View File

@ -14,7 +14,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,

View File

@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
import torch import torch
from vllm.config.lora import LoRAConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
@ -27,25 +27,26 @@ class WorkerLoRAManager:
def __init__( def __init__(
self, self,
max_num_seqs: int, vllm_config: VllmConfig,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device, device: torch.device,
embedding_modules: dict[str, str], embedding_modules: dict[str, str],
embedding_padding_modules: list[str], embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel, lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
): ):
self._lora_model_cls = lora_model_cls self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = (
self.vocab_size = vocab_size vllm_config.scheduler_config.max_num_batched_tokens)
self.lora_config = lora_config self.vocab_size = vllm_config.model_config.get_vocab_size()
self.max_position_embeddings = max_position_embeddings self.lora_config = vllm_config.lora_config
# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()
self.max_position_embeddings = text_config.max_position_embeddings
self.device = device self.device = device
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager self._adapter_manager: LoRAModelManager

View File

@ -107,9 +107,8 @@ class CPUModelRunner(GPUModelRunner):
self.model = get_model(vllm_config=self.vllm_config) self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config, self.model = self.load_lora_model(self.model, self.vllm_config,
self.scheduler_config, self.device)
self.lora_config, self.device)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model

View File

@ -2552,10 +2552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config) vllm_config=self.vllm_config, model_config=self.model_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model = self.load_lora_model(self.model, self.vllm_config,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device) self.device)
if hasattr(self, "drafter"): if hasattr(self, "drafter"):
logger.info("Loading drafter model...") logger.info("Loading drafter model...")

View File

@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
@ -31,9 +31,7 @@ class LoRAModelRunnerMixin:
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
def load_lora_model(self, model: nn.Module, model_config: ModelConfig, def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig,
device: torch.device) -> nn.Module: device: torch.device) -> nn.Module:
if not supports_lora(model): if not supports_lora(model):
@ -44,19 +42,12 @@ class LoRAModelRunnerMixin:
logger.warning("Regarding multimodal models, vLLM currently " logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = model_config.hf_config.get_text_config()
# Add LoRA Manager to the Model Runner # Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs, vllm_config,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
device, device,
model.embedding_modules, model.embedding_modules,
model.embedding_padding_modules, model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
) )
return self.lora_manager.create_lora_manager(model) return self.lora_manager.create_lora_manager(model)

View File

@ -1178,9 +1178,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"or sharding the weights on more chips. " "or sharding the weights on more chips. "
f"See the detailed error: {e}") from e f"See the detailed error: {e}") from e
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.vllm_config, self.device)
self.scheduler_config,
self.lora_config, self.device)
replace_set_lora(model) replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight # Sync all pending XLA execution during model initialization and weight

View File

@ -1078,20 +1078,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"Regarding multimodal models, vLLM currently " "Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.vllm_config,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()