support transformer at fp8

This commit is contained in:
kijai 2024-08-27 19:13:17 +03:00
parent 46fe99f102
commit 65e9e995e8
3 changed files with 41 additions and 42 deletions

View File

@ -818,7 +818,7 @@
"widgets_values": { "widgets_values": {
"frame_rate": 8, "frame_rate": 8,
"loop_count": 0, "loop_count": 0,
"filename_prefix": "AnimateDiff", "filename_prefix": "CogVideoX_vid2vid",
"format": "video/nvenc_h264-mp4", "format": "video/nvenc_h264-mp4",
"pix_fmt": "yuv420p", "pix_fmt": "yuv420p",
"bitrate": 10, "bitrate": 10,

View File

@ -25,15 +25,11 @@ class DownloadAndLoadCogVideoModel:
}, },
"optional": { "optional": {
"precision": ( "precision": (["fp16", "fp32", "bf16"],
[ {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
"fp16",
"fp32",
"bf16",
],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"},
), ),
}, "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
}
} }
RETURN_TYPES = ("COGVIDEOPIPE",) RETURN_TYPES = ("COGVIDEOPIPE",)
@ -41,12 +37,16 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision): def loadmodel(self, model, precision, fp8_transformer):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if fp8_transformer:
transformer_dtype = torch.float8_e4m3fn
else:
transformer_dtype = dtype
if "2b" in model: if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
@ -63,7 +63,7 @@ class DownloadAndLoadCogVideoModel:
local_dir=base_path, local_dir=base_path,
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
) )
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device) transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
@ -247,7 +247,7 @@ class CogVideoSampler:
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
elif scheduler == "DPM": elif scheduler == "DPM":
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler") pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
with torch.autocast(mm.get_autocast_device(device)):
latents = pipeline["pipe"]( latents = pipeline["pipe"](
num_inference_steps=steps, num_inference_steps=steps,
height = height, height = height,
@ -297,7 +297,7 @@ class CogVideoDecode:
tile_overlap_factor_height=1 / 12, tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12, tile_overlap_factor_width=1 / 12,
) )
latents = latents.to(vae.dtype)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents latents = 1 / vae.config.scaling_factor * latents

View File

@ -170,13 +170,13 @@ class CogVideoXPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
else: else:
latents = latents.to(device) latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1] latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1] frames_needed = noise.shape[1]
current_frames = latents.shape[1] current_frames = latents.shape[1]
@ -400,6 +400,7 @@ class CogVideoXPipeline(DiffusionPipeline):
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@ -425,6 +426,7 @@ class CogVideoXPipeline(DiffusionPipeline):
num_inference_steps, num_inference_steps,
latents latents
) )
latents = latents.to(self.transformer.dtype)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@ -551,15 +553,12 @@ class CogVideoXPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step( latents, old_pred_original_sample = self.scheduler.step(
noise_pred, noise_pred,
old_pred_original_sample, old_pred_original_sample,
t, t,
timesteps[i - 1] if i > 0 else None, timesteps[i - 1] if i > 0 else None,
latents, latents.to(self.vae.dtype),
**extra_step_kwargs, **extra_step_kwargs,
return_dict=False, return_dict=False,
) )