mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 19:57:07 +08:00
[model] Reduce medusa weight (#10454)
Signed-off-by: skylee-01 <497627264@qq.com>
This commit is contained in:
parent
ed701ca963
commit
343041c4c4
@ -61,14 +61,25 @@ class Medusa(nn.Module):
|
|||||||
self.truncated_vocab_size = config.truncated_vocab_size
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||||
|
|
||||||
self.lm_heads = nn.ModuleList([
|
if getattr(config, "original_lm_head", False):
|
||||||
ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
self.unpadded_vocab_size,
|
self.unpadded_vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=self.truncated_vocab_size,
|
org_num_embeddings=self.truncated_vocab_size,
|
||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
) for _ in range(self.config.num_heads)
|
)
|
||||||
])
|
self.lm_heads = [
|
||||||
|
self.lm_head for _ in range(self.config.num_heads)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.lm_heads = nn.ModuleList([
|
||||||
|
ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=self.truncated_vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
) for _ in range(self.config.num_heads)
|
||||||
|
])
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
@ -172,6 +183,9 @@ class Medusa(nn.Module):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
elif name in params_dict:
|
elif name in params_dict:
|
||||||
weights_map[name] = loaded_weight
|
weights_map[name] = loaded_weight
|
||||||
|
elif (getattr(self.config, "original_lm_head", False)
|
||||||
|
and name == "lm_heads.0.weight"):
|
||||||
|
weights_map["lm_head.weight"] = loaded_weight
|
||||||
|
|
||||||
for name, loaded_weight in weights_map.items():
|
for name, loaded_weight in weights_map.items():
|
||||||
if "lm_head" in name and self.token_map is not None and\
|
if "lm_head" in name and self.token_map is not None and\
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user