mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 23:34:27 +08:00
[Model] Add Internlm2 LoRA support (#5064)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
98f47f2a40
commit
c83919c7a6
@ -182,7 +182,7 @@ Text Generation
|
|||||||
* - :code:`InternLM2ForCausalLM`
|
* - :code:`InternLM2ForCausalLM`
|
||||||
- InternLM2
|
- InternLM2
|
||||||
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
|
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
|
||||||
-
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`JAISLMHeadModel`
|
* - :code:`JAISLMHeadModel`
|
||||||
- Jais
|
- Jais
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -319,7 +319,21 @@ class InternLM2Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"wqkv": ["wqkv"],
|
||||||
|
"gate_up_proj": ["w1", "w3"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"wqkv",
|
||||||
|
"wo",
|
||||||
|
"gate_up_proj",
|
||||||
|
"w2",
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
*,
|
*,
|
||||||
@ -329,8 +343,12 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.model = model_type(vllm_config=vllm_config,
|
self.model = model_type(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.output = ParallelLMHead(config.vocab_size,
|
self.output = ParallelLMHead(config.vocab_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user