mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-04 15:13:34 +08:00
sequential_cpu_offload for the VRAM deprived
This commit is contained in:
parent
7b4488b8ec
commit
0b19b54916
35
nodes.py
35
nodes.py
@ -33,6 +33,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
),
|
),
|
||||||
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer, compile="disabled"):
|
def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
@ -69,9 +70,12 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
)
|
)
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
|
|
||||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||||
|
|
||||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||||
|
if enable_sequential_cpu_offload:
|
||||||
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
if compile == "torch":
|
if compile == "torch":
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
@ -93,7 +97,8 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"base_path": base_path,
|
"base_path": base_path,
|
||||||
"onediff": True if compile == "onediff" else False
|
"onediff": True if compile == "onediff" else False,
|
||||||
|
"cpu_offloading": enable_sequential_cpu_offload
|
||||||
}
|
}
|
||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
@ -171,6 +176,7 @@ class CogVideoImageEncode:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
||||||
|
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,11 +185,15 @@ class CogVideoImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, image, chunk_size=16):
|
def encode(self, pipeline, image, chunk_size=16, enable_vae_slicing=True):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
|
if enable_vae_slicing:
|
||||||
|
vae.enable_slicing()
|
||||||
|
else:
|
||||||
|
vae.disable_slicing()
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
input_image = image.clone() * 2.0 - 1.0
|
input_image = image.clone() * 2.0 - 1.0
|
||||||
@ -264,7 +274,8 @@ class CogVideoSampler:
|
|||||||
dtype = pipeline["dtype"]
|
dtype = pipeline["dtype"]
|
||||||
base_path = pipeline["base_path"]
|
base_path = pipeline["base_path"]
|
||||||
|
|
||||||
pipe.transformer.to(device)
|
if not pipeline["cpu_offloading"]:
|
||||||
|
pipe.transformer.to(device)
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
if scheduler == "DDIM":
|
if scheduler == "DDIM":
|
||||||
@ -290,7 +301,8 @@ class CogVideoSampler:
|
|||||||
generator=generator,
|
generator=generator,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
pipe.transformer.to(offload_device)
|
if not pipeline["cpu_offloading"]:
|
||||||
|
pipe.transformer.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
print(latents.shape)
|
print(latents.shape)
|
||||||
|
|
||||||
@ -309,6 +321,7 @@ class CogVideoDecode:
|
|||||||
"tile_sample_min_width": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}),
|
"tile_sample_min_width": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}),
|
||||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_height": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_width": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,12 +330,17 @@ class CogVideoDecode:
|
|||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width):
|
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, enable_vae_slicing=True):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
latents = samples["samples"]
|
latents = samples["samples"]
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
vae.to(device)
|
if enable_vae_slicing:
|
||||||
|
vae.enable_slicing()
|
||||||
|
else:
|
||||||
|
vae.disable_slicing()
|
||||||
|
if not pipeline["cpu_offloading"]:
|
||||||
|
vae.to(device)
|
||||||
if enable_vae_tiling:
|
if enable_vae_tiling:
|
||||||
vae.enable_tiling(
|
vae.enable_tiling(
|
||||||
tile_sample_min_height=tile_sample_min_height,
|
tile_sample_min_height=tile_sample_min_height,
|
||||||
@ -335,7 +353,8 @@ class CogVideoDecode:
|
|||||||
latents = 1 / vae.config.scaling_factor * latents
|
latents = 1 / vae.config.scaling_factor * latents
|
||||||
|
|
||||||
frames = vae.decode(latents).sample
|
frames = vae.decode(latents).sample
|
||||||
vae.to(offload_device)
|
if not pipeline["cpu_offloading"]:
|
||||||
|
vae.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user