[Bugfix] compute ffn output param order

This commit is contained in:
i-yuanyukun 2025-12-18 17:03:15 +08:00
parent 26ddfa299c
commit 8276320a8a

View File

@ -351,7 +351,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
hidden_states, dim=0
)
ffn_output = self.model.compute_ffn_output(
current_layer_idx, gathered_hidden_states
gathered_hidden_states, current_layer_idx
)
# Extract the output corresponding to current rank
@ -361,7 +361,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
current_layer_idx, hidden_states
hidden_states, current_layer_idx
)
return rank_ffn_output