This commit is contained in:
kijai 2024-08-27 22:43:57 +03:00
parent 44a8305fcb
commit 1e356fa905
2 changed files with 13 additions and 17 deletions

View File

@ -59,7 +59,7 @@ class DownloadAndLoadCogVideoModel:
snapshot_download(
repo_id=model,
ignore_patterns=["*text_encoder*"],
ignore_patterns=["*text_encoder*", "*tokenizer*"],
local_dir=base_path,
local_dir_use_symlinks=False,
)

View File

@ -285,16 +285,16 @@ class CogVideoXPipeline(DiffusionPipeline):
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
if start_frame is not None:
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
@ -444,12 +444,12 @@ class CogVideoXPipeline(DiffusionPipeline):
print("latents.device", latents.device)
# # 6.5. Create rotary embeds if required
# image_rotary_emb = (
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
# if self.transformer.config.use_rotary_positional_embeddings
# else None
# )
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@ -541,11 +541,7 @@ class CogVideoXPipeline(DiffusionPipeline):
comfy_pbar.update(1)
# ==========================================
else:
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)