Add dtype=torch.float32

This commit is contained in:
Tri Dao 2025-08-25 11:47:09 -07:00
parent f6e34dd267
commit d6d7cc9860

View File

@ -185,7 +185,7 @@ class Linear(nn.Module):
else: else:
self.register_parameter("scale", None) self.register_parameter("scale", None)
if bias: if bias:
self.bias = nn.Parameter(torch.empty(out_features)) self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32))
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
@ -558,7 +558,7 @@ class Gate(nn.Module):
self.score_func = args.score_func self.score_func = args.score_func
self.route_scale = args.route_scale self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """