Fix DeciLM (#2883)

This commit is contained in:
Philipp Moritz 2024-02-14 22:29:57 -08:00 committed by GitHub
parent d7afab6d3a
commit 4f2ad11135
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,6 +28,7 @@ from typing import Optional
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader,
@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
self,
config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, linear_method=linear_method)
super().__init__(config=config,
linear_method=linear_method,
lora_config=lora_config)
def load_weights(self,
model_name_or_path: str,