support cublas_ops with GGUF

pretty big speed boost on 4090 at least, needs this installed:
https://github.com/aredden/torch-cublas-hgemm
This commit is contained in:
kijai 2024-10-26 16:42:25 +03:00
parent 0d15c0bd69
commit f29f739707
5 changed files with 48 additions and 68 deletions

View File

@ -15,6 +15,7 @@ from .layers import (
RMSNorm, RMSNorm,
TimestepEmbedder, TimestepEmbedder,
) )
from .mod_rmsnorm import modulated_rmsnorm from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import ( from .residual_tanh_gated_rmsnorm import (
residual_tanh_gated_rmsnorm, residual_tanh_gated_rmsnorm,
@ -140,7 +141,7 @@ class AsymmetricAttention(nn.Module):
# Process visual features # Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
assert qkv_x.dtype == torch.bfloat16 #assert qkv_x.dtype == torch.bfloat16
qkv_x = all_to_all_collect_tokens( qkv_x = all_to_all_collect_tokens(
qkv_x, self.num_heads qkv_x, self.num_heads
) # (3, B, N, local_h, head_dim) ) # (3, B, N, local_h, head_dim)

View File

@ -19,7 +19,7 @@ def apply_rotary_emb_qk_real(
Returns: Returns:
torch.Tensor: The input tensor with rotary embeddings applied. torch.Tensor: The input tensor with rotary embeddings applied.
""" """
assert xqk.dtype == torch.bfloat16 #assert xqk.dtype == torch.bfloat16
# Split the last dimension into even and odd parts # Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2] xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2] xqk_odd = xqk[..., 1::2]
@ -30,5 +30,5 @@ def apply_rotary_emb_qk_real(
# Interleave the results back into the original shape # Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
assert out.dtype == torch.bfloat16 #assert out.dtype == torch.bfloat16
return out return out

View File

@ -1,7 +1,7 @@
import json import json
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
#temporary patch to fix bug in Windows #temporary patch to fix torch compile bug in Windows
def patched_write_atomic( def patched_write_atomic(
path_: str, path_: str,
content: Union[str, bytes], content: Union[str, bytes],
@ -24,7 +24,7 @@ def patched_write_atomic(
write_mode = "w" if isinstance(content, str) else "wb" write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content) f.write(content)
shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
os.remove(tmp_path) os.remove(tmp_path)
try: try:
import torch._inductor.codecache import torch._inductor.codecache
@ -37,7 +37,7 @@ import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from einops import rearrange, repeat from einops import rearrange, repeat
from .dit.joint_model.context_parallel import get_cp_rank_size #from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm import comfy.model_management as mm
@ -133,9 +133,11 @@ class T2VSynthMochiModel:
fp8_fastmode: bool = False, fp8_fastmode: bool = False,
attention_mode: str = "sdpa", attention_mode: str = "sdpa",
compile_args: Optional[Dict] = None, compile_args: Optional[Dict] = None,
cublas_ops: Optional[bool] = False,
): ):
super().__init__() super().__init__()
self.device = device self.device = device
self.weight_dtype = weight_dtype
self.offload_device = offload_device self.offload_device = offload_device
logging.info("Initializing model...") logging.info("Initializing model...")
@ -170,7 +172,7 @@ class T2VSynthMochiModel:
import importlib import importlib
importlib.reload(mz_gguf_loader) importlib.reload(mz_gguf_loader)
with mz_gguf_loader.quantize_lazy_load(): with mz_gguf_loader.quantize_lazy_load():
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu") model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu", cublas_ops=cublas_ops)
elif is_accelerate_available: elif is_accelerate_available:
logging.info("Using accelerate to load and assign model weights to device...") logging.info("Using accelerate to load and assign model weights to device...")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
@ -207,51 +209,6 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts)
assert (
0 <= zero_last_n_prompts <= B
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
tokenize_kwargs = dict(
prompt=prompts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH)
caption_input_ids_t5 = t5_toks["input_ids"]
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
del t5_toks
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
if zero_last_n_prompts > 0:
# Zero the last N prompts
caption_input_ids_t5[-zero_last_n_prompts:] = 0
caption_attention_mask_t5[-zero_last_n_prompts:] = False
caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True)
caption_attention_mask_t5 = caption_attention_mask_t5.to(
self.device, non_blocking=True
)
y_mask = [caption_attention_mask_t5]
y_feat = []
self.t5_enc.to(self.device)
y_feat.append(
self.t5_enc(
caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32)
)
self.t5_enc.to(self.offload_device)
# Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
return dict(y_mask=y_mask, y_feat=y_feat)
def get_packed_indices(self, y_mask, *, lT, lW, lH): def get_packed_indices(self, y_mask, *, lT, lW, lH):
patch_size = 2 patch_size = 2
N = lT * lH * lW // (patch_size**2) N = lT * lH * lW // (patch_size**2)
@ -364,7 +321,11 @@ class T2VSynthMochiModel:
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
else: else:
nonlocal sample, sample_null nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out_cond = self.dit(z, sigma, **sample) out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null) out_uncond = self.dit(z, sigma, **sample_null)
@ -390,11 +351,11 @@ class T2VSynthMochiModel:
z = z + dsigma * pred z = z + dsigma * pred
comfy_pbar.update(1) comfy_pbar.update(1)
cp_rank, cp_size = get_cp_rank_size() #cp_rank, cp_size = get_cp_rank_size()
if batch_cfg: if batch_cfg:
z = z[:B] z = z[:B]
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim #z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to(self.offload_device) self.dit.to(self.offload_device, non_blocking=True)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}") logging.info(f"samples shape: {samples.shape}")

View File

@ -2,7 +2,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import gc import torch.nn.functional as F
class quantize_lazy_load(): class quantize_lazy_load():
@ -18,7 +19,17 @@ class quantize_lazy_load():
self.device.__exit__(exc_type, exc_value, traceback) self.device.__exit__(exc_type, exc_value, traceback)
def quantize_load_state_dict(model, state_dict, device="cpu"): def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
if cublas_ops:
try:
from cublas_ops import cublas_half_matmul
linear_ops = cublas_half_matmul
print("Using cublas_ops")
except:
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
else:
linear_ops = F.linear
pass
quant_keys = [] quant_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
if key.endswith(".Q4_0_qweight"): if key.endswith(".Q4_0_qweight"):
@ -35,12 +46,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
linear=module, linear=module,
device=device, device=device,
qtype=qtype, qtype=qtype,
linear_ops=linear_ops
) )
set_op_by_name(model, name, q_linear) set_op_by_name(model, name, q_linear)
model.to_empty(device=device) model.to_empty(device=device)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model.to(device) if linear_ops == cublas_half_matmul:
setattr(model, "cublas_half_matmul", True)
return model return model
@ -57,19 +70,16 @@ def set_op_by_name(layer, name, new_module):
else: else:
setattr(layer, name, new_module) setattr(layer, name, new_module)
import torch.nn.functional as F
class WQLinear_GGUF(nn.Module): class WQLinear_GGUF(nn.Module):
def __init__( def __init__(
self, in_features, out_features, bias, dev, qtype="Q4_0" self, in_features, out_features, bias, dev, qtype, linear_ops
): ):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.qtype = qtype self.qtype = qtype
self.linear_ops = linear_ops
qweight_shape = quant_shape_to_byte_shape( qweight_shape = quant_shape_to_byte_shape(
(out_features, in_features), qtype (out_features, in_features), qtype
@ -99,6 +109,7 @@ class WQLinear_GGUF(nn.Module):
cls, linear, cls, linear,
device="cpu", device="cpu",
qtype="Q4_0", qtype="Q4_0",
linear_ops=F.linear
): ):
q_linear = cls( q_linear = cls(
linear.in_features, linear.in_features,
@ -106,6 +117,7 @@ class WQLinear_GGUF(nn.Module):
linear.bias is not None, linear.bias is not None,
device, device,
qtype=qtype, qtype=qtype,
linear_ops=linear_ops
) )
return q_linear return q_linear
@ -120,6 +132,7 @@ class WQLinear_GGUF(nn.Module):
) )
) )
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
@ -127,8 +140,11 @@ class WQLinear_GGUF(nn.Module):
x = F.linear(x, dequantize_blocks_Q4_0( x = F.linear(x, dequantize_blocks_Q4_0(
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
elif self.qtype == "Q8_0": elif self.qtype == "Q8_0":
x = F.linear(x, dequantize_blocks_Q8_0( dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
else: else:
raise ValueError(f"Unknown qtype: {self.qtype}") raise ValueError(f"Unknown qtype: {self.qtype}")

View File

@ -72,6 +72,7 @@ class DownloadAndLoadMochiModel:
"optional": { "optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
}, },
} }
@ -81,7 +82,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None): def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -126,7 +127,8 @@ class DownloadAndLoadMochiModel:
weight_dtype=dtype, weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode, attention_mode=attention_mode,
compile_args=compile_args compile_args=compile_args,
cublas_ops=cublas_ops
) )
with (init_empty_weights() if is_accelerate_available else nullcontext()): with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder( vae = Decoder(