mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-10 05:15:02 +08:00
Add dtype=torch.float32
This commit is contained in:
parent
f6e34dd267
commit
d6d7cc9860
@ -185,7 +185,7 @@ class Linear(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("scale", None)
|
self.register_parameter("scale", None)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.empty(out_features))
|
self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32))
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
@ -558,7 +558,7 @@ class Gate(nn.Module):
|
|||||||
self.score_func = args.score_func
|
self.score_func = args.score_func
|
||||||
self.route_scale = args.route_scale
|
self.route_scale = args.route_scale
|
||||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user