sequential_cpu_offload for the VRAM deprived

This commit is contained in:
kijai 2024-08-29 17:19:36 +03:00
parent 7b4488b8ec
commit 0b19b54916

View File

@ -33,6 +33,7 @@ class DownloadAndLoadCogVideoModel:
),
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
}
}
@ -41,7 +42,7 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision, fp8_transformer, compile="disabled"):
def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -69,9 +70,12 @@ class DownloadAndLoadCogVideoModel:
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
pipe = CogVideoXPipeline(vae, transformer, scheduler)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
if compile == "torch":
torch._dynamo.config.suppress_errors = True
@ -93,7 +97,8 @@ class DownloadAndLoadCogVideoModel:
"pipe": pipe,
"dtype": dtype,
"base_path": base_path,
"onediff": True if compile == "onediff" else False
"onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload
}
return (pipeline,)
@ -171,6 +176,7 @@ class CogVideoImageEncode:
},
"optional": {
"chunk_size": ("INT", {"default": 16, "min": 1}),
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
},
}
@ -179,11 +185,15 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image, chunk_size=16):
def encode(self, pipeline, image, chunk_size=16, enable_vae_slicing=True):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
vae = pipeline["pipe"].vae
if enable_vae_slicing:
vae.enable_slicing()
else:
vae.disable_slicing()
vae.to(device)
input_image = image.clone() * 2.0 - 1.0
@ -264,7 +274,8 @@ class CogVideoSampler:
dtype = pipeline["dtype"]
base_path = pipeline["base_path"]
pipe.transformer.to(device)
if not pipeline["cpu_offloading"]:
pipe.transformer.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
if scheduler == "DDIM":
@ -290,7 +301,8 @@ class CogVideoSampler:
generator=generator,
device=device
)
pipe.transformer.to(offload_device)
if not pipeline["cpu_offloading"]:
pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
@ -309,6 +321,7 @@ class CogVideoDecode:
"tile_sample_min_width": ("INT", {"default": 96, "min": 16, "max": 2048, "step": 8}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.083, "min": 0.0, "max": 1.0, "step": 0.001}),
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
}
}
@ -317,12 +330,17 @@ class CogVideoDecode:
FUNCTION = "decode"
CATEGORY = "CogVideoWrapper"
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width):
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, enable_vae_slicing=True):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latents = samples["samples"]
vae = pipeline["pipe"].vae
vae.to(device)
if enable_vae_slicing:
vae.enable_slicing()
else:
vae.disable_slicing()
if not pipeline["cpu_offloading"]:
vae.to(device)
if enable_vae_tiling:
vae.enable_tiling(
tile_sample_min_height=tile_sample_min_height,
@ -335,7 +353,8 @@ class CogVideoDecode:
latents = 1 / vae.config.scaling_factor * latents
frames = vae.decode(latents).sample
vae.to(offload_device)
if not pipeline["cpu_offloading"]:
vae.to(offload_device)
mm.soft_empty_cache()
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")