support cublas_ops with GGUF
pretty big speed boost on 4090 at least, needs this installed: https://github.com/aredden/torch-cublas-hgemm
This commit is contained in:
parent
0d15c0bd69
commit
f29f739707
@ -15,6 +15,7 @@ from .layers import (
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
from .mod_rmsnorm import modulated_rmsnorm
|
||||
from .residual_tanh_gated_rmsnorm import (
|
||||
residual_tanh_gated_rmsnorm,
|
||||
@ -140,7 +141,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Process visual features
|
||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
assert qkv_x.dtype == torch.bfloat16
|
||||
#assert qkv_x.dtype == torch.bfloat16
|
||||
qkv_x = all_to_all_collect_tokens(
|
||||
qkv_x, self.num_heads
|
||||
) # (3, B, N, local_h, head_dim)
|
||||
|
||||
@ -19,7 +19,7 @@ def apply_rotary_emb_qk_real(
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
assert xqk.dtype == torch.bfloat16
|
||||
#assert xqk.dtype == torch.bfloat16
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
@ -30,5 +30,5 @@ def apply_rotary_emb_qk_real(
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
assert out.dtype == torch.bfloat16
|
||||
#assert out.dtype == torch.bfloat16
|
||||
return out
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
#temporary patch to fix bug in Windows
|
||||
#temporary patch to fix torch compile bug in Windows
|
||||
def patched_write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
@ -24,7 +24,7 @@ def patched_write_atomic(
|
||||
write_mode = "w" if isinstance(content, str) else "wb"
|
||||
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||
f.write(content)
|
||||
shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files
|
||||
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
|
||||
os.remove(tmp_path)
|
||||
try:
|
||||
import torch._inductor.codecache
|
||||
@ -37,7 +37,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
#from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
import comfy.model_management as mm
|
||||
@ -133,9 +133,11 @@ class T2VSynthMochiModel:
|
||||
fp8_fastmode: bool = False,
|
||||
attention_mode: str = "sdpa",
|
||||
compile_args: Optional[Dict] = None,
|
||||
cublas_ops: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.weight_dtype = weight_dtype
|
||||
self.offload_device = offload_device
|
||||
|
||||
logging.info("Initializing model...")
|
||||
@ -170,7 +172,7 @@ class T2VSynthMochiModel:
|
||||
import importlib
|
||||
importlib.reload(mz_gguf_loader)
|
||||
with mz_gguf_loader.quantize_lazy_load():
|
||||
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
|
||||
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu", cublas_ops=cublas_ops)
|
||||
elif is_accelerate_available:
|
||||
logging.info("Using accelerate to load and assign model weights to device...")
|
||||
for name, param in model.named_parameters():
|
||||
@ -207,51 +209,6 @@ class T2VSynthMochiModel:
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||
B = len(prompts)
|
||||
assert (
|
||||
0 <= zero_last_n_prompts <= B
|
||||
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
|
||||
tokenize_kwargs = dict(
|
||||
prompt=prompts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH)
|
||||
caption_input_ids_t5 = t5_toks["input_ids"]
|
||||
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
|
||||
del t5_toks
|
||||
|
||||
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
||||
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
||||
|
||||
if zero_last_n_prompts > 0:
|
||||
# Zero the last N prompts
|
||||
caption_input_ids_t5[-zero_last_n_prompts:] = 0
|
||||
caption_attention_mask_t5[-zero_last_n_prompts:] = False
|
||||
|
||||
caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True)
|
||||
caption_attention_mask_t5 = caption_attention_mask_t5.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
y_mask = [caption_attention_mask_t5]
|
||||
y_feat = []
|
||||
|
||||
self.t5_enc.to(self.device)
|
||||
y_feat.append(
|
||||
self.t5_enc(
|
||||
caption_input_ids_t5, caption_attention_mask_t5
|
||||
).last_hidden_state.detach().to(torch.float32)
|
||||
)
|
||||
self.t5_enc.to(self.offload_device)
|
||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
||||
return dict(y_mask=y_mask, y_feat=y_feat)
|
||||
|
||||
def get_packed_indices(self, y_mask, *, lT, lW, lH):
|
||||
patch_size = 2
|
||||
N = lT * lH * lW // (patch_size**2)
|
||||
@ -364,7 +321,11 @@ class T2VSynthMochiModel:
|
||||
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
else:
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
|
||||
@ -390,11 +351,11 @@ class T2VSynthMochiModel:
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
#cp_rank, cp_size = get_cp_rank_size()
|
||||
if batch_cfg:
|
||||
z = z[:B]
|
||||
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||
self.dit.to(self.offload_device)
|
||||
#z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||
self.dit.to(self.offload_device, non_blocking=True)
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
logging.info(f"samples shape: {samples.shape}")
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import gc
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
class quantize_lazy_load():
|
||||
@ -18,7 +19,17 @@ class quantize_lazy_load():
|
||||
self.device.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||
def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
|
||||
if cublas_ops:
|
||||
try:
|
||||
from cublas_ops import cublas_half_matmul
|
||||
linear_ops = cublas_half_matmul
|
||||
print("Using cublas_ops")
|
||||
except:
|
||||
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
|
||||
else:
|
||||
linear_ops = F.linear
|
||||
pass
|
||||
quant_keys = []
|
||||
for key in state_dict.keys():
|
||||
if key.endswith(".Q4_0_qweight"):
|
||||
@ -35,12 +46,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||
linear=module,
|
||||
device=device,
|
||||
qtype=qtype,
|
||||
linear_ops=linear_ops
|
||||
)
|
||||
set_op_by_name(model, name, q_linear)
|
||||
|
||||
model.to_empty(device=device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(device)
|
||||
if linear_ops == cublas_half_matmul:
|
||||
setattr(model, "cublas_half_matmul", True)
|
||||
return model
|
||||
|
||||
|
||||
@ -57,19 +70,16 @@ def set_op_by_name(layer, name, new_module):
|
||||
else:
|
||||
setattr(layer, name, new_module)
|
||||
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class WQLinear_GGUF(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, bias, dev, qtype="Q4_0"
|
||||
self, in_features, out_features, bias, dev, qtype, linear_ops
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.qtype = qtype
|
||||
self.linear_ops = linear_ops
|
||||
|
||||
qweight_shape = quant_shape_to_byte_shape(
|
||||
(out_features, in_features), qtype
|
||||
@ -99,6 +109,7 @@ class WQLinear_GGUF(nn.Module):
|
||||
cls, linear,
|
||||
device="cpu",
|
||||
qtype="Q4_0",
|
||||
linear_ops=F.linear
|
||||
):
|
||||
q_linear = cls(
|
||||
linear.in_features,
|
||||
@ -106,6 +117,7 @@ class WQLinear_GGUF(nn.Module):
|
||||
linear.bias is not None,
|
||||
device,
|
||||
qtype=qtype,
|
||||
linear_ops=linear_ops
|
||||
)
|
||||
return q_linear
|
||||
|
||||
@ -120,6 +132,7 @@ class WQLinear_GGUF(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
||||
@ -127,8 +140,11 @@ class WQLinear_GGUF(nn.Module):
|
||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
elif self.qtype == "Q8_0":
|
||||
x = F.linear(x, dequantize_blocks_Q8_0(
|
||||
self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||
|
||||
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -72,6 +72,7 @@ class DownloadAndLoadMochiModel:
|
||||
"optional": {
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -81,7 +82,7 @@ class DownloadAndLoadMochiModel:
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None):
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -126,7 +127,8 @@ class DownloadAndLoadMochiModel:
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode,
|
||||
compile_args=compile_args
|
||||
compile_args=compile_args,
|
||||
cublas_ops=cublas_ops
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user