mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
cleanup, bugfixes
This commit is contained in:
parent
7e8a3e4f2a
commit
256a638ee4
@ -287,25 +287,13 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# Motion-guidance Fuser
|
||||
# Tora Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
#print(video_flow_feature)
|
||||
#print("hidden_states.shape", hidden_states.shape)
|
||||
#print("tora_trajectory.shape", video_flow_feature.shape)
|
||||
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
T = norm_hidden_states.shape[1] // H // W
|
||||
|
||||
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
|
||||
#print("h.dtype", h.dtype)
|
||||
|
||||
#video_flow_feature = video_flow_feature.to(h)
|
||||
#print("video_flow_feature.dtype", video_flow_feature.dtype)
|
||||
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
# if torch.any(torch.isnan(h)):
|
||||
# #print("hidden_states", h)
|
||||
# raise ValueError("hidden_states has NaN values")
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
|
||||
|
||||
43
nodes.py
43
nodes.py
@ -421,7 +421,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
|
||||
for module in transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(torch.float16).to(device)
|
||||
param.data = param.data.to(torch.float16)
|
||||
del fuser_sd
|
||||
|
||||
from .tora.traj_module import TrajExtractor
|
||||
@ -1004,6 +1004,7 @@ class ToraEncodeTrajectory:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, width, height, num_frames, coordinates):
|
||||
check_diffusers_version()
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
@ -1011,44 +1012,34 @@ class ToraEncodeTrajectory:
|
||||
traj_extractor = pipeline["pipe"].traj_extractor
|
||||
vae = pipeline["pipe"].vae
|
||||
vae.enable_slicing()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
|
||||
canvas_width, canvas_height = width, height
|
||||
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
|
||||
coordinates = json.loads(coordinates.replace("'", '"'))
|
||||
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
|
||||
|
||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, canvas_width, canvas_height)
|
||||
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
|
||||
|
||||
check_diffusers_version()
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
|
||||
total_num_frames = num_frames
|
||||
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
|
||||
video_flow = rearrange(video_flow, "T H W C -> T C H W")
|
||||
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
|
||||
|
||||
video_flow, points = process_traj(traj_list_range_256, total_num_frames, (height,width), device=device)
|
||||
video_flow = video_flow.unsqueeze_(0)
|
||||
|
||||
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
|
||||
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
|
||||
|
||||
del tmp
|
||||
video_flow = (
|
||||
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
|
||||
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
|
||||
mm.soft_empty_cache()
|
||||
|
||||
# VAE encode
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(device)
|
||||
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
|
||||
print("video_flow shape", video_flow.shape)
|
||||
|
||||
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
vae.to(offload_device)
|
||||
|
||||
video_flow = rearrange(video_flow, "b t d h w -> b d t h w")
|
||||
video_flow_features = traj_extractor(video_flow.to(torch.float32))
|
||||
video_flow_features = torch.stack(video_flow_features)
|
||||
|
||||
logging.info(f"video_flow shape: {video_flow.shape}")
|
||||
|
||||
return (video_flow_features,)
|
||||
|
||||
|
||||
@ -1293,7 +1284,7 @@ class CogVideoXFunSampler:
|
||||
else:
|
||||
context_frames, context_stride, context_overlap = None, None, None
|
||||
|
||||
generator= torch.Generator(device="cpu").manual_seed(seed)
|
||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
@ -1388,7 +1379,7 @@ class CogVideoXFunVid2VidSampler:
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||
|
||||
generator= torch.Generator(device).manual_seed(seed)
|
||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
@ -1635,7 +1626,7 @@ class CogVideoXFunControlSampler:
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||
|
||||
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
|
||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"]
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
|
||||
@ -513,14 +513,10 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
print("padding_shape: ", padding_shape)
|
||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
|
||||
print(image_cond_latents.shape)
|
||||
print(image_cond_latents[:, 0, :, :, :].shape)
|
||||
print(image_cond_latents[:, -1, :, :, :].shape)
|
||||
|
||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||
print("image cond latents shape",image_cond_latents.shape)
|
||||
logger.info("image cond latents shape: ",image_cond_latents.shape)
|
||||
else:
|
||||
logger.info("Only one image conditioning frame received, img2vid")
|
||||
padding_shape = (
|
||||
@ -546,15 +542,15 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
# masks
|
||||
if self.original_mask is not None:
|
||||
mask = self.original_mask.to(device)
|
||||
print("self.original_mask: ", self.original_mask.shape)
|
||||
logger.info("self.original_mask: ", self.original_mask.shape)
|
||||
|
||||
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
|
||||
if mask.shape[0] != latents.shape[1]:
|
||||
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
|
||||
else:
|
||||
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
|
||||
print("latents: ", latents.shape)
|
||||
print("mask: ", mask.shape)
|
||||
logger.info(f"latents: {latents.shape}")
|
||||
logger.info(f"mask: {mask.shape}")
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@ -566,11 +562,11 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
t_tile_overlap = context_overlap
|
||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
|
||||
use_temporal_tiling = True
|
||||
print("Temporal tiling enabled")
|
||||
logger.info("Temporal tiling enabled")
|
||||
elif context_schedule is not None:
|
||||
if image_cond_latents is not None:
|
||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = True
|
||||
from .cogvideox_fun.context import get_context_scheduler
|
||||
@ -579,15 +575,17 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
else:
|
||||
use_temporal_tiling = False
|
||||
use_context_schedule = False
|
||||
print("Temporal tiling and context schedule disabled")
|
||||
logger.info("Temporal tiling and context schedule disabled")
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
# 9. Controlnet
|
||||
if video_flow_features is not None and do_classifier_free_guidance:
|
||||
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
|
||||
|
||||
# 9. Controlnet
|
||||
if controlnet is not None:
|
||||
self.controlnet = controlnet["control_model"].to(device)
|
||||
if self.transformer.dtype == torch.float8_e4m3fn:
|
||||
@ -606,7 +604,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||
control_weights = controlnet["control_weights"]
|
||||
print("Controlnet enabled with weights: ", control_weights)
|
||||
logger.info(f"Controlnet enabled with weights: {control_weights}")
|
||||
control_start = controlnet["control_start"]
|
||||
control_end = controlnet["control_end"]
|
||||
else:
|
||||
@ -786,6 +784,13 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
else:
|
||||
for c in context_queue:
|
||||
partial_latent_model_input = latent_model_input[:, c, :, :, :]
|
||||
if video_flow_features is not None:
|
||||
if do_classifier_free_guidance:
|
||||
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
|
||||
else:
|
||||
partial_video_flow_features = video_flow_features[:, c, :, :, :]
|
||||
else:
|
||||
partial_video_flow_features = None
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
@ -793,6 +798,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_features=partial_video_flow_features,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user