Add bf16 model

This commit is contained in:
kijai 2024-10-24 00:00:38 +03:00
parent 754085eaf5
commit 1cd5409295
2 changed files with 2 additions and 8 deletions

View File

@ -127,7 +127,6 @@ class AsymmetricAttention(nn.Module):
): ):
# Pre-norm for visual features # Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
#print("x in attn", x.dtype, x.device)
# Process visual features # Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
@ -139,13 +138,6 @@ class AsymmetricAttention(nn.Module):
# Process text features # Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
#print("y in attn", y.dtype, y.device)
#print(q_y.dtype, q_y.device)
#print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device)
# self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype)
# self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype)
# self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype)
# self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype)
q_y = self.q_norm_y(q_y) q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y) k_y = self.k_norm_y(k_y)

View File

@ -48,6 +48,8 @@ class DownloadAndLoadMochiModel:
"model": ( "model": (
[ [
"mochi_preview_dit_fp8_e4m3fn.safetensors", "mochi_preview_dit_fp8_e4m3fn.safetensors",
"mochi_preview_dit_bf16.safetensors",
], ],
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", }, {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
), ),