cleanup, bugfixes

This commit is contained in:
kijai 2024-10-21 03:24:53 +03:00
parent 7e8a3e4f2a
commit 256a638ee4
3 changed files with 37 additions and 52 deletions

View File

@ -287,25 +287,13 @@ class CogVideoXBlock(nn.Module):
hidden_states, encoder_hidden_states, temb
)
# Motion-guidance Fuser
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
#print(video_flow_feature)
#print("hidden_states.shape", hidden_states.shape)
#print("tora_trajectory.shape", video_flow_feature.shape)
H, W = video_flow_feature.shape[-2:]
T = norm_hidden_states.shape[1] // H // W
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W).to(torch.float16)
#print("h.dtype", h.dtype)
#video_flow_feature = video_flow_feature.to(h)
#print("video_flow_feature.dtype", video_flow_feature.dtype)
h = fuser(h, video_flow_feature.to(h), T=T)
# if torch.any(torch.isnan(h)):
# #print("hidden_states", h)
# raise ValueError("hidden_states has NaN values")
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
del h, fuser

View File

@ -421,7 +421,7 @@ class DownloadAndLoadCogVideoModel:
pipe.transformer.fuser_list.load_state_dict(fuser_sd)
for module in transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(torch.float16).to(device)
param.data = param.data.to(torch.float16)
del fuser_sd
from .tora.traj_module import TrajExtractor
@ -1004,6 +1004,7 @@ class ToraEncodeTrajectory:
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, width, height, num_frames, coordinates):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
@ -1011,44 +1012,34 @@ class ToraEncodeTrajectory:
traj_extractor = pipeline["pipe"].traj_extractor
vae = pipeline["pipe"].vae
vae.enable_slicing()
vae._clear_fake_context_parallel_cache()
canvas_width, canvas_height = width, height
#get coordinates from string and convert to compatible range/format (has to be 256x256 for the model)
coordinates = json.loads(coordinates.replace("'", '"'))
coordinates = [(coord['x'], coord['y']) for coord in coordinates]
traj_list_range_256 = scale_traj_list_to_256(coordinates, canvas_width, canvas_height)
traj_list_range_256 = scale_traj_list_to_256(coordinates, width, height)
check_diffusers_version()
vae._clear_fake_context_parallel_cache()
total_num_frames = num_frames
video_flow, points = process_traj(traj_list_range_256, num_frames, (height,width), device=device)
video_flow = rearrange(video_flow, "T H W C -> T C H W")
video_flow = flow_to_image(video_flow).unsqueeze_(0).to(device) # [1 T C H W]
video_flow, points = process_traj(traj_list_range_256, total_num_frames, (height,width), device=device)
video_flow = video_flow.unsqueeze_(0)
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
del tmp
video_flow = (
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(vae.dtype)
)
torch.cuda.empty_cache()
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
mm.soft_empty_cache()
# VAE encode
if not pipeline["cpu_offloading"]:
vae.to(device)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
print("video_flow shape", video_flow.shape)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
vae.to(offload_device)
video_flow = rearrange(video_flow, "b t d h w -> b d t h w")
video_flow_features = traj_extractor(video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features)
logging.info(f"video_flow shape: {video_flow.shape}")
return (video_flow_features,)
@ -1293,7 +1284,7 @@ class CogVideoXFunSampler:
else:
context_frames, context_stride, context_overlap = None, None, None
generator= torch.Generator(device="cpu").manual_seed(seed)
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
autocastcondition = not pipeline["onediff"]
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
@ -1388,7 +1379,7 @@ class CogVideoXFunVid2VidSampler:
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
generator= torch.Generator(device).manual_seed(seed)
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
autocastcondition = not pipeline["onediff"]
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
@ -1635,7 +1626,7 @@ class CogVideoXFunControlSampler:
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
autocastcondition = not pipeline["onediff"]
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()

View File

@ -513,14 +513,10 @@ class CogVideoXPipeline(VideoSysPipeline):
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
print("padding_shape: ", padding_shape)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae.dtype)
print(image_cond_latents.shape)
print(image_cond_latents[:, 0, :, :, :].shape)
print(image_cond_latents[:, -1, :, :, :].shape)
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
print("image cond latents shape",image_cond_latents.shape)
logger.info("image cond latents shape: ",image_cond_latents.shape)
else:
logger.info("Only one image conditioning frame received, img2vid")
padding_shape = (
@ -546,15 +542,15 @@ class CogVideoXPipeline(VideoSysPipeline):
# masks
if self.original_mask is not None:
mask = self.original_mask.to(device)
print("self.original_mask: ", self.original_mask.shape)
logger.info("self.original_mask: ", self.original_mask.shape)
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
if mask.shape[0] != latents.shape[1]:
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
else:
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
print("latents: ", latents.shape)
print("mask: ", mask.shape)
logger.info(f"latents: {latents.shape}")
logger.info(f"mask: {mask.shape}")
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@ -566,11 +562,11 @@ class CogVideoXPipeline(VideoSysPipeline):
t_tile_overlap = context_overlap
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(self.vae.dtype)
use_temporal_tiling = True
print("Temporal tiling enabled")
logger.info("Temporal tiling enabled")
elif context_schedule is not None:
if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning")
print(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
logger.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap")
use_temporal_tiling = False
use_context_schedule = True
from .cogvideox_fun.context import get_context_scheduler
@ -579,15 +575,17 @@ class CogVideoXPipeline(VideoSysPipeline):
else:
use_temporal_tiling = False
use_context_schedule = False
print("Temporal tiling and context schedule disabled")
logger.info("Temporal tiling and context schedule disabled")
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 9. Controlnet
if video_flow_features is not None and do_classifier_free_guidance:
video_flow_features = video_flow_features.repeat(1, 2, 1, 1, 1).contiguous()
# 9. Controlnet
if controlnet is not None:
self.controlnet = controlnet["control_model"].to(device)
if self.transformer.dtype == torch.float8_e4m3fn:
@ -606,7 +604,7 @@ class CogVideoXPipeline(VideoSysPipeline):
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
control_weights = controlnet["control_weights"]
print("Controlnet enabled with weights: ", control_weights)
logger.info(f"Controlnet enabled with weights: {control_weights}")
control_start = controlnet["control_start"]
control_end = controlnet["control_end"]
else:
@ -786,6 +784,13 @@ class CogVideoXPipeline(VideoSysPipeline):
else:
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
if video_flow_features is not None:
if do_classifier_free_guidance:
partial_video_flow_features = video_flow_features[:, c, :, :, :].repeat(1, 2, 1, 1, 1).contiguous()
else:
partial_video_flow_features = video_flow_features[:, c, :, :, :]
else:
partial_video_flow_features = None
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
@ -793,6 +798,7 @@ class CogVideoXPipeline(VideoSysPipeline):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
video_flow_features=partial_video_flow_features,
return_dict=False
)[0]