mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 19:25:23 +08:00
[Core] Modify the initialization parameters of the lora manager (#25249)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
6c117cff7d
commit
2821986450
@ -8,11 +8,12 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2,
|
||||
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
|
||||
worker_adapter_manager.max_num_seqs = 4
|
||||
worker_adapter_manager.max_num_batched_tokens = 2
|
||||
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
|
||||
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
|
||||
EMBEDDING_MODULES,
|
||||
EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.vocab_size = (
|
||||
dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
|
||||
class DummyLoRAManager:
|
||||
|
||||
@ -14,7 +14,7 @@ from torch import nn
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
@ -27,25 +27,26 @@ class WorkerLoRAManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: dict[str, str],
|
||||
embedding_padding_modules: list[str],
|
||||
lora_model_cls: type[LoRAModel] = LoRAModel,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
):
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.lora_config = lora_config
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
self.lora_config = vllm_config.lora_config
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
|
||||
self.max_position_embeddings = text_config.max_position_embeddings
|
||||
self.device = device
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._adapter_manager: LoRAModelManager
|
||||
|
||||
@ -107,9 +107,8 @@ class CPUModelRunner(GPUModelRunner):
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
self.model = self.load_lora_model(self.model, self.vllm_config,
|
||||
self.device)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@ -2552,10 +2552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config,
|
||||
self.model = self.load_lora_model(self.model, self.vllm_config,
|
||||
self.device)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info("Loading drafter model...")
|
||||
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
@ -31,9 +31,7 @@ class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: LoRAConfig,
|
||||
def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
|
||||
device: torch.device) -> nn.Module:
|
||||
|
||||
if not supports_lora(model):
|
||||
@ -44,19 +42,12 @@ class LoRAModelRunnerMixin:
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = model_config.hf_config.get_text_config()
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
scheduler_config.max_num_seqs,
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
model_config.get_vocab_size(),
|
||||
lora_config,
|
||||
vllm_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.max_position_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
|
||||
@ -1178,9 +1178,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"or sharding the weights on more chips. "
|
||||
f"See the detailed error: {e}") from e
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
model = self.load_lora_model(model, self.vllm_config, self.device)
|
||||
replace_set_lora(model)
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
|
||||
@ -1078,20 +1078,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = self.model_config.hf_config.get_text_config()
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.
|
||||
max_position_embeddings,
|
||||
)
|
||||
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user