mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
Fix DeciLM (#2883)
This commit is contained in:
parent
d7afab6d3a
commit
4f2ad11135
@ -28,6 +28,7 @@ from typing import Optional
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||
delattr(config, "num_key_value_heads_per_layer")
|
||||
super().__init__(config=config, linear_method=linear_method)
|
||||
super().__init__(config=config,
|
||||
linear_method=linear_method,
|
||||
lora_config=lora_config)
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user