mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 17:17:06 +08:00
Move forward
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
d4f39dc38a
commit
0642610719
@ -80,6 +80,10 @@ class LoRAModelManager:
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.model: SupportsLoRA = model
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self._registered_adapters: dict[int, LoRAModel] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._active_adapters: dict[int, None] = {}
|
||||
@ -91,30 +95,31 @@ class LoRAModelManager:
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._last_mapping: LoRAMapping | None = None
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
|
||||
self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
|
||||
self._create_lora_modules()
|
||||
|
||||
self.model.lora_manager = self
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig):
|
||||
def _init_punica_wrapper(
|
||||
self, max_num_batched_tokens: int, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
self._maybe_init_mm(vllm_config)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig) -> None:
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -320,7 +325,7 @@ class LoRAModelManager:
|
||||
# Default to the main language model wrapper
|
||||
target_wrapper = self.punica_wrapper
|
||||
|
||||
if self.supports_tower_connector_lora:
|
||||
if self.supports_mm and self.supports_tower_connector_lora:
|
||||
if mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
|
||||
target_name = self.mm_mapping.tower_model[0]
|
||||
target_wrapper = self.mm_punica_wrapper_mapping[target_name]
|
||||
@ -363,7 +368,7 @@ class LoRAModelManager:
|
||||
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
logger.warning(
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
"Regarding %s, vLLM currently only supports adding LoRA to"
|
||||
" language model, {module_name} will be ignored.",
|
||||
self.model.__class__.__name__,
|
||||
module_name,
|
||||
@ -427,7 +432,7 @@ class LoRAModelManager:
|
||||
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
if self.supports_tower_connector_lora:
|
||||
if self.supports_mm and self.supports_tower_connector_lora:
|
||||
new_module.set_mapping(self._get_mm_punica_wrapper(module_name))
|
||||
else:
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user