mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-04-30 21:52:19 +08:00
sageattn 2.0.0 options
This commit is contained in:
parent
1ade29084e
commit
2b211b9d1b
@ -46,9 +46,40 @@ except:
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
def set_attention_func(attention_mode, heads):
|
||||||
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
if attention_mode == "sdpa" or attention_mode == "fused_sdpa":
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "comfy":
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True)
|
||||||
|
return func
|
||||||
|
|
||||||
|
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||||
|
return func
|
||||||
|
|
||||||
def fft(tensor):
|
def fft(tensor):
|
||||||
tensor_fft = torch.fft.fft2(tensor)
|
tensor_fft = torch.fft.fft2(tensor)
|
||||||
@ -67,16 +98,18 @@ def fft(tensor):
|
|||||||
|
|
||||||
return low_freq_fft, high_freq_fft
|
return low_freq_fft, high_freq_fft
|
||||||
|
|
||||||
|
#region Attention
|
||||||
class CogVideoXAttnProcessor2_0:
|
class CogVideoXAttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||||
query and key vectors, but does not include spatial normalization.
|
query and key vectors, but does not include spatial normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, attn_func, attention_mode: Optional[str] = None):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
self.attn_func = attn_func
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -84,7 +117,6 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
attention_mode: Optional[str] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -101,7 +133,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
||||||
|
|
||||||
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
|
if not "fused" in self.attention_mode:
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
key = attn.to_k(hidden_states)
|
key = attn.to_k(hidden_states)
|
||||||
value = attn.to_v(hidden_states)
|
value = attn.to_v(hidden_states)
|
||||||
@ -128,16 +160,10 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
||||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
|
||||||
|
if self.attention_mode != "comfy":
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
elif attention_mode == "sdpa" or attention_mode == "fused_sdpa":
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
elif attention_mode == "comfy":
|
|
||||||
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
@ -203,13 +229,15 @@ class CogVideoXBlock(nn.Module):
|
|||||||
ff_inner_dim: Optional[int] = None,
|
ff_inner_dim: Optional[int] = None,
|
||||||
ff_bias: bool = True,
|
ff_bias: bool = True,
|
||||||
attention_out_bias: bool = True,
|
attention_out_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = "sdpa",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 1. Self Attention
|
# 1. Self Attention
|
||||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||||
|
|
||||||
|
|
||||||
|
attn_func = set_attention_func(attention_mode, num_attention_heads)
|
||||||
|
|
||||||
self.attn1 = Attention(
|
self.attn1 = Attention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
dim_head=attention_head_dim,
|
dim_head=attention_head_dim,
|
||||||
@ -218,7 +246,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
bias=attention_bias,
|
bias=attention_bias,
|
||||||
out_bias=attention_out_bias,
|
out_bias=attention_out_bias,
|
||||||
processor=CogVideoXAttnProcessor2_0(),
|
processor=CogVideoXAttnProcessor2_0(attn_func, attention_mode=attention_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Feed Forward
|
# 2. Feed Forward
|
||||||
@ -247,7 +275,6 @@ class CogVideoXBlock(nn.Module):
|
|||||||
fastercache_counter=0,
|
fastercache_counter=0,
|
||||||
fastercache_start_step=15,
|
fastercache_start_step=15,
|
||||||
fastercache_device="cuda:0",
|
fastercache_device="cuda:0",
|
||||||
attention_mode="sdpa",
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
@ -286,7 +313,6 @@ class CogVideoXBlock(nn.Module):
|
|||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
attention_mode=attention_mode,
|
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||||
@ -298,8 +324,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb
|
||||||
attention_mode=attention_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||||
@ -408,6 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
use_rotary_positional_embeddings: bool = False,
|
use_rotary_positional_embeddings: bool = False,
|
||||||
use_learned_positional_embeddings: bool = False,
|
use_learned_positional_embeddings: bool = False,
|
||||||
patch_bias: bool = True,
|
patch_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = "sdpa",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -461,6 +487,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
activation_fn=activation_fn,
|
activation_fn=activation_fn,
|
||||||
attention_bias=attention_bias,
|
attention_bias=attention_bias,
|
||||||
|
attention_mode=attention_mode,
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
norm_elementwise_affine=norm_elementwise_affine,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
)
|
)
|
||||||
@ -496,73 +523,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
self.fastercache_hf_step = 30
|
self.fastercache_hf_step = 30
|
||||||
self.fastercache_device = "cuda"
|
self.fastercache_device = "cuda"
|
||||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||||
self.attention_mode = "sdpa"
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
|
#region forward
|
||||||
@property
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
|
||||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
|
||||||
indexed by its weight name.
|
|
||||||
"""
|
|
||||||
# set recursively
|
|
||||||
processors = {}
|
|
||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
|
||||||
if hasattr(module, "get_processor"):
|
|
||||||
processors[f"{name}.processor"] = module.get_processor()
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_add_processors(name, module, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
|
||||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
|
||||||
r"""
|
|
||||||
Sets the attention processor to use to compute attention.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
|
||||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
|
||||||
for **all** `Attention` layers.
|
|
||||||
|
|
||||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
|
||||||
processor. This is strongly recommended when setting trainable attention processors.
|
|
||||||
|
|
||||||
"""
|
|
||||||
count = len(self.attn_processors.keys())
|
|
||||||
|
|
||||||
if isinstance(processor, dict) and len(processor) != count:
|
|
||||||
raise ValueError(
|
|
||||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
|
||||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
|
||||||
)
|
|
||||||
|
|
||||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
|
||||||
if hasattr(module, "set_processor"):
|
|
||||||
if not isinstance(processor, dict):
|
|
||||||
module.set_processor(processor)
|
|
||||||
else:
|
|
||||||
module.set_processor(processor.pop(f"{name}.processor"))
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_attn_processor(name, module, processor)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -624,8 +590,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_start_step = self.fastercache_start_step,
|
fastercache_start_step = self.fastercache_start_step,
|
||||||
fastercache_device = self.fastercache_device,
|
fastercache_device = self.fastercache_device
|
||||||
attention_mode = self.attention_mode
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
@ -695,8 +660,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_start_step = self.fastercache_start_step,
|
fastercache_start_step = self.fastercache_start_step,
|
||||||
fastercache_device = self.fastercache_device,
|
fastercache_device = self.fastercache_device
|
||||||
attention_mode = self.attention_mode
|
|
||||||
)
|
)
|
||||||
#has_nan = torch.isnan(hidden_states).any()
|
#has_nan = torch.isnan(hidden_states).any()
|
||||||
#if has_nan:
|
#if has_nan:
|
||||||
|
|||||||
@ -124,7 +124,19 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
|
"attention_mode": ([
|
||||||
|
"sdpa",
|
||||||
|
"fused_sdpa",
|
||||||
|
"sageattn",
|
||||||
|
"fused_sageattn",
|
||||||
|
"sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"comfy"
|
||||||
|
], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -139,11 +151,18 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
||||||
attention_mode="sdpa", load_device="main_device"):
|
attention_mode="sdpa", load_device="main_device"):
|
||||||
|
|
||||||
|
transformer = None
|
||||||
|
|
||||||
if "sage" in attention_mode:
|
if "sage" in attention_mode:
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Can't import SageAttention: {str(e)}")
|
raise ValueError(f"Can't import SageAttention: {str(e)}")
|
||||||
|
if "qk_int8" in attention_mode:
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}")
|
||||||
|
|
||||||
if precision == "fp16" and "1.5" in model:
|
if precision == "fp16" and "1.5" in model:
|
||||||
raise ValueError("1.5 models do not currently work in fp16")
|
raise ValueError("1.5 models do not currently work in fp16")
|
||||||
@ -218,7 +237,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
@ -291,7 +310,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
for module in pipe.transformer.modules():
|
for module in pipe.transformer.modules():
|
||||||
if isinstance(module, Attention):
|
if isinstance(module, Attention):
|
||||||
module.fuse_projections(fuse=True)
|
module.fuse_projections(fuse=True)
|
||||||
pipe.transformer.attention_mode = attention_mode
|
|
||||||
|
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
@ -515,7 +533,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
else:
|
else:
|
||||||
transformer_config["in_channels"] = 16
|
transformer_config["in_channels"] = 16
|
||||||
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
|
||||||
cast_dtype = vae_dtype
|
cast_dtype = vae_dtype
|
||||||
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
@ -696,7 +714,7 @@ class CogVideoXModelLoader:
|
|||||||
transformer_config["sample_width"] = 300
|
transformer_config["sample_width"] = 300
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
|
||||||
|
|
||||||
#load weights
|
#load weights
|
||||||
#params_to_keep = {}
|
#params_to_keep = {}
|
||||||
|
|||||||
@ -471,7 +471,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
# 5.5.
|
# 5.5.
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
if image_cond_latents.shape[1] == 2:
|
if image_cond_latents.shape[1] == 3:
|
||||||
|
logger.info("More than one image conditioning frame received, interpolating")
|
||||||
|
padding_shape = (
|
||||||
|
batch_size,
|
||||||
|
(latents.shape[1] - 3),
|
||||||
|
self.vae_latent_channels,
|
||||||
|
height // self.vae_scale_factor_spatial,
|
||||||
|
width // self.vae_scale_factor_spatial,
|
||||||
|
)
|
||||||
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||||
|
middle_frame = image_cond_latents[:, 2, :, :, :]
|
||||||
|
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||||
|
middle_frame_idx = image_cond_latents.shape[1] // 2
|
||||||
|
image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame
|
||||||
|
|
||||||
|
if self.transformer.config.patch_size_t is not None:
|
||||||
|
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||||
|
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||||
|
elif image_cond_latents.shape[1] == 2:
|
||||||
logger.info("More than one image conditioning frame received, interpolating")
|
logger.info("More than one image conditioning frame received, interpolating")
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -593,9 +611,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
counter = torch.zeros_like(latent_model_input)
|
counter = torch.zeros_like(latent_model_input)
|
||||||
noise_pred = torch.zeros_like(latent_model_input)
|
noise_pred = torch.zeros_like(latent_model_input)
|
||||||
|
|
||||||
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
latent_image_input = torch.zeros_like(latent_model_input)
|
||||||
|
else:
|
||||||
|
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||||
|
if fun_mask is not None: #for fun img2vid and interpolation
|
||||||
|
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||||
|
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
|
||||||
|
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
|
||||||
|
else:
|
||||||
|
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||||
|
else: # for Fun inpaint vid2vid
|
||||||
|
if fun_mask is not None:
|
||||||
|
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||||
|
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
|
||||||
|
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
|
||||||
|
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
|
||||||
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user