From d6d7cc986094939fe3f71207d169fc721b0701e3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 25 Aug 2025 11:47:09 -0700 Subject: [PATCH] Add dtype=torch.float32 --- inference/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/model.py b/inference/model.py index c143e97..aaaed5e 100644 --- a/inference/model.py +++ b/inference/model.py @@ -185,7 +185,7 @@ class Linear(nn.Module): else: self.register_parameter("scale", None) if bias: - self.bias = nn.Parameter(torch.empty(out_features)) + self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.float32)) else: self.register_parameter("bias", None) @@ -558,7 +558,7 @@ class Gate(nn.Module): self.score_func = args.score_func self.route_scale = args.route_scale self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) - self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None + self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """