Add dtype=torch.float32

This commit is contained in:
Tri Dao 2025-08-25 11:47:09 -07:00
parent f6e34dd267
commit d6d7cc9860

View File

@ -185,7 +185,7 @@ class Linear(nn.Module):
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32))
else:
self.register_parameter("bias", None)
@ -558,7 +558,7 @@ class Gate(nn.Module):
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""