mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 17:24:28 +08:00
[Model][LoRA]LoRA support added for glm-4v (#10418)
Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
parent
01aae1cc68
commit
5be4e52b65
@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
|
||||||
@ -574,25 +575,8 @@ class ChatGLMModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
SupportsMultiModal):
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
|
||||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
|
||||||
SupportsMultiModal):
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"query_key_value": ["query_key_value"],
|
|
||||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
|
||||||
}
|
|
||||||
# LoRA specific attributes
|
|
||||||
supported_lora_modules = [
|
|
||||||
"query_key_value",
|
|
||||||
"dense",
|
|
||||||
"dense_h_to_4h",
|
|
||||||
"dense_4h_to_h",
|
|
||||||
]
|
|
||||||
embedding_modules = {}
|
|
||||||
embedding_padding_modules = []
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -692,3 +676,79 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
|||||||
weight_loader(param, combined_weight)
|
weight_loader(param, combined_weight)
|
||||||
loaded_params.add(combined_name)
|
loaded_params.add(combined_name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLM(ChatGLMBaseModel):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"query_key_value": ["query_key_value"],
|
||||||
|
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"query_key_value",
|
||||||
|
"dense",
|
||||||
|
"dense_h_to_4h",
|
||||||
|
"dense_4h_to_h",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMV(ChatGLMBaseModel):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"query_key_value": ["query_key_value"],
|
||||||
|
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||||
|
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"query_key_value",
|
||||||
|
"dense",
|
||||||
|
"dense_h_to_4h",
|
||||||
|
"dense_4h_to_h",
|
||||||
|
# vision
|
||||||
|
"fc1",
|
||||||
|
"fc2",
|
||||||
|
"merged_proj",
|
||||||
|
"linear_proj"
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="transformer.encoder",
|
||||||
|
connector="transformer.vision.linear_proj",
|
||||||
|
tower_model="transformer.vision.transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||||
|
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||||
|
SupportsMultiModal):
|
||||||
|
# Ensure that the LoRA support check passes when the class is not
|
||||||
|
# initialized, but set all these attributes to empty.
|
||||||
|
packed_modules_mapping = {}
|
||||||
|
supported_lora_modules = []
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
# Initialize VL
|
||||||
|
if hasattr(config, "visual"):
|
||||||
|
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
# Initialize LLM
|
||||||
|
else:
|
||||||
|
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user