From 2cc521062f035920b899bf6afce5cea7cbf283bb Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:40:31 +0300 Subject: [PATCH] correct Tora fuser dtype I think... --- custom_cogvideox_transformer_3d.py | 3 +-- nodes.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index daac903..7f0cf3b 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -291,8 +291,7 @@ class CogVideoXBlock(nn.Module): if video_flow_feature is not None: H, W = video_flow_feature.shape[-2:] T = norm_hidden_states.shape[1] // H // W - - h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16) + h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) h = fuser(h, video_flow_feature.to(h), T=T) norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) del h, fuser diff --git a/nodes.py b/nodes.py index 2d90617..b14ecd4 100644 --- a/nodes.py +++ b/nodes.py @@ -672,7 +672,7 @@ class DownloadAndLoadToraModel: fuser_list.load_state_dict(fuser_sd) for module in fuser_list: for param in module.parameters(): - param.data = param.data.to(torch.float16).to(device) + param.data = param.data.to(torch.bfloat16).to(device) del fuser_sd traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")