mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
update from upstream, ofs embeds
This commit is contained in:
parent
5f1a917b93
commit
6931576916
@ -425,9 +425,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
self.ofs_proj = None
|
||||
self.ofs_embedding = None
|
||||
|
||||
if ofs_embed_dim:
|
||||
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
|
||||
|
||||
# 3. Define spatio-temporal transformers blocks
|
||||
@ -547,6 +549,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
controlnet_states: torch.Tensor = None,
|
||||
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
||||
@ -563,26 +566,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
if self.ofs_embedding is not None: #1.5 I2V
|
||||
emb_ofs = self.ofs_embedding(emb, timestep_cond)
|
||||
emb = emb + emb_ofs
|
||||
ofs_emb = self.ofs_proj(ofs)
|
||||
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
||||
ofs_emb = self.ofs_embedding(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
|
||||
# 2. Patch embedding
|
||||
p = self.config.patch_size
|
||||
p_t = self.config.patch_size_t
|
||||
# We know that the hidden states height and width will always be divisible by patch_size.
|
||||
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
|
||||
# if p_t is not None:
|
||||
# remaining_frames = 0 if num_frames % 2 == 0 else 1
|
||||
# first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
|
||||
# hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
|
||||
|
||||
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
@ -639,7 +637,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
output = output[:, remaining_frames:]
|
||||
|
||||
(bb, tt, cc, hh, ww) = output.shape
|
||||
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
@ -711,7 +708,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
#output = output[:, remaining_frames:]
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
(bb, tt, cc, hh, ww) = output.shape
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -346,7 +346,7 @@ class CogVideoImageEncode:
|
||||
"image": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"chunk_size": ("INT", {"default": 16, "min": 4}),
|
||||
"chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}),
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
"mask": ("MASK", ),
|
||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||
@ -806,6 +806,7 @@ class CogVideoSampler:
|
||||
"controlnet": ("COGVIDECONTROLNET",),
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
#"sigmas": ("SIGMAS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -879,6 +880,9 @@ class CogVideoSampler:
|
||||
cfg = [cfg for _ in range(steps)]
|
||||
else:
|
||||
assert len(cfg) == steps, "Length of cfg list must match number of steps"
|
||||
|
||||
# if sigmas is not None:
|
||||
# sigma_list = sigmas.tolist()
|
||||
|
||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||
@ -889,6 +893,7 @@ class CogVideoSampler:
|
||||
width = width,
|
||||
num_frames = num_frames,
|
||||
guidance_scale=cfg,
|
||||
#sigmas=sigma_list if sigmas is not None else None,
|
||||
latents=samples["samples"] if samples is not None else None,
|
||||
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
|
||||
denoise_strength=denoise_strength,
|
||||
|
||||
@ -369,6 +369,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
denoise_strength: float = 1.0,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@ -429,7 +430,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
|
||||
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
num_videos_per_prompt = 1
|
||||
@ -460,7 +461,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
prompt_embeds = prompt_embeds.to(self.vae.dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
if sigmas is None:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, sigmas=sigmas, device=device)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
@ -499,7 +503,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
freenoise=freenoise,
|
||||
)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
#print("latents", latents.shape)
|
||||
|
||||
# 5.5.
|
||||
if image_cond_latents is not None:
|
||||
@ -579,6 +582,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
# 7.6. Create ofs embeds if required
|
||||
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
|
||||
|
||||
if tora is not None and do_classifier_free_guidance:
|
||||
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||
|
||||
@ -617,6 +623,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(device)
|
||||
|
||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||
|
||||
# 9. Denoising loop
|
||||
comfy_pbar = ProgressBar(len(timesteps))
|
||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||
@ -873,6 +881,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
ofs=ofs_emb,
|
||||
return_dict=False,
|
||||
controlnet_states=controlnet_states,
|
||||
controlnet_weights=control_weights,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user