mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
correct Tora fuser dtype
I think...
This commit is contained in:
parent
15bf8f51ab
commit
2cc521062f
@ -291,8 +291,7 @@ class CogVideoXBlock(nn.Module):
|
||||
if video_flow_feature is not None:
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
T = norm_hidden_states.shape[1] // H // W
|
||||
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -672,7 +672,7 @@ class DownloadAndLoadToraModel:
|
||||
fuser_list.load_state_dict(fuser_sd)
|
||||
for module in fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16).to(device)
|
||||
param.data = param.data.to(torch.bfloat16).to(device)
|
||||
del fuser_sd
|
||||
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user