mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
some cleanup
This commit is contained in:
parent
643bbc18c1
commit
1c3aff9000
@ -60,7 +60,7 @@ class CogVideoLoraSelect:
|
||||
cog_loras_list.append(cog_lora)
|
||||
print(cog_loras_list)
|
||||
return (cog_loras_list,)
|
||||
|
||||
#region DownloadAndLoadCogVideoModel
|
||||
class DownloadAndLoadCogVideoModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -259,12 +259,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
if fuse:
|
||||
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
|
||||
|
||||
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
|
||||
# compilation
|
||||
if compile == "torch":
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
@ -277,8 +274,6 @@ class DownloadAndLoadCogVideoModel:
|
||||
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||
if "CogVideoXBlock" in str(block):
|
||||
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||
|
||||
|
||||
|
||||
elif compile == "onediff":
|
||||
from onediffx import compile_pipe
|
||||
@ -303,7 +298,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
}
|
||||
|
||||
return (pipeline,)
|
||||
|
||||
#region GGUF
|
||||
class DownloadAndLoadCogVideoGGUFModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -483,7 +478,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
}
|
||||
|
||||
return (pipeline,)
|
||||
|
||||
#region Tora
|
||||
class DownloadAndLoadToraModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -591,7 +586,7 @@ class DownloadAndLoadToraModel:
|
||||
}
|
||||
|
||||
return (toramodel,)
|
||||
|
||||
#region controlnet
|
||||
class DownloadAndLoadCogVideoControlNet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -816,7 +816,12 @@ class CogVideoSampler:
|
||||
base_path = pipeline["base_path"]
|
||||
|
||||
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
|
||||
assert ("I2V" not in pipeline.get("model_name","") or num_frames == 49 or context_options is not None), "I2V model can only do 49 frames"
|
||||
assert (
|
||||
"I2V" not in pipeline.get("model_name", "") or
|
||||
"1.5" in pipeline.get("model_name", "") or
|
||||
num_frames == 49 or
|
||||
context_options is not None
|
||||
), "1.0 I2V model can only do 49 frames"
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
@ -317,18 +317,17 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
start_frame: int = None,
|
||||
end_frame: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
p = self.transformer.config.patch_size
|
||||
p_t = self.transformer.config.patch_size_t or 1
|
||||
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
@ -336,19 +335,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
use_real=True,
|
||||
temporal_size=base_num_frames
|
||||
)
|
||||
|
||||
if start_frame is not None:
|
||||
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||
|
||||
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||
|
||||
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
@ -535,13 +523,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.5. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# masks
|
||||
if self.original_mask is not None:
|
||||
mask = self.original_mask.to(device)
|
||||
@ -579,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
logger.info("Temporal tiling and context schedule disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
# 8.5. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
@ -882,7 +863,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user