support transformer at fp8

This commit is contained in:
kijai 2024-08-27 19:13:17 +03:00
parent 46fe99f102
commit 65e9e995e8
3 changed files with 41 additions and 42 deletions

View File

@ -818,7 +818,7 @@
"widgets_values": {
"frame_rate": 8,
"loop_count": 0,
"filename_prefix": "AnimateDiff",
"filename_prefix": "CogVideoX_vid2vid",
"format": "video/nvenc_h264-mp4",
"pix_fmt": "yuv420p",
"bitrate": 10,

View File

@ -25,15 +25,11 @@ class DownloadAndLoadCogVideoModel:
},
"optional": {
"precision": (
[
"fp16",
"fp32",
"bf16",
],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"},
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
},
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
}
}
RETURN_TYPES = ("COGVIDEOPIPE",)
@ -41,12 +37,16 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision):
def loadmodel(self, model, precision, fp8_transformer):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if fp8_transformer:
transformer_dtype = torch.float8_e4m3fn
else:
transformer_dtype = dtype
if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
@ -63,7 +63,7 @@ class DownloadAndLoadCogVideoModel:
local_dir=base_path,
local_dir_use_symlinks=False,
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
@ -247,22 +247,22 @@ class CogVideoSampler:
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
elif scheduler == "DPM":
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
latents = pipeline["pipe"](
num_inference_steps=steps,
height = height,
width = width,
num_frames = num_frames,
t_tile_length = t_tile_length,
t_tile_overlap = t_tile_overlap,
guidance_scale=cfg,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
prompt_embeds=positive.to(dtype).to(device),
negative_prompt_embeds=negative.to(dtype).to(device),
generator=generator,
device=device
)
with torch.autocast(mm.get_autocast_device(device)):
latents = pipeline["pipe"](
num_inference_steps=steps,
height = height,
width = width,
num_frames = num_frames,
t_tile_length = t_tile_length,
t_tile_overlap = t_tile_overlap,
guidance_scale=cfg,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
prompt_embeds=positive.to(dtype).to(device),
negative_prompt_embeds=negative.to(dtype).to(device),
generator=generator,
device=device
)
pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
@ -297,7 +297,7 @@ class CogVideoDecode:
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
latents = latents.to(vae.dtype)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents

View File

@ -170,13 +170,13 @@ class CogVideoXPipeline(DiffusionPipeline):
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
@ -400,6 +400,7 @@ class CogVideoXPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@ -425,6 +426,7 @@ class CogVideoXPipeline(DiffusionPipeline):
num_inference_steps,
latents
)
latents = latents.to(self.transformer.dtype)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@ -551,18 +553,15 @@ class CogVideoXPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents.to(self.vae.dtype),
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):