[Bugfix] compute ffn output param order

This commit is contained in:
i-yuanyukun 2025-12-18 17:03:15 +08:00
parent 26ddfa299c
commit 8276320a8a

View File

@ -351,7 +351,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
hidden_states, dim=0 hidden_states, dim=0
) )
ffn_output = self.model.compute_ffn_output( ffn_output = self.model.compute_ffn_output(
current_layer_idx, gathered_hidden_states gathered_hidden_states, current_layer_idx
) )
# Extract the output corresponding to current rank # Extract the output corresponding to current rank
@ -361,7 +361,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
else: else:
# Single TP case # Single TP case
rank_ffn_output = self.model.compute_ffn_output( rank_ffn_output = self.model.compute_ffn_output(
current_layer_idx, hidden_states hidden_states, current_layer_idx
) )
return rank_ffn_output return rank_ffn_output