mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
Fix DeciLM (#2883)
This commit is contained in:
parent
d7afab6d3a
commit
4f2ad11135
@ -28,6 +28,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
|||||||
self,
|
self,
|
||||||
config: Optional[PretrainedConfig] = None,
|
config: Optional[PretrainedConfig] = None,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||||
delattr(config, "num_key_value_heads_per_layer")
|
delattr(config, "num_key_value_heads_per_layer")
|
||||||
super().__init__(config=config, linear_method=linear_method)
|
super().__init__(config=config,
|
||||||
|
linear_method=linear_method,
|
||||||
|
lora_config=lora_config)
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user