This commit is contained in:
kijai 2024-10-26 17:57:02 +03:00
parent ddfb3a6bf2
commit c5c136cb11
2 changed files with 14 additions and 7 deletions

View File

@ -315,16 +315,17 @@ class T2VSynthMochiModel:
def model_fn(*, z, sigma, cfg_scale):
self.dit.to(self.device)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
if batch_cfg:
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out = self.dit(z, sigma, **sample_batched)
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
else:
nonlocal sample, sample_null
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)

View File

@ -52,8 +52,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
model.to_empty(device=device)
model.load_state_dict(state_dict, strict=False)
if linear_ops == cublas_half_matmul:
setattr(model, "cublas_half_matmul", True)
try:
if linear_ops == cublas_half_matmul:
setattr(model, "cublas_half_matmul", True)
else:
setattr(model, "cublas_half_matmul", False)
except:
setattr(model, "cublas_half_matmul", False)
pass
return model