fp32 gate bias

This commit is contained in:
Xingkai Yu 2025-08-26 17:39:07 +08:00 committed by GitHub
parent f6e34dd267
commit 4592be48c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -558,7 +558,7 @@ class Gate(nn.Module):
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""