mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-15 16:04:28 +08:00
update from upstream, ofs embeds
This commit is contained in:
parent
5f1a917b93
commit
6931576916
@ -425,9 +425,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||||
|
|
||||||
|
self.ofs_proj = None
|
||||||
self.ofs_embedding = None
|
self.ofs_embedding = None
|
||||||
|
|
||||||
if ofs_embed_dim:
|
if ofs_embed_dim:
|
||||||
|
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
||||||
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
|
self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs
|
||||||
|
|
||||||
# 3. Define spatio-temporal transformers blocks
|
# 3. Define spatio-temporal transformers blocks
|
||||||
@ -547,6 +549,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
timestep: Union[int, float, torch.LongTensor],
|
timestep: Union[int, float, torch.LongTensor],
|
||||||
timestep_cond: Optional[torch.Tensor] = None,
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
controlnet_states: torch.Tensor = None,
|
controlnet_states: torch.Tensor = None,
|
||||||
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
||||||
@ -563,26 +566,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
# there might be better ways to encapsulate this.
|
# there might be better ways to encapsulate this.
|
||||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||||
|
|
||||||
emb = self.time_embedding(t_emb, timestep_cond)
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
if self.ofs_embedding is not None: #1.5 I2V
|
if self.ofs_embedding is not None: #1.5 I2V
|
||||||
emb_ofs = self.ofs_embedding(emb, timestep_cond)
|
ofs_emb = self.ofs_proj(ofs)
|
||||||
emb = emb + emb_ofs
|
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
||||||
|
ofs_emb = self.ofs_embedding(ofs_emb)
|
||||||
|
emb = emb + ofs_emb
|
||||||
|
|
||||||
# 2. Patch embedding
|
# 2. Patch embedding
|
||||||
p = self.config.patch_size
|
p = self.config.patch_size
|
||||||
p_t = self.config.patch_size_t
|
p_t = self.config.patch_size_t
|
||||||
# We know that the hidden states height and width will always be divisible by patch_size.
|
|
||||||
# But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames.
|
|
||||||
# if p_t is not None:
|
|
||||||
# remaining_frames = 0 if num_frames % 2 == 0 else 1
|
|
||||||
# first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
|
|
||||||
# hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||||
hidden_states = self.embedding_dropout(hidden_states)
|
hidden_states = self.embedding_dropout(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.shape[1]
|
text_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||||
hidden_states = hidden_states[:, text_seq_length:]
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
@ -639,7 +637,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||||
)
|
)
|
||||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||||
output = output[:, remaining_frames:]
|
|
||||||
|
|
||||||
(bb, tt, cc, hh, ww) = output.shape
|
(bb, tt, cc, hh, ww) = output.shape
|
||||||
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
@ -711,7 +708,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||||
)
|
)
|
||||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||||
#output = output[:, remaining_frames:]
|
|
||||||
|
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
(bb, tt, cc, hh, ww) = output.shape
|
(bb, tt, cc, hh, ww) = output.shape
|
||||||
|
|||||||
7
nodes.py
7
nodes.py
@ -346,7 +346,7 @@ class CogVideoImageEncode:
|
|||||||
"image": ("IMAGE", ),
|
"image": ("IMAGE", ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"chunk_size": ("INT", {"default": 16, "min": 4}),
|
"chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}),
|
||||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||||
"mask": ("MASK", ),
|
"mask": ("MASK", ),
|
||||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||||
@ -806,6 +806,7 @@ class CogVideoSampler:
|
|||||||
"controlnet": ("COGVIDECONTROLNET",),
|
"controlnet": ("COGVIDECONTROLNET",),
|
||||||
"tora_trajectory": ("TORAFEATURES", ),
|
"tora_trajectory": ("TORAFEATURES", ),
|
||||||
"fastercache": ("FASTERCACHEARGS", ),
|
"fastercache": ("FASTERCACHEARGS", ),
|
||||||
|
#"sigmas": ("SIGMAS", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -879,6 +880,9 @@ class CogVideoSampler:
|
|||||||
cfg = [cfg for _ in range(steps)]
|
cfg = [cfg for _ in range(steps)]
|
||||||
else:
|
else:
|
||||||
assert len(cfg) == steps, "Length of cfg list must match number of steps"
|
assert len(cfg) == steps, "Length of cfg list must match number of steps"
|
||||||
|
|
||||||
|
# if sigmas is not None:
|
||||||
|
# sigma_list = sigmas.tolist()
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||||
@ -889,6 +893,7 @@ class CogVideoSampler:
|
|||||||
width = width,
|
width = width,
|
||||||
num_frames = num_frames,
|
num_frames = num_frames,
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
|
#sigmas=sigma_list if sigmas is not None else None,
|
||||||
latents=samples["samples"] if samples is not None else None,
|
latents=samples["samples"] if samples is not None else None,
|
||||||
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
|
image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None,
|
||||||
denoise_strength=denoise_strength,
|
denoise_strength=denoise_strength,
|
||||||
|
|||||||
@ -369,6 +369,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
guidance_scale: float = 6,
|
guidance_scale: float = 6,
|
||||||
denoise_strength: float = 1.0,
|
denoise_strength: float = 1.0,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
@ -429,7 +430,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
argument.
|
argument.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
num_videos_per_prompt = 1
|
num_videos_per_prompt = 1
|
||||||
@ -460,7 +461,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
prompt_embeds = prompt_embeds.to(self.vae.dtype)
|
prompt_embeds = prompt_embeds.to(self.vae.dtype)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
if sigmas is None:
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||||
|
else:
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, sigmas=sigmas, device=device)
|
||||||
self._num_timesteps = len(timesteps)
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
# 5. Prepare latents.
|
# 5. Prepare latents.
|
||||||
@ -499,7 +503,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
freenoise=freenoise,
|
freenoise=freenoise,
|
||||||
)
|
)
|
||||||
latents = latents.to(self.vae.dtype)
|
latents = latents.to(self.vae.dtype)
|
||||||
#print("latents", latents.shape)
|
|
||||||
|
|
||||||
# 5.5.
|
# 5.5.
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
@ -579,6 +582,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
# 7.6. Create ofs embeds if required
|
||||||
|
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
|
||||||
|
|
||||||
if tora is not None and do_classifier_free_guidance:
|
if tora is not None and do_classifier_free_guidance:
|
||||||
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
|
||||||
@ -617,6 +623,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.data = param.data.to(device)
|
param.data = param.data.to(device)
|
||||||
|
|
||||||
|
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||||
|
|
||||||
# 9. Denoising loop
|
# 9. Denoising loop
|
||||||
comfy_pbar = ProgressBar(len(timesteps))
|
comfy_pbar = ProgressBar(len(timesteps))
|
||||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||||
@ -873,6 +881,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
ofs=ofs_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
controlnet_states=controlnet_states,
|
controlnet_states=controlnet_states,
|
||||||
controlnet_weights=control_weights,
|
controlnet_weights=control_weights,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user