mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:35:01 +08:00
add vocab padding for LLama(Support WizardLM) (#411)
This commit is contained in:
parent
c6dfc3cdbe
commit
7b6ae94059
@ -187,10 +187,9 @@ class LlamaModel(nn.Module):
|
|||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
config.hidden_size,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
@ -228,8 +227,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = LlamaModel(config)
|
self.model = LlamaModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
config.vocab_size,
|
vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
@ -259,6 +259,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
use_np_cache: bool = False):
|
||||||
|
tensor_model_parallel_world_size = (
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
@ -267,6 +269,17 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
param = state_dict[name]
|
||||||
|
# Consider padding in the vocab size.
|
||||||
|
padded_vocab_size = (param.shape[0] *
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
|
extra_rows = torch.empty(num_extra_rows,
|
||||||
|
loaded_weight.shape[1])
|
||||||
|
extra_rows = extra_rows.to(loaded_weight)
|
||||||
|
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for stride_id, att_weight_name in enumerate(
|
for stride_id, att_weight_name in enumerate(
|
||||||
["q_proj", "k_proj", "v_proj"]):
|
["q_proj", "k_proj", "v_proj"]):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user