support cublas_ops with GGUF

pretty big speed boost on 4090 at least, needs this installed:
https://github.com/aredden/torch-cublas-hgemm
This commit is contained in:
kijai 2024-10-26 16:42:25 +03:00
parent 0d15c0bd69
commit f29f739707
5 changed files with 48 additions and 68 deletions

View File

@ -15,6 +15,7 @@ from .layers import (
RMSNorm,
TimestepEmbedder,
)
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (
residual_tanh_gated_rmsnorm,
@ -140,7 +141,7 @@ class AsymmetricAttention(nn.Module):
# Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
assert qkv_x.dtype == torch.bfloat16
#assert qkv_x.dtype == torch.bfloat16
qkv_x = all_to_all_collect_tokens(
qkv_x, self.num_heads
) # (3, B, N, local_h, head_dim)

View File

@ -19,7 +19,7 @@ def apply_rotary_emb_qk_real(
Returns:
torch.Tensor: The input tensor with rotary embeddings applied.
"""
assert xqk.dtype == torch.bfloat16
#assert xqk.dtype == torch.bfloat16
# Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2]
@ -30,5 +30,5 @@ def apply_rotary_emb_qk_real(
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
assert out.dtype == torch.bfloat16
#assert out.dtype == torch.bfloat16
return out

View File

@ -1,7 +1,7 @@
import json
from typing import Dict, List, Optional, Union
#temporary patch to fix bug in Windows
#temporary patch to fix torch compile bug in Windows
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
@ -24,7 +24,7 @@ def patched_write_atomic(
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path)
try:
import torch._inductor.codecache
@ -37,7 +37,7 @@ import torch.nn.functional as F
import torch.utils.data
from einops import rearrange, repeat
from .dit.joint_model.context_parallel import get_cp_rank_size
#from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm
@ -133,9 +133,11 @@ class T2VSynthMochiModel:
fp8_fastmode: bool = False,
attention_mode: str = "sdpa",
compile_args: Optional[Dict] = None,
cublas_ops: Optional[bool] = False,
):
super().__init__()
self.device = device
self.weight_dtype = weight_dtype
self.offload_device = offload_device
logging.info("Initializing model...")
@ -170,7 +172,7 @@ class T2VSynthMochiModel:
import importlib
importlib.reload(mz_gguf_loader)
with mz_gguf_loader.quantize_lazy_load():
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu", cublas_ops=cublas_ops)
elif is_accelerate_available:
logging.info("Using accelerate to load and assign model weights to device...")
for name, param in model.named_parameters():
@ -207,51 +209,6 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts)
assert (
0 <= zero_last_n_prompts <= B
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
tokenize_kwargs = dict(
prompt=prompts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH)
caption_input_ids_t5 = t5_toks["input_ids"]
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
del t5_toks
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
if zero_last_n_prompts > 0:
# Zero the last N prompts
caption_input_ids_t5[-zero_last_n_prompts:] = 0
caption_attention_mask_t5[-zero_last_n_prompts:] = False
caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True)
caption_attention_mask_t5 = caption_attention_mask_t5.to(
self.device, non_blocking=True
)
y_mask = [caption_attention_mask_t5]
y_feat = []
self.t5_enc.to(self.device)
y_feat.append(
self.t5_enc(
caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32)
)
self.t5_enc.to(self.offload_device)
# Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
return dict(y_mask=y_mask, y_feat=y_feat)
def get_packed_indices(self, y_mask, *, lT, lW, lH):
patch_size = 2
N = lT * lH * lW // (patch_size**2)
@ -364,7 +321,11 @@ class T2VSynthMochiModel:
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
else:
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
@ -390,11 +351,11 @@ class T2VSynthMochiModel:
z = z + dsigma * pred
comfy_pbar.update(1)
cp_rank, cp_size = get_cp_rank_size()
#cp_rank, cp_size = get_cp_rank_size()
if batch_cfg:
z = z[:B]
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to(self.offload_device)
#z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to(self.offload_device, non_blocking=True)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}")

View File

@ -2,7 +2,8 @@
import torch
import torch.nn as nn
import gc
import torch.nn.functional as F
class quantize_lazy_load():
@ -18,7 +19,17 @@ class quantize_lazy_load():
self.device.__exit__(exc_type, exc_value, traceback)
def quantize_load_state_dict(model, state_dict, device="cpu"):
def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
if cublas_ops:
try:
from cublas_ops import cublas_half_matmul
linear_ops = cublas_half_matmul
print("Using cublas_ops")
except:
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
else:
linear_ops = F.linear
pass
quant_keys = []
for key in state_dict.keys():
if key.endswith(".Q4_0_qweight"):
@ -35,12 +46,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
linear=module,
device=device,
qtype=qtype,
linear_ops=linear_ops
)
set_op_by_name(model, name, q_linear)
model.to_empty(device=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
if linear_ops == cublas_half_matmul:
setattr(model, "cublas_half_matmul", True)
return model
@ -57,19 +70,16 @@ def set_op_by_name(layer, name, new_module):
else:
setattr(layer, name, new_module)
import torch.nn.functional as F
class WQLinear_GGUF(nn.Module):
def __init__(
self, in_features, out_features, bias, dev, qtype="Q4_0"
self, in_features, out_features, bias, dev, qtype, linear_ops
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.qtype = qtype
self.linear_ops = linear_ops
qweight_shape = quant_shape_to_byte_shape(
(out_features, in_features), qtype
@ -99,6 +109,7 @@ class WQLinear_GGUF(nn.Module):
cls, linear,
device="cpu",
qtype="Q4_0",
linear_ops=F.linear
):
q_linear = cls(
linear.in_features,
@ -106,6 +117,7 @@ class WQLinear_GGUF(nn.Module):
linear.bias is not None,
device,
qtype=qtype,
linear_ops=linear_ops
)
return q_linear
@ -120,6 +132,7 @@ class WQLinear_GGUF(nn.Module):
)
)
@torch.no_grad()
def forward(self, x):
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
@ -127,8 +140,11 @@ class WQLinear_GGUF(nn.Module):
x = F.linear(x, dequantize_blocks_Q4_0(
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
elif self.qtype == "Q8_0":
x = F.linear(x, dequantize_blocks_Q8_0(
self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
else:
raise ValueError(f"Unknown qtype: {self.qtype}")

View File

@ -72,6 +72,7 @@ class DownloadAndLoadMochiModel:
"optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
},
}
@ -81,7 +82,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None):
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -126,7 +127,8 @@ class DownloadAndLoadMochiModel:
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
compile_args=compile_args
compile_args=compile_args,
cublas_ops=cublas_ops
)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder(