[Perf] Optimize memory peak during EAGLE model loading. (#24585)

Signed-off-by: Chen Ding <candy.dc@alibaba-inc.com>
This commit is contained in:
Chen Ding 2025-09-19 11:31:16 +08:00 committed by GitHub
parent 6d8246aaff
commit 1a0a04dae9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 26 deletions

View File

@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))

View File

@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
def transform(inputs):
name, loaded_weight = inputs
name, weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
return name, weight
loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
)
model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))
def get_input_embeddings(
self,

View File

@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))