mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
some experimental optimizations
This commit is contained in:
parent
75e98906a3
commit
bececf0189
@ -93,9 +93,14 @@ class CogVideoXAttnProcessor2_0:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
else:
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@ -240,13 +245,14 @@ class CogVideoXBlock(nn.Module):
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
) -> torch.Tensor:
|
||||
|
||||
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
#print("norm_hidden_states in block: ", norm_hidden_states.shape) #torch.Size([2, 3200, 3072])
|
||||
|
||||
# Tora Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
@ -587,13 +593,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# 2. Patch embedding
|
||||
p = self.config.patch_size
|
||||
p_t = self.config.patch_size_t
|
||||
|
||||
#print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90])
|
||||
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
#print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072])
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
#print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072])
|
||||
|
||||
if self.use_fastercache:
|
||||
self.fastercache_counter+=1
|
||||
|
||||
@ -89,13 +89,13 @@ class DownloadAndLoadCogVideoModel:
|
||||
"precision": (["fp16", "fp32", "bf16"],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||
),
|
||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||
"lora": ("COGLORA", {"default": None}),
|
||||
"compile_args":("COMPILEARGS", ),
|
||||
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
|
||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||
}
|
||||
}
|
||||
@ -111,10 +111,11 @@ class DownloadAndLoadCogVideoModel:
|
||||
attention_mode="sdpa", load_device="main_device"):
|
||||
|
||||
if precision == "fp16" and "1.5" in model:
|
||||
raise ValueError("1.5 models do not work in fp16")
|
||||
raise ValueError("1.5 models do not currently work in fp16")
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
manual_offloading = True
|
||||
transformer_load_device = device if load_device == "main_device" else offload_device
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -189,7 +190,6 @@ class DownloadAndLoadCogVideoModel:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||
|
||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||
transformer.attention_mode = attention_mode
|
||||
|
||||
if "1.5" in model:
|
||||
transformer.config.sample_height = 300
|
||||
@ -202,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
scheduler_config = json.load(f)
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
#VAE
|
||||
# VAE
|
||||
if "Fun" in model:
|
||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
if "Pose" in model:
|
||||
@ -263,13 +263,75 @@ class DownloadAndLoadCogVideoModel:
|
||||
if "1.5" in model:
|
||||
params_to_keep.update({"ff"}) #otherwise NaNs
|
||||
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
|
||||
|
||||
elif "torchao" in fp8_transformer:
|
||||
try:
|
||||
from torchao.quantization import (
|
||||
quantize_,
|
||||
fpx_weight_only,
|
||||
float8_dynamic_activation_float8_weight
|
||||
)
|
||||
except:
|
||||
raise ImportError("torchao is not installed, please install torchao to use fp8dq")
|
||||
|
||||
def filter_fn(module: nn.Module, fqn: str) -> bool:
|
||||
target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
|
||||
if any(sub in fqn for sub in target_submodules):
|
||||
return isinstance(module, nn.Linear)
|
||||
return False
|
||||
|
||||
if "fp6" in fp8_transformer: #slower for some reason on 4090
|
||||
quant_func = fpx_weight_only(3, 2)
|
||||
elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled
|
||||
quant_func = float8_dynamic_activation_float8_weight()
|
||||
|
||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||
if "CogVideoXBlock" in str(block):
|
||||
quantize_(block, quant_func, filter_fn=filter_fn)
|
||||
|
||||
manual_offloading = False # to disable manual .to(device) calls
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
manual_offloading = False
|
||||
|
||||
# CogVideoXBlock(
|
||||
# (norm1): CogVideoXLayerNormZero(
|
||||
# (silu): SiLU()
|
||||
# (linear): Linear(in_features=512, out_features=18432, bias=True)
|
||||
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||
# )
|
||||
# (attn1): Attention(
|
||||
# (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
|
||||
# (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
|
||||
# (to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
# (to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
# (to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
# (to_out): ModuleList(
|
||||
# (0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
# (1): Dropout(p=0.0, inplace=False)
|
||||
# )
|
||||
# )
|
||||
# (norm2): CogVideoXLayerNormZero(
|
||||
# (silu): SiLU()
|
||||
# (linear): Linear(in_features=512, out_features=18432, bias=True)
|
||||
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||
# )
|
||||
# (ff): FeedForward(
|
||||
# (net): ModuleList(
|
||||
# (0): GELU(
|
||||
# (proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
# )
|
||||
# (1): Dropout(p=0.0, inplace=False)
|
||||
# (2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
# (3): Dropout(p=0.0, inplace=False)
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
|
||||
# compilation
|
||||
if compile == "torch":
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
#pipe.transformer.to(memory_format=torch.channels_last)
|
||||
if compile_args is not None:
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||
@ -279,7 +341,16 @@ class DownloadAndLoadCogVideoModel:
|
||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||
if "CogVideoXBlock" in str(block):
|
||||
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||
|
||||
|
||||
transformer.attention_mode = attention_mode
|
||||
|
||||
if "fused" in attention_mode:
|
||||
from diffusers.models.attention import Attention
|
||||
transformer.fuse_qkv_projections = True
|
||||
for module in transformer.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
elif compile == "onediff":
|
||||
from onediffx import compile_pipe
|
||||
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
|
||||
@ -298,8 +369,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
"base_path": base_path,
|
||||
"onediff": True if compile == "onediff" else False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
"manual_offloading": manual_offloading,
|
||||
"scheduler_config": scheduler_config,
|
||||
"model_name": model
|
||||
"model_name": model,
|
||||
}
|
||||
|
||||
return (pipeline,)
|
||||
@ -515,7 +587,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
"onediff": False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
"scheduler_config": scheduler_config,
|
||||
"model_name": model
|
||||
"model_name": model,
|
||||
"manual_offloading": True,
|
||||
}
|
||||
|
||||
return (pipeline,)
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -819,7 +819,7 @@ class CogVideoSampler:
|
||||
dtype = pipeline["dtype"]
|
||||
scheduler_config = pipeline["scheduler_config"]
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
|
||||
pipe.transformer.to(device)
|
||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
@ -890,7 +890,7 @@ class CogVideoSampler:
|
||||
controlnet=controlnet,
|
||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||
)
|
||||
if not pipeline["cpu_offloading"]:
|
||||
if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
|
||||
if fastercache is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user