From 7b6ae94059a7c7f097f32987489fb03b02cbf613 Mon Sep 17 00:00:00 2001 From: panda Date: Fri, 14 Jul 2023 11:56:22 +0800 Subject: [PATCH] add vocab padding for LLama(Support WizardLM) (#411) --- vllm/model_executor/models/llama.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c623b4148ffb..65800207f6e7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -187,10 +187,9 @@ class LlamaModel(nn.Module): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - perform_initialization=False) + vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) @@ -228,8 +227,9 @@ class LlamaForCausalLM(nn.Module): super().__init__() self.config = config self.model = LlamaModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ColumnParallelLinear(config.hidden_size, - config.vocab_size, + vocab_size, bias=False, gather_output=False, perform_initialization=False) @@ -259,6 +259,8 @@ class LlamaForCausalLM(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -267,6 +269,17 @@ class LlamaForCausalLM(nn.Module): if "rotary_emb.inv_freq" in name: continue + if "embed_tokens" in name or "lm_head" in name: + param = state_dict[name] + # Consider padding in the vocab size. + padded_vocab_size = (param.shape[0] * + tensor_model_parallel_world_size) + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + is_attention_weight = False for stride_id, att_weight_name in enumerate( ["q_proj", "k_proj", "v_proj"]):