mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 06:45:01 +08:00
[Misc] Embedding model support LoRA (#14935)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
f863ffc965
commit
db7c8ca910
@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
|||||||
is_regex_target_modules,
|
is_regex_target_modules,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||||
|
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
|
|||||||
"""Get LoRA for a given module by name"""
|
"""Get LoRA for a given module by name"""
|
||||||
return self.loras.get(module_name, None)
|
return self.loras.get(module_name, None)
|
||||||
|
|
||||||
|
def check_lora_name(self, lora_name: str) -> bool:
|
||||||
|
return lora_name in self.loras
|
||||||
|
|
||||||
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_lora_tensors(
|
def from_lora_tensors(
|
||||||
@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
# Used for long context lora.
|
# Used for long context lora.
|
||||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
|
|
||||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||||
f"{self.model.__class__.__name__}."
|
f"{self.model.__class__.__name__}."
|
||||||
@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
# In case the model only supports LoRA for
|
# In case the model only supports LoRA for
|
||||||
# text modules (e.g. ChatGLM)
|
# text modules (e.g. ChatGLM)
|
||||||
and hasattr(self.model, "get_mm_mapping"))
|
and hasattr(self.model, "get_mm_mapping"))
|
||||||
|
self.is_pooling_model = is_pooling_model(self.model)
|
||||||
self.packed_modules: Dict[str, List[str]] = {}
|
self.packed_modules: Dict[str, List[str]] = {}
|
||||||
self.modules: Dict[str, BaseLayerWithLoRA] = {}
|
self.modules: Dict[str, BaseLayerWithLoRA] = {}
|
||||||
# Dict instead of a Set for compatibility with LRUCache.
|
# Dict instead of a Set for compatibility with LRUCache.
|
||||||
@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
lora_model.id, index)
|
lora_model.id, index)
|
||||||
self.lora_index_to_id[index] = lora_model.id
|
self.lora_index_to_id[index] = lora_model.id
|
||||||
for module_name, module in self.modules.items():
|
for module_name, module in self.modules.items():
|
||||||
module_lora = lora_model.get_lora(module_name)
|
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||||
if module_lora:
|
if module_lora:
|
||||||
module_lora.optimize()
|
module_lora.optimize()
|
||||||
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
||||||
@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
replaced_module: Set[str] = set()
|
replaced_module: Set[str] = set()
|
||||||
has_replacement = False
|
has_replacement = False
|
||||||
for r in new_module_names:
|
for r in new_module_names:
|
||||||
lora = lora_model.get_lora(r)
|
lora = self._get_lora_layer_weights(lora_model, r)
|
||||||
replacement_loras.append(lora)
|
replacement_loras.append(lora)
|
||||||
if lora:
|
if lora:
|
||||||
has_replacement = True
|
has_replacement = True
|
||||||
@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
if replacement_loras[i]:
|
if replacement_loras[i]:
|
||||||
continue
|
continue
|
||||||
replacement_loras[i] = None
|
replacement_loras[i] = None
|
||||||
|
# HACK Temporary solution for the pool model.
|
||||||
|
if self.is_pooling_model and not lora_model.check_lora_name(
|
||||||
|
module_name):
|
||||||
|
replaced_module_name = module_name.replace("model.", "")
|
||||||
|
if lora_model.check_lora_name(module_name):
|
||||||
|
module_name = replaced_module_name
|
||||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||||
replacement_loras)
|
replacement_loras)
|
||||||
# Remove the modules that have been replaced.
|
# Remove the modules that have been replaced.
|
||||||
for module in replaced_module:
|
for module in replaced_module:
|
||||||
lora_model.loras.pop(module, None)
|
lora_model.loras.pop(module, None)
|
||||||
|
|
||||||
|
def _get_lora_layer_weights(
|
||||||
|
self, lora_model: LoRAModel,
|
||||||
|
module_name: str) -> Optional[LoRALayerWeights]:
|
||||||
|
org_module_name = module_name
|
||||||
|
if self.is_pooling_model and not lora_model.check_lora_name(
|
||||||
|
module_name):
|
||||||
|
# If it's a pool model, and the layer name is not found,
|
||||||
|
# remove the prefix 'model.' and search again.
|
||||||
|
module_name = module_name.replace("model.", "")
|
||||||
|
if lora_model.check_lora_name(module_name):
|
||||||
|
org_module_name = module_name
|
||||||
|
logger.info_once(
|
||||||
|
"For the pool model, successfully loaded the LoRA weights "
|
||||||
|
"after removing the prefix 'model.'.")
|
||||||
|
return lora_model.get_lora(org_module_name)
|
||||||
|
|
||||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||||
self._deactivate_adapter)
|
self._deactivate_adapter)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user