padding fix

This commit is contained in:
kijai 2024-11-11 17:29:57 +02:00
parent 43bc7fb4fc
commit 5f1a917b93
3 changed files with 23 additions and 9 deletions

View File

@ -573,10 +573,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
p_t = self.config.patch_size_t
# We know that the hidden states height and width will always be divisible by patch_size.
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
if p_t is not None:
remaining_frames = 0 if num_frames % 2 == 0 else 1
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
# if p_t is not None:
# remaining_frames = 0 if num_frames % 2 == 0 else 1
# first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
# hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
@ -711,7 +711,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
#output = output[:, remaining_frames:]
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape

View File

@ -1004,7 +1004,7 @@ class CogVideoDecode:
vae._clear_fake_context_parallel_cache()
except:
pass
frames = vae.decode(latents).sample
frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample
vae.disable_tiling()
if not pipeline["cpu_offloading"]:
vae.to(offload_device)

View File

@ -434,6 +434,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
self.num_frames = num_frames
# 1. Check inputs. Raise error if not correct
self.check_inputs(
height,
@ -463,6 +465,14 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
# 5. Prepare latents.
latent_channels = self.vae.config.latent_channels
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
self.additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
self.additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += self.additional_frames * self.vae_scale_factor_temporal
#if latents is None and num_frames == t_tile_length:
# num_frames += 1
@ -503,8 +513,12 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
# Select the first frame along the second dimension
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_latents], dim=1)
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
else:
logger.info("Only one image conditioning frame received, img2vid")
@ -597,8 +611,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
if trajectory_length != latents.shape[1]:
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
#if trajectory_length != latents.shape[1]:
# raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)