diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index bf592bb..61b632f 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -315,16 +315,17 @@ class T2VSynthMochiModel: def model_fn(*, z, sigma, cfg_scale): self.dit.to(self.device) + if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 if batch_cfg: - with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): out = self.dit(z, sigma, **sample_batched) out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) else: nonlocal sample, sample_null - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): out_cond = self.dit(z, sigma, **sample) out_uncond = self.dit(z, sigma, **sample_null) diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 1b66c18..8182896 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -52,8 +52,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): model.to_empty(device=device) model.load_state_dict(state_dict, strict=False) - if linear_ops == cublas_half_matmul: - setattr(model, "cublas_half_matmul", True) + try: + if linear_ops == cublas_half_matmul: + setattr(model, "cublas_half_matmul", True) + else: + setattr(model, "cublas_half_matmul", False) + except: + setattr(model, "cublas_half_matmul", False) + pass return model