mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
support Tora for Fun -models
This commit is contained in:
parent
5fd4f67b14
commit
5b4819ba65
@ -528,6 +528,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
context_stride: Optional[int] = None,
|
context_stride: Optional[int] = None,
|
||||||
context_overlap: Optional[int] = None,
|
context_overlap: Optional[int] = None,
|
||||||
freenoise: Optional[bool] = True,
|
freenoise: Optional[bool] = True,
|
||||||
|
tora: Optional[dict] = None,
|
||||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -720,7 +721,13 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
if tora is not None and do_classifier_free_guidance:
|
||||||
|
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
|
||||||
|
if tora is not None:
|
||||||
|
for module in self.transformer.fuser_list:
|
||||||
|
for param in module.parameters():
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
# for DPM-solver++
|
# for DPM-solver++
|
||||||
@ -910,6 +917,8 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
control_latents=current_control_latents,
|
control_latents=current_control_latents,
|
||||||
|
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||||
|
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
|||||||
@ -610,6 +610,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
context_stride: Optional[int] = None,
|
context_stride: Optional[int] = None,
|
||||||
context_overlap: Optional[int] = None,
|
context_overlap: Optional[int] = None,
|
||||||
freenoise: Optional[bool] = True,
|
freenoise: Optional[bool] = True,
|
||||||
|
tora: Optional[dict] = None,
|
||||||
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -889,6 +890,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
if tora is not None and do_classifier_free_guidance:
|
||||||
|
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
|
||||||
|
if tora is not None:
|
||||||
|
for module in self.transformer.fuser_list:
|
||||||
|
for param in module.parameters():
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
|
||||||
# 8. Denoising loop
|
# 8. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
@ -1061,6 +1069,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latent_model_input.shape[0])
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
|
current_step_percentage = i / num_inference_steps
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=latent_model_input,
|
hidden_states=latent_model_input,
|
||||||
@ -1069,6 +1079,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
inpaint_latents=inpaint_latents,
|
inpaint_latents=inpaint_latents,
|
||||||
|
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||||
|
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
SAGEATTN_IS_AVAVILABLE = True
|
SAGEATTN_IS_AVAVILABLE = True
|
||||||
@ -42,6 +43,23 @@ except:
|
|||||||
logger.info("sageattn not found, using sdpa")
|
logger.info("sageattn not found, using sdpa")
|
||||||
SAGEATTN_IS_AVAVILABLE = False
|
SAGEATTN_IS_AVAVILABLE = False
|
||||||
|
|
||||||
|
def fft(tensor):
|
||||||
|
tensor_fft = torch.fft.fft2(tensor)
|
||||||
|
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||||
|
B, C, H, W = tensor.size()
|
||||||
|
radius = min(H, W) // 5
|
||||||
|
|
||||||
|
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
|
||||||
|
center_x, center_y = W // 2, H // 2
|
||||||
|
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
|
||||||
|
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
|
||||||
|
high_freq_mask = ~low_freq_mask
|
||||||
|
|
||||||
|
low_freq_fft = tensor_fft_shifted * low_freq_mask
|
||||||
|
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
||||||
|
|
||||||
|
return low_freq_fft, high_freq_fft
|
||||||
|
|
||||||
class CogVideoXAttnProcessor2_0:
|
class CogVideoXAttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||||
@ -315,6 +333,11 @@ class CogVideoXBlock(nn.Module):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
video_flow_feature: Optional[torch.Tensor] = None,
|
||||||
|
fuser=None,
|
||||||
|
fastercache_counter=0,
|
||||||
|
fastercache_start_step=15,
|
||||||
|
fastercache_device="cuda:0",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -322,7 +345,39 @@ class CogVideoXBlock(nn.Module):
|
|||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
|
# Tora Motion-guidance Fuser
|
||||||
|
if video_flow_feature is not None:
|
||||||
|
H, W = video_flow_feature.shape[-2:]
|
||||||
|
T = norm_hidden_states.shape[1] // H // W
|
||||||
|
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||||
|
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||||
|
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||||
|
del h, fuser
|
||||||
|
#fastercache
|
||||||
|
B = norm_hidden_states.shape[0]
|
||||||
|
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||||
|
attn_hidden_states = (
|
||||||
|
self.cached_hidden_states[1][:B] +
|
||||||
|
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
|
||||||
|
* 0.3
|
||||||
|
).to(norm_hidden_states.device, non_blocking=True)
|
||||||
|
attn_encoder_hidden_states = (
|
||||||
|
self.cached_encoder_hidden_states[1][:B] +
|
||||||
|
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
|
||||||
|
* 0.3
|
||||||
|
).to(norm_hidden_states.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
if fastercache_counter == fastercache_start_step:
|
||||||
|
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||||
|
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
|
||||||
|
elif fastercache_counter > fastercache_start_step:
|
||||||
|
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||||
|
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||||
# attention
|
# attention
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
@ -497,6 +552,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.fuser_list = None
|
||||||
|
|
||||||
|
self.use_fastercache = False
|
||||||
|
self.fastercache_counter = 0
|
||||||
|
self.fastercache_start_step = 15
|
||||||
|
self.fastercache_lf_step = 40
|
||||||
|
self.fastercache_hf_step = 30
|
||||||
|
self.fastercache_device = "cuda"
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
|
|
||||||
@ -609,6 +673,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
inpaint_latents: Optional[torch.Tensor] = None,
|
inpaint_latents: Optional[torch.Tensor] = None,
|
||||||
control_latents: Optional[torch.Tensor] = None,
|
control_latents: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
video_flow_features: Optional[torch.Tensor] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||||
@ -649,50 +714,101 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
# 4. Transformer blocks
|
if self.use_fastercache:
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
self.fastercache_counter+=1
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||||
|
# 4. Transformer blocks
|
||||||
def create_custom_forward(module):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
def custom_forward(*inputs):
|
hidden_states, encoder_hidden_states = block(
|
||||||
return module(*inputs)
|
hidden_states=hidden_states[:1],
|
||||||
|
encoder_hidden_states=encoder_hidden_states[:1],
|
||||||
return custom_forward
|
temb=emb[:1],
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
|
||||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||||
create_custom_forward(block),
|
fastercache_counter = self.fastercache_counter,
|
||||||
hidden_states,
|
fastercache_device = self.fastercache_device
|
||||||
encoder_hidden_states,
|
|
||||||
emb,
|
|
||||||
image_rotary_emb,
|
|
||||||
**ckpt_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.use_rotary_positional_embeddings:
|
||||||
|
# CogVideoX-2B
|
||||||
|
hidden_states = self.norm_final(hidden_states)
|
||||||
else:
|
else:
|
||||||
|
# CogVideoX-5B
|
||||||
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||||
|
hidden_states = self.norm_final(hidden_states)
|
||||||
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
|
# 5. Final block
|
||||||
|
hidden_states = self.norm_out(hidden_states, temb=emb[:1])
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
# 6. Unpatchify
|
||||||
|
p = self.config.patch_size
|
||||||
|
output = hidden_states.reshape(1, num_frames, height // p, width // p, channels, p, p)
|
||||||
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||||
|
|
||||||
|
(bb, tt, cc, hh, ww) = output.shape
|
||||||
|
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
lf_c, hf_c = fft(cond.float())
|
||||||
|
#lf_step = 40
|
||||||
|
#hf_step = 30
|
||||||
|
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||||
|
self.delta_lf = self.delta_lf * 1.1
|
||||||
|
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||||
|
self.delta_hf = self.delta_hf * 1.1
|
||||||
|
|
||||||
|
new_hf_uc = self.delta_hf + hf_c
|
||||||
|
new_lf_uc = self.delta_lf + lf_c
|
||||||
|
|
||||||
|
combine_uc = new_lf_uc + new_hf_uc
|
||||||
|
combined_fft = torch.fft.ifftshift(combine_uc)
|
||||||
|
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
||||||
|
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
output = torch.cat([output, recovered_uncond])
|
||||||
|
else:
|
||||||
|
# 4. Transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
hidden_states, encoder_hidden_states = block(
|
hidden_states, encoder_hidden_states = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||||
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||||
|
fastercache_counter = self.fastercache_counter,
|
||||||
|
fastercache_device = self.fastercache_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
if not self.config.use_rotary_positional_embeddings:
|
||||||
# CogVideoX-2B
|
# CogVideoX-2B
|
||||||
hidden_states = self.norm_final(hidden_states)
|
hidden_states = self.norm_final(hidden_states)
|
||||||
else:
|
else:
|
||||||
# CogVideoX-5B
|
# CogVideoX-5B
|
||||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||||
hidden_states = self.norm_final(hidden_states)
|
hidden_states = self.norm_final(hidden_states)
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
# 5. Final block
|
# 5. Final block
|
||||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
# 6. Unpatchify
|
||||||
|
p = self.config.patch_size
|
||||||
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||||
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||||
|
|
||||||
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
|
(bb, tt, cc, hh, ww) = output.shape
|
||||||
|
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
|
||||||
|
lf_c, hf_c = fft(cond)
|
||||||
|
lf_uc, hf_uc = fft(uncond)
|
||||||
|
|
||||||
|
self.delta_lf = lf_uc - lf_c
|
||||||
|
self.delta_hf = hf_uc - hf_c
|
||||||
|
|
||||||
# 6. Unpatchify
|
|
||||||
p = self.config.patch_size
|
|
||||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
|
||||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (output,)
|
return (output,)
|
||||||
|
|||||||
@ -662,7 +662,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
encoder_hidden_states=encoder_hidden_states[:1],
|
encoder_hidden_states=encoder_hidden_states[:1],
|
||||||
temb=emb[:1],
|
temb=emb[:1],
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
|
||||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_device = self.fastercache_device
|
fastercache_device = self.fastercache_device
|
||||||
|
|||||||
21
nodes.py
21
nodes.py
@ -1480,6 +1480,8 @@ class CogVideoXFunSampler:
|
|||||||
"opt_empty_latent": ("LATENT",),
|
"opt_empty_latent": ("LATENT",),
|
||||||
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"context_options": ("COGCONTEXT", ),
|
"context_options": ("COGCONTEXT", ),
|
||||||
|
"tora_trajectory": ("TORAFEATURES", ),
|
||||||
|
"fastercache": ("FASTERCACHEARGS",),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1489,7 +1491,7 @@ class CogVideoXFunSampler:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
|
def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler,
|
||||||
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None):
|
start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563, context_options=None, fastercache=None, tora_trajectory=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
pipe = pipeline["pipe"]
|
pipe = pipeline["pipe"]
|
||||||
@ -1538,6 +1540,20 @@ class CogVideoXFunSampler:
|
|||||||
else:
|
else:
|
||||||
context_frames, context_stride, context_overlap = None, None, None
|
context_frames, context_stride, context_overlap = None, None, None
|
||||||
|
|
||||||
|
if tora_trajectory is not None:
|
||||||
|
pipe.transformer.fuser_list = tora_trajectory["fuser_list"]
|
||||||
|
|
||||||
|
if fastercache is not None:
|
||||||
|
pipe.transformer.use_fastercache = True
|
||||||
|
pipe.transformer.fastercache_counter = 0
|
||||||
|
pipe.transformer.fastercache_start_step = fastercache["start_step"]
|
||||||
|
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
|
||||||
|
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
|
||||||
|
pipe.transformer.fastercache_device = fastercache["cache_device"]
|
||||||
|
else:
|
||||||
|
pipe.transformer.use_fastercache = False
|
||||||
|
pipe.transformer.fastercache_counter = 0
|
||||||
|
|
||||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||||
@ -1564,7 +1580,8 @@ class CogVideoXFunSampler:
|
|||||||
context_frames=context_frames,
|
context_frames=context_frames,
|
||||||
context_stride= context_stride,
|
context_stride= context_stride,
|
||||||
context_overlap= context_overlap,
|
context_overlap= context_overlap,
|
||||||
freenoise=context_options["freenoise"] if context_options is not None else None
|
freenoise=context_options["freenoise"] if context_options is not None else None,
|
||||||
|
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||||
)
|
)
|
||||||
#if not pipeline["cpu_offloading"]:
|
#if not pipeline["cpu_offloading"]:
|
||||||
# pipe.transformer.to(offload_device)
|
# pipe.transformer.to(offload_device)
|
||||||
|
|||||||
@ -287,11 +287,21 @@ class MGF(nn.Module):
|
|||||||
gamma_flow = self.flow_gamma_spatial(flow)
|
gamma_flow = self.flow_gamma_spatial(flow)
|
||||||
beta_flow = self.flow_beta_spatial(flow)
|
beta_flow = self.flow_beta_spatial(flow)
|
||||||
_, _, hh, wh = beta_flow.shape
|
_, _, hh, wh = beta_flow.shape
|
||||||
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
|
|
||||||
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
|
if gamma_flow.shape[0] == 1: # Check if batch size is 1
|
||||||
gamma_flow = self.flow_gamma_temporal(gamma_flow)
|
gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)")
|
||||||
beta_flow = self.flow_beta_temporal(beta_flow)
|
beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)")
|
||||||
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
gamma_flow = self.flow_gamma_temporal(gamma_flow)
|
||||||
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
beta_flow = self.flow_beta_temporal(beta_flow)
|
||||||
|
gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh)
|
||||||
|
beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh)
|
||||||
|
else:
|
||||||
|
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
|
||||||
|
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
|
||||||
|
gamma_flow = self.flow_gamma_temporal(gamma_flow)
|
||||||
|
beta_flow = self.flow_beta_temporal(beta_flow)
|
||||||
|
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
||||||
|
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
||||||
|
|
||||||
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
|
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
|
||||||
return h
|
return h
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user