mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-23 03:44:26 +08:00
correct Tora fuser dtype
I think...
This commit is contained in:
parent
15bf8f51ab
commit
2cc521062f
@ -291,8 +291,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
if video_flow_feature is not None:
|
if video_flow_feature is not None:
|
||||||
H, W = video_flow_feature.shape[-2:]
|
H, W = video_flow_feature.shape[-2:]
|
||||||
T = norm_hidden_states.shape[1] // H // W
|
T = norm_hidden_states.shape[1] // H // W
|
||||||
|
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
|
||||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||||
del h, fuser
|
del h, fuser
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -672,7 +672,7 @@ class DownloadAndLoadToraModel:
|
|||||||
fuser_list.load_state_dict(fuser_sd)
|
fuser_list.load_state_dict(fuser_sd)
|
||||||
for module in fuser_list:
|
for module in fuser_list:
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.data = param.data.to(torch.float16).to(device)
|
param.data = param.data.to(torch.bfloat16).to(device)
|
||||||
del fuser_sd
|
del fuser_sd
|
||||||
|
|
||||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user