Merge remote-tracking branch 'kijai/main'

This commit is contained in:
Phr00t 2024-09-20 16:17:49 -04:00
commit 7c11309ebb

View File

@ -178,7 +178,7 @@ class DownloadAndLoadCogVideoGGUFModel:
), ),
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}, },
} }
@ -187,7 +187,7 @@ class DownloadAndLoadCogVideoGGUFModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, compile, fp8_fastmode): def loadmodel(self, model, vae_precision, fp8_fastmode, load_device):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -227,21 +227,11 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer.to(torch.float8_e4m3fn) transformer.to(torch.float8_e4m3fn)
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu") transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
if load_device == "offload_device":
transformer.to(offload_device)
else:
transformer.to(device) transformer.to(device)
# transformer
# if fp8_transformer == "fastmode":
# if "2b" in model:
# for name, param in transformer.named_parameters():
# if name != "pos_embedding":
# param.data = param.data.to(torch.float8_e4m3fn)
# elif "I2V" in model:
# for name, param in transformer.named_parameters():
# if "patch_embed" not in name:
# param.data = param.data.to(torch.float8_e4m3fn)
# else:
# transformer.to(torch.float8_e4m3fn)
if fp8_fastmode: if fp8_fastmode:
from .fp8_optimization import convert_fp8_linear from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, vae_dtype) convert_fp8_linear(transformer, vae_dtype)
@ -586,16 +576,16 @@ class CogVideoSampler:
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"] pipe = pipeline["pipe"]
dtype = pipeline["dtype"] dtype = pipeline["dtype"]
scheduler_config = pipeline["scheduler_config"]
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"]:
pipe.transformer.to(device) pipe.transformer.to(device)
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
if scheduler == "DDIM" or scheduler == "DDIM_tiled": if scheduler == "DDIM" or scheduler == "DDIM_tiled":
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") pipe.scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
elif scheduler == "DPM": elif scheduler == "DPM":
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler") pipe.scheduler = CogVideoXDPMScheduler.from_config(scheduler_config)
if negative.shape[1] < positive.shape[1]: if negative.shape[1] < positive.shape[1]:
target_length = positive.shape[1] target_length = positive.shape[1]
@ -674,7 +664,7 @@ class CogVideoDecode:
latents = latents.to(vae.dtype) latents = latents.to(vae.dtype)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents latents = 1 / vae.config.scaling_factor * latents
vae._clear_fake_context_parallel_cache()
frames = vae.decode(latents).sample frames = vae.decode(latents).sample
vae.disable_tiling() vae.disable_tiling()
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"]: