diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index d2637ed..730fef9 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -287,25 +287,13 @@ class CogVideoXBlock(nn.Module): hidden_states, encoder_hidden_states, temb ) - # Motion-guidance Fuser + # Tora Motion-guidance Fuser if video_flow_feature is not None: - #print(video_flow_feature) - #print("hidden_states.shape", hidden_states.shape) - #print("tora_trajectory.shape", video_flow_feature.shape) - H, W = video_flow_feature.shape[-2:] T = norm_hidden_states.shape[1] // H // W h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16) - #print("h.dtype", h.dtype) - - #video_flow_feature = video_flow_feature.to(h) - #print("video_flow_feature.dtype", video_flow_feature.dtype) - h = fuser(h, video_flow_feature.to(h), T=T) - # if torch.any(torch.isnan(h)): - # #print("hidden_states", h) - # raise ValueError("hidden_states has NaN values") norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) del h, fuser diff --git a/nodes.py b/nodes.py index aa0e07b..9119a83 100644 --- a/nodes.py +++ b/nodes.py @@ -421,7 +421,7 @@ class DownloadAndLoadCogVideoModel: pipe.transformer.fuser_list.load_state_dict(fuser_sd) for module in transformer.fuser_list: for param in module.parameters(): - param.data = param.data.to(torch.float16).to(device) + param.data = param.data.to(torch.float16) del fuser_sd from .tora.traj_module import TrajExtractor @@ -1004,6 +1004,7 @@ class ToraEncodeTrajectory: CATEGORY = "CogVideoWrapper" def encode(self, pipeline, width, height, num_frames, coordinates): + check_diffusers_version() device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) @@ -1011,44 +1012,34 @@ class ToraEncodeTrajectory: traj_extractor = pipeline["pipe"].traj_extractor vae = pipeline["pipe"].vae vae.enable_slicing() + vae._clear_fake_context_parallel_cache() - canvas_width, canvas_height = width, height + #get coordinates from string and convert to compatible range/format (has to be 256x256 for the model) coordinates = json.loads(coordinates.replace("'", '"')) coordinates = [(coord['x'], coord['y']) for coord in coordinates] - - traj_list_range_256 = scale_traj_list_to_256(coordinates, canvas_width, canvas_height) + traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height) - check_diffusers_version() - vae._clear_fake_context_parallel_cache() - - total_num_frames = num_frames + video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device) + video_flow = rearrange(video_flow, "T H W C -> T C H W") + video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W] - video_flow, points = process_traj(traj_list_range_256, total_num_frames, (height,width), device=device) - video_flow = video_flow.unsqueeze_(0) - - tmp = rearrange(video_flow[0], "T H W C -> T C H W") - video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W] - - del tmp video_flow = ( - rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16) + rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype) ) - torch.cuda.empty_cache() - video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition + mm.soft_empty_cache() + # VAE encode if not pipeline["cpu_offloading"]: vae.to(device) - - video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor - video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous() - print("video_flow shape", video_flow.shape) + video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor vae.to(offload_device) - video_flow = rearrange(video_flow, "b t d h w -> b d t h w") video_flow_features = traj_extractor(video_flow.to(torch.float32)) video_flow_features = torch.stack(video_flow_features) + logging.info(f"video_flow shape: {video_flow.shape}") + return (video_flow_features,) @@ -1293,7 +1284,7 @@ class CogVideoXFunSampler: else: context_frames, context_stride, context_overlap = None, None, None - generator= torch.Generator(device="cpu").manual_seed(seed) + generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() @@ -1388,7 +1379,7 @@ class CogVideoXFunVid2VidSampler: else: raise ValueError(f"Unknown scheduler: {scheduler}") - generator= torch.Generator(device).manual_seed(seed) + generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() @@ -1635,7 +1626,7 @@ class CogVideoXFunControlSampler: else: raise ValueError(f"Unknown scheduler: {scheduler}") - generator=torch.Generator(torch.device("cpu")).manual_seed(seed) + generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) autocastcondition = not pipeline["onediff"] autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index ceff8b2..5039b31 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -513,14 +513,10 @@ class CogVideoXPipeline(VideoSysPipeline): height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - print("padding_shape: ", padding_shape) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype) - print(image_cond_latents.shape) - print(image_cond_latents[:, 0, :, :, :].shape) - print(image_cond_latents[:, -1, :, :, :].shape) image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1) - print("image cond latents shape",image_cond_latents.shape) + logger.info("image cond latents shape: ",image_cond_latents.shape) else: logger.info("Only one image conditioning frame received, img2vid") padding_shape = ( @@ -546,15 +542,15 @@ class CogVideoXPipeline(VideoSysPipeline): # masks if self.original_mask is not None: mask = self.original_mask.to(device) - print("self.original_mask: ", self.original_mask.shape) + logger.info("self.original_mask: ", self.original_mask.shape) mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False) if mask.shape[0] != latents.shape[1]: mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1) else: mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1) - print("latents: ", latents.shape) - print("mask: ", mask.shape) + logger.info(f"latents: {latents.shape}") + logger.info(f"mask: {mask.shape}") # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -566,11 +562,11 @@ class CogVideoXPipeline(VideoSysPipeline): t_tile_overlap = context_overlap t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype) use_temporal_tiling = True - print("Temporal tiling enabled") + logger.info("Temporal tiling enabled") elif context_schedule is not None: if image_cond_latents is not None: raise NotImplementedError("Context schedule not currently supported with image conditioning") - print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") + logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") use_temporal_tiling = False use_context_schedule = True from .cogvideox_fun.context import get_context_scheduler @@ -579,15 +575,17 @@ class CogVideoXPipeline(VideoSysPipeline): else: use_temporal_tiling = False use_context_schedule = False - print("Temporal tiling and context schedule disabled") + logger.info("Temporal tiling and context schedule disabled") # 7. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None ) - # 9. Controlnet + if video_flow_features is not None and do_classifier_free_guidance: + video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous() + # 9. Controlnet if controlnet is not None: self.controlnet = controlnet["control_model"].to(device) if self.transformer.dtype == torch.float8_e4m3fn: @@ -606,7 +604,7 @@ class CogVideoXPipeline(VideoSysPipeline): control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous() control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames control_weights = controlnet["control_weights"] - print("Controlnet enabled with weights: ", control_weights) + logger.info(f"Controlnet enabled with weights: {control_weights}") control_start = controlnet["control_start"] control_end = controlnet["control_end"] else: @@ -786,6 +784,13 @@ class CogVideoXPipeline(VideoSysPipeline): else: for c in context_queue: partial_latent_model_input = latent_model_input[:, c, :, :, :] + if video_flow_features is not None: + if do_classifier_free_guidance: + partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous() + else: + partial_video_flow_features = video_flow_features[:, c, :, :, :] + else: + partial_video_flow_features = None # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer( @@ -793,6 +798,7 @@ class CogVideoXPipeline(VideoSysPipeline): encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + video_flow_features=partial_video_flow_features, return_dict=False )[0]