mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-08 20:34:32 +08:00
Add dtype=torch.float32
This commit is contained in:
parent
f6e34dd267
commit
d6d7cc9860
@ -185,7 +185,7 @@ class Linear(nn.Module):
|
||||
else:
|
||||
self.register_parameter("scale", None)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
@ -558,7 +558,7 @@ class Gate(nn.Module):
|
||||
self.score_func = args.score_func
|
||||
self.route_scale = args.route_scale
|
||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user