make torch compile work better

This commit is contained in:
kijai 2024-10-26 02:33:29 +03:00
parent 25f16462aa
commit dcca095743
3 changed files with 35 additions and 37 deletions

View File

@ -52,7 +52,7 @@ class CogVideoXAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
@ -126,7 +126,7 @@ class FusedCogVideoXAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,

View File

@ -408,7 +408,11 @@ class DownloadAndLoadCogVideoModel:
if compile == "torch": if compile == "torch":
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") #pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
print(block)
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff": elif compile == "onediff":
from onediffx import compile_pipe from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
@ -559,10 +563,8 @@ class DownloadAndLoadCogVideoGGUFModel:
convert_fp8_linear(transformer, vae_dtype) convert_fp8_linear(transformer, vae_dtype)
# compilation # compilation
if compile == "torch": for i, block in enumerate(transformer.transformer_blocks):
torch._dynamo.config.suppress_errors = True transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
with open(scheduler_path) as f: with open(scheduler_path) as f:
scheduler_config = json.load(f) scheduler_config = json.load(f)

View File

@ -9,7 +9,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from einops import rearrange
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
@ -42,7 +42,7 @@ class CogVideoXAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
@ -134,7 +134,7 @@ class FusedCogVideoXAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
@ -286,7 +286,7 @@ class CogVideoXBlock(nn.Module):
self.attn_count = 0 self.attn_count = 0
self.last_attn = None self.last_attn = None
self.block_idx = block_idx self.block_idx = block_idx
#@torch.compiler.disable()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -294,6 +294,8 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep=None, timestep=None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
@ -301,7 +303,14 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
) )
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
h = fuser(h, video_flow_feature.to(h), T=T)
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser
# attention # attention
if enable_pab(): if enable_pab():
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
@ -494,6 +503,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.fuser_list = None
# parallel # parallel
#self.parallel_manager = None #self.parallel_manager = None
@ -524,6 +535,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
controlnet_states: torch.Tensor = None, controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0, controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
video_flow_features: Optional[torch.Tensor] = None,
): ):
# if self.parallel_manager.cp_size > 1: # if self.parallel_manager.cp_size > 1:
# ( # (
@ -574,30 +586,14 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
# 4. Transformer blocks # 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=emb, temb=emb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
timestep=timesteps if enable_pab() else None, timestep=timesteps if enable_pab() else None,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
) )
if (controlnet_states is not None) and (i < len(controlnet_states)): if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i] controlnet_states_block = controlnet_states[i]