some cleanup

This commit is contained in:
kijai 2024-11-09 15:15:10 +02:00
parent 643bbc18c1
commit 1c3aff9000
3 changed files with 14 additions and 34 deletions

View File

@ -60,7 +60,7 @@ class CogVideoLoraSelect:
cog_loras_list.append(cog_lora) cog_loras_list.append(cog_lora)
print(cog_loras_list) print(cog_loras_list)
return (cog_loras_list,) return (cog_loras_list,)
#region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel: class DownloadAndLoadCogVideoModel:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -259,12 +259,9 @@ class DownloadAndLoadCogVideoModel:
if fuse: if fuse:
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"]) pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
if enable_sequential_cpu_offload: if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
# compilation # compilation
if compile == "torch": if compile == "torch":
pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
@ -278,8 +275,6 @@ class DownloadAndLoadCogVideoModel:
if "CogVideoXBlock" in str(block): if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff": elif compile == "onediff":
from onediffx import compile_pipe from onediffx import compile_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
@ -303,7 +298,7 @@ class DownloadAndLoadCogVideoModel:
} }
return (pipeline,) return (pipeline,)
#region GGUF
class DownloadAndLoadCogVideoGGUFModel: class DownloadAndLoadCogVideoGGUFModel:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -483,7 +478,7 @@ class DownloadAndLoadCogVideoGGUFModel:
} }
return (pipeline,) return (pipeline,)
#region Tora
class DownloadAndLoadToraModel: class DownloadAndLoadToraModel:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -591,7 +586,7 @@ class DownloadAndLoadToraModel:
} }
return (toramodel,) return (toramodel,)
#region controlnet
class DownloadAndLoadCogVideoControlNet: class DownloadAndLoadCogVideoControlNet:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -816,7 +816,12 @@ class CogVideoSampler:
base_path = pipeline["base_path"] base_path = pipeline["base_path"]
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'" assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
assert ("I2V" not in pipeline.get("model_name","") or num_frames == 49 or context_options is not None), "I2V model can only do 49 frames" assert (
"I2V" not in pipeline.get("model_name", "") or
"1.5" in pipeline.get("model_name", "") or
num_frames == 49 or
context_options is not None
), "1.0 I2V model can only do 49 frames"
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()

View File

@ -317,11 +317,10 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
width: int, width: int,
num_frames: int, num_frames: int,
device: torch.device, device: torch.device,
start_frame: int = None,
end_frame: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t or 1
@ -336,20 +335,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=base_num_frames, temporal_size=base_num_frames
use_real=True,
) )
if start_frame is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device) freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@ -535,13 +523,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# masks # masks
if self.original_mask is not None: if self.original_mask is not None:
mask = self.original_mask.to(device) mask = self.original_mask.to(device)
@ -579,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
use_temporal_tiling = False use_temporal_tiling = False
use_context_schedule = False use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled") logger.info("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required # 8.5. Create rotary embeds if required
image_rotary_emb = ( image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings if self.transformer.config.use_rotary_positional_embeddings
@ -882,7 +863,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
controlnet_states=controlnet_states, controlnet_states=controlnet_states,
controlnet_weights=control_weights, controlnet_weights=control_weights,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()