mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 15:35:01 +08:00
[Perf] Optimize memory peak during EAGLE model loading. (#24585)
Signed-off-by: Chen Ding <candy.dc@alibaba-inc.com>
This commit is contained in:
parent
6d8246aaff
commit
1a0a04dae9
@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
return name, loaded_weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
loader.load_weights(model_weights.items())
|
||||
loader.load_weights(map(transform, weights))
|
||||
|
||||
@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
name, weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight)
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
return name, weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# lm_head is tied with target model (Llama4ForCausalLM)
|
||||
skip_prefixes=(["lm_head."]),
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
weights = [
|
||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
||||
loader.load_weights(map(transform, weights))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
def transform(inputs):
|
||||
name, loaded_weight = inputs
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
return name, loaded_weight
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
loader.load_weights(model_weights.items())
|
||||
loader.load_weights(map(transform, weights))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user