mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
some experimental optimizations
This commit is contained in:
parent
75e98906a3
commit
bececf0189
@ -93,9 +93,14 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
|
||||||
key = attn.to_k(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
value = attn.to_v(hidden_states)
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
else:
|
||||||
|
qkv = attn.to_qkv(hidden_states)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
inner_dim = key.shape[-1]
|
||||||
head_dim = inner_dim // attn.heads
|
head_dim = inner_dim // attn.heads
|
||||||
@ -240,13 +245,14 @@ class CogVideoXBlock(nn.Module):
|
|||||||
fastercache_start_step=15,
|
fastercache_start_step=15,
|
||||||
fastercache_device="cuda:0",
|
fastercache_device="cuda:0",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
# norm & modulate
|
# norm & modulate
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
|
#print("norm_hidden_states in block: ", norm_hidden_states.shape) #torch.Size([2, 3200, 3072])
|
||||||
|
|
||||||
# Tora Motion-guidance Fuser
|
# Tora Motion-guidance Fuser
|
||||||
if video_flow_feature is not None:
|
if video_flow_feature is not None:
|
||||||
@ -587,13 +593,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
# 2. Patch embedding
|
# 2. Patch embedding
|
||||||
p = self.config.patch_size
|
p = self.config.patch_size
|
||||||
p_t = self.config.patch_size_t
|
p_t = self.config.patch_size_t
|
||||||
|
|
||||||
|
#print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90])
|
||||||
|
|
||||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||||
|
#print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072])
|
||||||
hidden_states = self.embedding_dropout(hidden_states)
|
hidden_states = self.embedding_dropout(hidden_states)
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.shape[1]
|
text_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
#print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072])
|
||||||
|
|
||||||
if self.use_fastercache:
|
if self.use_fastercache:
|
||||||
self.fastercache_counter+=1
|
self.fastercache_counter+=1
|
||||||
|
|||||||
@ -89,13 +89,13 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"precision": (["fp16", "fp32", "bf16"],
|
"precision": (["fp16", "fp32", "bf16"],
|
||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||||
),
|
),
|
||||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
"fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,10 +111,11 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
attention_mode="sdpa", load_device="main_device"):
|
attention_mode="sdpa", load_device="main_device"):
|
||||||
|
|
||||||
if precision == "fp16" and "1.5" in model:
|
if precision == "fp16" and "1.5" in model:
|
||||||
raise ValueError("1.5 models do not work in fp16")
|
raise ValueError("1.5 models do not currently work in fp16")
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
manual_offloading = True
|
||||||
transformer_load_device = device if load_device == "main_device" else offload_device
|
transformer_load_device = device if load_device == "main_device" else offload_device
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
@ -189,7 +190,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||||
|
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
transformer.attention_mode = attention_mode
|
|
||||||
|
|
||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
transformer.config.sample_height = 300
|
transformer.config.sample_height = 300
|
||||||
@ -202,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
scheduler_config = json.load(f)
|
scheduler_config = json.load(f)
|
||||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||||
|
|
||||||
#VAE
|
# VAE
|
||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
if "Pose" in model:
|
if "Pose" in model:
|
||||||
@ -263,13 +263,75 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
params_to_keep.update({"ff"}) #otherwise NaNs
|
params_to_keep.update({"ff"}) #otherwise NaNs
|
||||||
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
|
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
|
||||||
|
|
||||||
|
elif "torchao" in fp8_transformer:
|
||||||
|
try:
|
||||||
|
from torchao.quantization import (
|
||||||
|
quantize_,
|
||||||
|
fpx_weight_only,
|
||||||
|
float8_dynamic_activation_float8_weight
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
raise ImportError("torchao is not installed, please install torchao to use fp8dq")
|
||||||
|
|
||||||
|
def filter_fn(module: nn.Module, fqn: str) -> bool:
|
||||||
|
target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
|
||||||
|
if any(sub in fqn for sub in target_submodules):
|
||||||
|
return isinstance(module, nn.Linear)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "fp6" in fp8_transformer: #slower for some reason on 4090
|
||||||
|
quant_func = fpx_weight_only(3, 2)
|
||||||
|
elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled
|
||||||
|
quant_func = float8_dynamic_activation_float8_weight()
|
||||||
|
|
||||||
|
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||||
|
if "CogVideoXBlock" in str(block):
|
||||||
|
quantize_(block, quant_func, filter_fn=filter_fn)
|
||||||
|
|
||||||
|
manual_offloading = False # to disable manual .to(device) calls
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
manual_offloading = False
|
||||||
|
|
||||||
|
# CogVideoXBlock(
|
||||||
|
# (norm1): CogVideoXLayerNormZero(
|
||||||
|
# (silu): SiLU()
|
||||||
|
# (linear): Linear(in_features=512, out_features=18432, bias=True)
|
||||||
|
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||||
|
# )
|
||||||
|
# (attn1): Attention(
|
||||||
|
# (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
|
||||||
|
# (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
|
||||||
|
# (to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||||
|
# (to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||||
|
# (to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||||
|
# (to_out): ModuleList(
|
||||||
|
# (0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||||
|
# (1): Dropout(p=0.0, inplace=False)
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# (norm2): CogVideoXLayerNormZero(
|
||||||
|
# (silu): SiLU()
|
||||||
|
# (linear): Linear(in_features=512, out_features=18432, bias=True)
|
||||||
|
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||||
|
# )
|
||||||
|
# (ff): FeedForward(
|
||||||
|
# (net): ModuleList(
|
||||||
|
# (0): GELU(
|
||||||
|
# (proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||||
|
# )
|
||||||
|
# (1): Dropout(p=0.0, inplace=False)
|
||||||
|
# (2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||||
|
# (3): Dropout(p=0.0, inplace=False)
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
# compilation
|
# compilation
|
||||||
if compile == "torch":
|
if compile == "torch":
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
#pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||||
@ -279,7 +341,16 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||||
if "CogVideoXBlock" in str(block):
|
if "CogVideoXBlock" in str(block):
|
||||||
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||||
|
|
||||||
|
transformer.attention_mode = attention_mode
|
||||||
|
|
||||||
|
if "fused" in attention_mode:
|
||||||
|
from diffusers.models.attention import Attention
|
||||||
|
transformer.fuse_qkv_projections = True
|
||||||
|
for module in transformer.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
elif compile == "onediff":
|
elif compile == "onediff":
|
||||||
from onediffx import compile_pipe
|
from onediffx import compile_pipe
|
||||||
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
|
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
|
||||||
@ -298,8 +369,9 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"base_path": base_path,
|
"base_path": base_path,
|
||||||
"onediff": True if compile == "onediff" else False,
|
"onediff": True if compile == "onediff" else False,
|
||||||
"cpu_offloading": enable_sequential_cpu_offload,
|
"cpu_offloading": enable_sequential_cpu_offload,
|
||||||
|
"manual_offloading": manual_offloading,
|
||||||
"scheduler_config": scheduler_config,
|
"scheduler_config": scheduler_config,
|
||||||
"model_name": model
|
"model_name": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
@ -515,7 +587,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
"onediff": False,
|
"onediff": False,
|
||||||
"cpu_offloading": enable_sequential_cpu_offload,
|
"cpu_offloading": enable_sequential_cpu_offload,
|
||||||
"scheduler_config": scheduler_config,
|
"scheduler_config": scheduler_config,
|
||||||
"model_name": model
|
"model_name": model,
|
||||||
|
"manual_offloading": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
|
|||||||
4
nodes.py
4
nodes.py
@ -819,7 +819,7 @@ class CogVideoSampler:
|
|||||||
dtype = pipeline["dtype"]
|
dtype = pipeline["dtype"]
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
|
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
|
||||||
pipe.transformer.to(device)
|
pipe.transformer.to(device)
|
||||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
@ -890,7 +890,7 @@ class CogVideoSampler:
|
|||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||||
)
|
)
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
|
|
||||||
if fastercache is not None:
|
if fastercache is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user