mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-11 05:54:29 +08:00
cleanup, bugfixes
This commit is contained in:
parent
7e8a3e4f2a
commit
256a638ee4
@ -287,25 +287,13 @@ class CogVideoXBlock(nn.Module):
|
|||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
|
|
||||||
# Motion-guidance Fuser
|
# Tora Motion-guidance Fuser
|
||||||
if video_flow_feature is not None:
|
if video_flow_feature is not None:
|
||||||
#print(video_flow_feature)
|
|
||||||
#print("hidden_states.shape", hidden_states.shape)
|
|
||||||
#print("tora_trajectory.shape", video_flow_feature.shape)
|
|
||||||
|
|
||||||
H, W = video_flow_feature.shape[-2:]
|
H, W = video_flow_feature.shape[-2:]
|
||||||
T = norm_hidden_states.shape[1] // H // W
|
T = norm_hidden_states.shape[1] // H // W
|
||||||
|
|
||||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
||||||
#print("h.dtype", h.dtype)
|
|
||||||
|
|
||||||
#video_flow_feature = video_flow_feature.to(h)
|
|
||||||
#print("video_flow_feature.dtype", video_flow_feature.dtype)
|
|
||||||
|
|
||||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||||
# if torch.any(torch.isnan(h)):
|
|
||||||
# #print("hidden_states", h)
|
|
||||||
# raise ValueError("hidden_states has NaN values")
|
|
||||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||||
del h, fuser
|
del h, fuser
|
||||||
|
|
||||||
|
|||||||
43
nodes.py
43
nodes.py
@ -421,7 +421,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||||
for module in transformer.fuser_list:
|
for module in transformer.fuser_list:
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.data = param.data.to(torch.float16).to(device)
|
param.data = param.data.to(torch.float16)
|
||||||
del fuser_sd
|
del fuser_sd
|
||||||
|
|
||||||
from .tora.traj_module import TrajExtractor
|
from .tora.traj_module import TrajExtractor
|
||||||
@ -1004,6 +1004,7 @@ class ToraEncodeTrajectory:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, width, height, num_frames, coordinates):
|
def encode(self, pipeline, width, height, num_frames, coordinates):
|
||||||
|
check_diffusers_version()
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
@ -1011,44 +1012,34 @@ class ToraEncodeTrajectory:
|
|||||||
traj_extractor = pipeline["pipe"].traj_extractor
|
traj_extractor = pipeline["pipe"].traj_extractor
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
|
vae._clear_fake_context_parallel_cache()
|
||||||
|
|
||||||
canvas_width, canvas_height = width, height
|
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
|
||||||
coordinates = json.loads(coordinates.replace("'", '"'))
|
coordinates = json.loads(coordinates.replace("'", '"'))
|
||||||
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||||
|
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
||||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, canvas_width, canvas_height)
|
|
||||||
|
|
||||||
check_diffusers_version()
|
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
|
||||||
vae._clear_fake_context_parallel_cache()
|
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
||||||
|
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
||||||
total_num_frames = num_frames
|
|
||||||
|
|
||||||
video_flow, points = process_traj(traj_list_range_256, total_num_frames, (height,width), device=device)
|
|
||||||
video_flow = video_flow.unsqueeze_(0)
|
|
||||||
|
|
||||||
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
|
|
||||||
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
|
|
||||||
|
|
||||||
del tmp
|
|
||||||
video_flow = (
|
video_flow = (
|
||||||
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
|
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
mm.soft_empty_cache()
|
||||||
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
|
|
||||||
|
|
||||||
|
# VAE encode
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
|
||||||
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
print("video_flow shape", video_flow.shape)
|
|
||||||
|
|
||||||
|
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
video_flow = rearrange(video_flow, "b t d h w -> b d t h w")
|
|
||||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||||
video_flow_features = torch.stack(video_flow_features)
|
video_flow_features = torch.stack(video_flow_features)
|
||||||
|
|
||||||
|
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||||
|
|
||||||
return (video_flow_features,)
|
return (video_flow_features,)
|
||||||
|
|
||||||
|
|
||||||
@ -1293,7 +1284,7 @@ class CogVideoXFunSampler:
|
|||||||
else:
|
else:
|
||||||
context_frames, context_stride, context_overlap = None, None, None
|
context_frames, context_stride, context_overlap = None, None, None
|
||||||
|
|
||||||
generator= torch.Generator(device="cpu").manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
autocastcondition = not pipeline["onediff"]
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
@ -1388,7 +1379,7 @@ class CogVideoXFunVid2VidSampler:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||||
|
|
||||||
generator= torch.Generator(device).manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
autocastcondition = not pipeline["onediff"]
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
@ -1635,7 +1626,7 @@ class CogVideoXFunControlSampler:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||||
|
|
||||||
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
|
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
autocastcondition = not pipeline["onediff"]
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
|
|||||||
@ -513,14 +513,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
height // self.vae_scale_factor_spatial,
|
height // self.vae_scale_factor_spatial,
|
||||||
width // self.vae_scale_factor_spatial,
|
width // self.vae_scale_factor_spatial,
|
||||||
)
|
)
|
||||||
print("padding_shape: ", padding_shape)
|
|
||||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
|
||||||
print(image_cond_latents.shape)
|
|
||||||
print(image_cond_latents[:, 0, :, :, :].shape)
|
|
||||||
print(image_cond_latents[:, -1, :, :, :].shape)
|
|
||||||
|
|
||||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||||
print("image cond latents shape",image_cond_latents.shape)
|
logger.info("image cond latents shape: ",image_cond_latents.shape)
|
||||||
else:
|
else:
|
||||||
logger.info("Only one image conditioning frame received, img2vid")
|
logger.info("Only one image conditioning frame received, img2vid")
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
@ -546,15 +542,15 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
# masks
|
# masks
|
||||||
if self.original_mask is not None:
|
if self.original_mask is not None:
|
||||||
mask = self.original_mask.to(device)
|
mask = self.original_mask.to(device)
|
||||||
print("self.original_mask: ", self.original_mask.shape)
|
logger.info("self.original_mask: ", self.original_mask.shape)
|
||||||
|
|
||||||
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
|
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
|
||||||
if mask.shape[0] != latents.shape[1]:
|
if mask.shape[0] != latents.shape[1]:
|
||||||
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
|
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
|
||||||
else:
|
else:
|
||||||
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
|
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
|
||||||
print("latents: ", latents.shape)
|
logger.info(f"latents: {latents.shape}")
|
||||||
print("mask: ", mask.shape)
|
logger.info(f"mask: {mask.shape}")
|
||||||
|
|
||||||
# 7. Denoising loop
|
# 7. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
@ -566,11 +562,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
t_tile_overlap = context_overlap
|
t_tile_overlap = context_overlap
|
||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||||
use_temporal_tiling = True
|
use_temporal_tiling = True
|
||||||
print("Temporal tiling enabled")
|
logger.info("Temporal tiling enabled")
|
||||||
elif context_schedule is not None:
|
elif context_schedule is not None:
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||||
use_temporal_tiling = False
|
use_temporal_tiling = False
|
||||||
use_context_schedule = True
|
use_context_schedule = True
|
||||||
from .cogvideox_fun.context import get_context_scheduler
|
from .cogvideox_fun.context import get_context_scheduler
|
||||||
@ -579,15 +575,17 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
else:
|
else:
|
||||||
use_temporal_tiling = False
|
use_temporal_tiling = False
|
||||||
use_context_schedule = False
|
use_context_schedule = False
|
||||||
print("Temporal tiling and context schedule disabled")
|
logger.info("Temporal tiling and context schedule disabled")
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
# 9. Controlnet
|
if video_flow_features is not None and do_classifier_free_guidance:
|
||||||
|
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
|
||||||
|
# 9. Controlnet
|
||||||
if controlnet is not None:
|
if controlnet is not None:
|
||||||
self.controlnet = controlnet["control_model"].to(device)
|
self.controlnet = controlnet["control_model"].to(device)
|
||||||
if self.transformer.dtype == torch.float8_e4m3fn:
|
if self.transformer.dtype == torch.float8_e4m3fn:
|
||||||
@ -606,7 +604,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
||||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||||
control_weights = controlnet["control_weights"]
|
control_weights = controlnet["control_weights"]
|
||||||
print("Controlnet enabled with weights: ", control_weights)
|
logger.info(f"Controlnet enabled with weights: {control_weights}")
|
||||||
control_start = controlnet["control_start"]
|
control_start = controlnet["control_start"]
|
||||||
control_end = controlnet["control_end"]
|
control_end = controlnet["control_end"]
|
||||||
else:
|
else:
|
||||||
@ -786,6 +784,13 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
else:
|
else:
|
||||||
for c in context_queue:
|
for c in context_queue:
|
||||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||||
|
if video_flow_features is not None:
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
||||||
|
else:
|
||||||
|
partial_video_flow_features = video_flow_features[:, c, :, :, :]
|
||||||
|
else:
|
||||||
|
partial_video_flow_features = None
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred[:, c, :, :, :] += self.transformer(
|
noise_pred[:, c, :, :, :] += self.transformer(
|
||||||
@ -793,6 +798,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
video_flow_features=partial_video_flow_features,
|
||||||
return_dict=False
|
return_dict=False
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user