mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:26:07 +08:00
[Model] Add Internlm2 LoRA support (#5064)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
98f47f2a40
commit
c83919c7a6
@ -182,7 +182,7 @@ Text Generation
|
||||
* - :code:`InternLM2ForCausalLM`
|
||||
- InternLM2
|
||||
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`JAISLMHeadModel`
|
||||
- Jais
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -319,7 +319,21 @@ class InternLM2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"wqkv": ["wqkv"],
|
||||
"gate_up_proj": ["w1", "w3"],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"wqkv",
|
||||
"wo",
|
||||
"gate_up_proj",
|
||||
"w2",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
@ -329,8 +343,12 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = model_type(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.output = ParallelLMHead(config.vocab_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user