update from upstream, ofs embeds

This commit is contained in:
kijai 2024-11-11 18:53:12 +02:00
parent 5f1a917b93
commit 6931576916
3 changed files with 26 additions and 16 deletions

View File

@ -425,9 +425,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.ofs_proj = None
self.ofs_embedding = None
if ofs_embed_dim:
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
# 3. Define spatio-temporal transformers blocks
@ -547,6 +549,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
@ -563,26 +566,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None: #1.5 I2V
emb_ofs = self.ofs_embedding(emb, timestep_cond)
emb = emb + emb_ofs
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
p = self.config.patch_size
p_t = self.config.patch_size_t
# We know that the hidden states height and width will always be divisible by patch_size.
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
# if p_t is not None:
# remaining_frames = 0 if num_frames % 2 == 0 else 1
# first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
# hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
@ -639,7 +637,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
output = output[:, remaining_frames:]
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -711,7 +708,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
#output = output[:, remaining_frames:]
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape

View File

@ -346,7 +346,7 @@ class CogVideoImageEncode:
"image": ("IMAGE", ),
},
"optional": {
"chunk_size": ("INT", {"default": 16, "min": 4}),
"chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"mask": ("MASK", ),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
@ -806,6 +806,7 @@ class CogVideoSampler:
"controlnet": ("COGVIDECONTROLNET",),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
#"sigmas": ("SIGMAS", ),
}
}
@ -879,6 +880,9 @@ class CogVideoSampler:
cfg = [cfg for _ in range(steps)]
else:
assert len(cfg) == steps, "Length of cfg list must match number of steps"
# if sigmas is not None:
# sigma_list = sigmas.tolist()
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
@ -889,6 +893,7 @@ class CogVideoSampler:
width = width,
num_frames = num_frames,
guidance_scale=cfg,
#sigmas=sigma_list if sigmas is not None else None,
latents=samples["samples"] if samples is not None else None,
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
denoise_strength=denoise_strength,

View File

@ -369,6 +369,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
denoise_strength: float = 1.0,
sigmas: Optional[List[float]] = None,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@ -429,7 +430,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
@ -460,7 +461,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
prompt_embeds = prompt_embeds.to(self.vae.dtype)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
if sigmas is None:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, sigmas=sigmas, device=device)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
@ -499,7 +503,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
freenoise=freenoise,
)
latents = latents.to(self.vae.dtype)
#print("latents", latents.shape)
# 5.5.
if image_cond_latents is not None:
@ -579,6 +582,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 7.6. Create ofs embeds if required
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
if tora is not None and do_classifier_free_guidance:
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
@ -617,6 +623,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
for param in module.parameters():
param.data = param.data.to(device)
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
# 9. Denoising loop
comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar:
@ -873,6 +881,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
ofs=ofs_emb,
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_weights,