some experimental optimizations

This commit is contained in:
kijai 2024-11-16 17:32:31 +02:00
parent 75e98906a3
commit bececf0189
3 changed files with 98 additions and 15 deletions

View File

@ -93,9 +93,14 @@ class CogVideoXAttnProcessor2_0:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states) if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
key = attn.to_k(hidden_states) query = attn.to_q(hidden_states)
value = attn.to_v(hidden_states) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
else:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
@ -240,13 +245,14 @@ class CogVideoXBlock(nn.Module):
fastercache_start_step=15, fastercache_start_step=15,
fastercache_device="cuda:0", fastercache_device="cuda:0",
) -> torch.Tensor: ) -> torch.Tensor:
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
# norm & modulate # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
) )
#print("norm_hidden_states in block: ", norm_hidden_states.shape) #torch.Size([2, 3200, 3072])
# Tora Motion-guidance Fuser # Tora Motion-guidance Fuser
if video_flow_feature is not None: if video_flow_feature is not None:
@ -587,13 +593,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Patch embedding # 2. Patch embedding
p = self.config.patch_size p = self.config.patch_size
p_t = self.config.patch_size_t p_t = self.config.patch_size_t
#print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90])
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
#print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072])
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1] text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length] encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:] hidden_states = hidden_states[:, text_seq_length:]
#print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072])
if self.use_fastercache: if self.use_fastercache:
self.fastercache_counter+=1 self.fastercache_counter+=1

View File

@ -89,13 +89,13 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"], "precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
), ),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), "fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}), "lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ), "compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
} }
} }
@ -111,10 +111,11 @@ class DownloadAndLoadCogVideoModel:
attention_mode="sdpa", load_device="main_device"): attention_mode="sdpa", load_device="main_device"):
if precision == "fp16" and "1.5" in model: if precision == "fp16" and "1.5" in model:
raise ValueError("1.5 models do not work in fp16") raise ValueError("1.5 models do not currently work in fp16")
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
manual_offloading = True
transformer_load_device = device if load_device == "main_device" else offload_device transformer_load_device = device if load_device == "main_device" else offload_device
mm.soft_empty_cache() mm.soft_empty_cache()
@ -189,7 +190,6 @@ class DownloadAndLoadCogVideoModel:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device) transformer = transformer.to(dtype).to(transformer_load_device)
transformer.attention_mode = attention_mode
if "1.5" in model: if "1.5" in model:
transformer.config.sample_height = 300 transformer.config.sample_height = 300
@ -202,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
scheduler_config = json.load(f) scheduler_config = json.load(f)
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
#VAE # VAE
if "Fun" in model: if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
if "Pose" in model: if "Pose" in model:
@ -263,13 +263,75 @@ class DownloadAndLoadCogVideoModel:
if "1.5" in model: if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep) convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
elif "torchao" in fp8_transformer:
try:
from torchao.quantization import (
quantize_,
fpx_weight_only,
float8_dynamic_activation_float8_weight
)
except:
raise ImportError("torchao is not installed, please install torchao to use fp8dq")
def filter_fn(module: nn.Module, fqn: str) -> bool:
target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models
if any(sub in fqn for sub in target_submodules):
return isinstance(module, nn.Linear)
return False
if "fp6" in fp8_transformer: #slower for some reason on 4090
quant_func = fpx_weight_only(3, 2)
elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled
quant_func = float8_dynamic_activation_float8_weight()
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
quantize_(block, quant_func, filter_fn=filter_fn)
manual_offloading = False # to disable manual .to(device) calls
if enable_sequential_cpu_offload: if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
manual_offloading = False
# CogVideoXBlock(
# (norm1): CogVideoXLayerNormZero(
# (silu): SiLU()
# (linear): Linear(in_features=512, out_features=18432, bias=True)
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
# )
# (attn1): Attention(
# (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
# (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
# (to_q): Linear(in_features=3072, out_features=3072, bias=True)
# (to_k): Linear(in_features=3072, out_features=3072, bias=True)
# (to_v): Linear(in_features=3072, out_features=3072, bias=True)
# (to_out): ModuleList(
# (0): Linear(in_features=3072, out_features=3072, bias=True)
# (1): Dropout(p=0.0, inplace=False)
# )
# )
# (norm2): CogVideoXLayerNormZero(
# (silu): SiLU()
# (linear): Linear(in_features=512, out_features=18432, bias=True)
# (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
# )
# (ff): FeedForward(
# (net): ModuleList(
# (0): GELU(
# (proj): Linear(in_features=3072, out_features=12288, bias=True)
# )
# (1): Dropout(p=0.0, inplace=False)
# (2): Linear(in_features=12288, out_features=3072, bias=True)
# (3): Dropout(p=0.0, inplace=False)
# )
# )
# )
# compilation # compilation
if compile == "torch": if compile == "torch":
pipe.transformer.to(memory_format=torch.channels_last) #pipe.transformer.to(memory_format=torch.channels_last)
if compile_args is not None: if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(pipe.transformer.transformer_blocks): for i, block in enumerate(pipe.transformer.transformer_blocks):
@ -279,7 +341,16 @@ class DownloadAndLoadCogVideoModel:
for i, block in enumerate(pipe.transformer.transformer_blocks): for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block): if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
transformer.attention_mode = attention_mode
if "fused" in attention_mode:
from diffusers.models.attention import Attention
transformer.fuse_qkv_projections = True
for module in transformer.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
elif compile == "onediff": elif compile == "onediff":
from onediffx import compile_pipe from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
@ -298,8 +369,9 @@ class DownloadAndLoadCogVideoModel:
"base_path": base_path, "base_path": base_path,
"onediff": True if compile == "onediff" else False, "onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
"manual_offloading": manual_offloading,
"scheduler_config": scheduler_config, "scheduler_config": scheduler_config,
"model_name": model "model_name": model,
} }
return (pipeline,) return (pipeline,)
@ -515,7 +587,8 @@ class DownloadAndLoadCogVideoGGUFModel:
"onediff": False, "onediff": False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
"scheduler_config": scheduler_config, "scheduler_config": scheduler_config,
"model_name": model "model_name": model,
"manual_offloading": True,
} }
return (pipeline,) return (pipeline,)

View File

@ -819,7 +819,7 @@ class CogVideoSampler:
dtype = pipeline["dtype"] dtype = pipeline["dtype"]
scheduler_config = pipeline["scheduler_config"] scheduler_config = pipeline["scheduler_config"]
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
pipe.transformer.to(device) pipe.transformer.to(device)
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
@ -890,7 +890,7 @@ class CogVideoSampler:
controlnet=controlnet, controlnet=controlnet,
tora=tora_trajectory if tora_trajectory is not None else None, tora=tora_trajectory if tora_trajectory is not None else None,
) )
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]:
pipe.transformer.to(offload_device) pipe.transformer.to(offload_device)
if fastercache is not None: if fastercache is not None: