From 6eaccb7353cfe84d77981da726f6d82a8aefd2be Mon Sep 17 00:00:00 2001 From: Yikang Shen Date: Sun, 12 May 2024 00:27:24 -0400 Subject: [PATCH] [Model] Add support for IBM Granite Code models (#4636) --- vllm/model_executor/models/llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..127e4612b2e40 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -58,15 +58,16 @@ class LlamaMLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QKVParallelLinear] = None, + bias: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, - bias=False, + bias=bias, quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, - bias=False, + bias=bias, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -209,6 +210,7 @@ class LlamaDecoderLayer(nn.Module): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,6 +350,8 @@ class LlamaForCausalLM(nn.Module): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,