diff --git a/inference/model.py b/inference/model.py index c143e97..7539a68 100644 --- a/inference/model.py +++ b/inference/model.py @@ -558,7 +558,7 @@ class Gate(nn.Module): self.score_func = args.score_func self.route_scale = args.route_scale self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) - self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """