some cleanup

This commit is contained in:
kijai 2024-11-09 15:15:10 +02:00
parent 643bbc18c1
commit 1c3aff9000
3 changed files with 14 additions and 34 deletions

View File

@ -60,7 +60,7 @@ class CogVideoLoraSelect:
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
#region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel:
@classmethod
def INPUT_TYPES(s):
@ -259,12 +259,9 @@ class DownloadAndLoadCogVideoModel:
if fuse:
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
# compilation
if compile == "torch":
pipe.transformer.to(memory_format=torch.channels_last)
@ -277,8 +274,6 @@ class DownloadAndLoadCogVideoModel:
for i, block in enumerate(pipe.transformer.transformer_blocks):
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff":
from onediffx import compile_pipe
@ -303,7 +298,7 @@ class DownloadAndLoadCogVideoModel:
}
return (pipeline,)
#region GGUF
class DownloadAndLoadCogVideoGGUFModel:
@classmethod
def INPUT_TYPES(s):
@ -483,7 +478,7 @@ class DownloadAndLoadCogVideoGGUFModel:
}
return (pipeline,)
#region Tora
class DownloadAndLoadToraModel:
@classmethod
def INPUT_TYPES(s):
@ -591,7 +586,7 @@ class DownloadAndLoadToraModel:
}
return (toramodel,)
#region controlnet
class DownloadAndLoadCogVideoControlNet:
@classmethod
def INPUT_TYPES(s):

View File

@ -816,7 +816,12 @@ class CogVideoSampler:
base_path = pipeline["base_path"]
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
assert ("I2V" not in pipeline.get("model_name","") or num_frames == 49 or context_options is not None), "I2V model can only do 49 frames"
assert (
"I2V" not in pipeline.get("model_name", "") or
"1.5" in pipeline.get("model_name", "") or
num_frames == 49 or
context_options is not None
), "1.0 I2V model can only do 49 frames"
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()

View File

@ -317,18 +317,17 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
width: int,
num_frames: int,
device: torch.device,
start_frame: int = None,
end_frame: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
@ -336,19 +335,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
use_real=True,
temporal_size=base_num_frames
)
if start_frame is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
@ -535,13 +523,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# masks
if self.original_mask is not None:
mask = self.original_mask.to(device)
@ -579,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
use_temporal_tiling = False
use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required
# 8.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
@ -882,7 +863,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
controlnet_states=controlnet_states,
controlnet_weights=control_weights,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0]
noise_pred = noise_pred.float()