mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-19 00:26:58 +08:00
Merge remote-tracking branch 'kijai/main'
This commit is contained in:
commit
7c11309ebb
30
nodes.py
30
nodes.py
@ -178,7 +178,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
),
|
),
|
||||||
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
|
"vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}),
|
||||||
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
|
"fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,7 +187,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, vae_precision, compile, fp8_fastmode):
|
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
@ -227,20 +227,10 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
|
|
||||||
transformer.to(torch.float8_e4m3fn)
|
transformer.to(torch.float8_e4m3fn)
|
||||||
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
|
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
|
||||||
transformer.to(device)
|
if load_device == "offload_device":
|
||||||
|
transformer.to(offload_device)
|
||||||
# transformer
|
else:
|
||||||
# if fp8_transformer == "fastmode":
|
transformer.to(device)
|
||||||
# if "2b" in model:
|
|
||||||
# for name, param in transformer.named_parameters():
|
|
||||||
# if name != "pos_embedding":
|
|
||||||
# param.data = param.data.to(torch.float8_e4m3fn)
|
|
||||||
# elif "I2V" in model:
|
|
||||||
# for name, param in transformer.named_parameters():
|
|
||||||
# if "patch_embed" not in name:
|
|
||||||
# param.data = param.data.to(torch.float8_e4m3fn)
|
|
||||||
# else:
|
|
||||||
# transformer.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
if fp8_fastmode:
|
if fp8_fastmode:
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
@ -586,16 +576,16 @@ class CogVideoSampler:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
pipe = pipeline["pipe"]
|
pipe = pipeline["pipe"]
|
||||||
dtype = pipeline["dtype"]
|
dtype = pipeline["dtype"]
|
||||||
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
|
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
pipe.transformer.to(device)
|
pipe.transformer.to(device)
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
if scheduler == "DDIM" or scheduler == "DDIM_tiled":
|
if scheduler == "DDIM" or scheduler == "DDIM_tiled":
|
||||||
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
pipe.scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||||
elif scheduler == "DPM":
|
elif scheduler == "DPM":
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(scheduler_config)
|
||||||
|
|
||||||
if negative.shape[1] < positive.shape[1]:
|
if negative.shape[1] < positive.shape[1]:
|
||||||
target_length = positive.shape[1]
|
target_length = positive.shape[1]
|
||||||
@ -674,7 +664,7 @@ class CogVideoDecode:
|
|||||||
latents = latents.to(vae.dtype)
|
latents = latents.to(vae.dtype)
|
||||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||||
latents = 1 / vae.config.scaling_factor * latents
|
latents = 1 / vae.config.scaling_factor * latents
|
||||||
|
vae._clear_fake_context_parallel_cache()
|
||||||
frames = vae.decode(latents).sample
|
frames = vae.decode(latents).sample
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user