correct Tora fuser dtype

I think...
This commit is contained in:
kijai 2024-10-22 18:40:31 +03:00
parent 15bf8f51ab
commit 2cc521062f
2 changed files with 2 additions and 3 deletions

View File

@ -291,8 +291,7 @@ class CogVideoXBlock(nn.Module):
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser

View File

@ -672,7 +672,7 @@ class DownloadAndLoadToraModel:
fuser_list.load_state_dict(fuser_sd)
for module in fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16).to(device)
param.data = param.data.to(torch.bfloat16).to(device)
del fuser_sd
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")