mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
allow limiting blocks to cache
This commit is contained in:
parent
0a121dba53
commit
dac6a2a3ac
@ -235,6 +235,7 @@ class CogVideoXBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_flow_feature: Optional[torch.Tensor] = None,
|
||||
fuser=None,
|
||||
block_use_fastercache=False,
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
@ -257,6 +258,7 @@ class CogVideoXBlock(nn.Module):
|
||||
del h, fuser
|
||||
|
||||
#region fastercache
|
||||
if block_use_fastercache:
|
||||
B = norm_hidden_states.shape[0]
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||
attn_hidden_states = (
|
||||
@ -282,6 +284,13 @@ class CogVideoXBlock(nn.Module):
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
@ -477,6 +486,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
@ -597,7 +607,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
@ -665,7 +677,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
#has_nan = torch.isnan(hidden_states).any()
|
||||
|
||||
47
nodes.py
47
nodes.py
@ -452,6 +452,8 @@ class CogVideoImageInterpolationEncode:
|
||||
"optional": {
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
"mask": ("MASK", ),
|
||||
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
|
||||
|
||||
},
|
||||
}
|
||||
|
||||
@ -460,14 +462,21 @@ class CogVideoImageInterpolationEncode:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, start_image, end_image, chunk_size=8, enable_tiling=False, mask=None):
|
||||
def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
B, H, W, C = start_image.shape
|
||||
|
||||
vae = pipeline["pipe"].vae
|
||||
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||
vae.enable_slicing()
|
||||
model_name = pipeline.get("model_name", "")
|
||||
|
||||
if ("1.5" in model_name or "1_5" in model_name):
|
||||
vae_scaling_factor = 1 / vae.config.scaling_factor
|
||||
else:
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
vae.enable_slicing()
|
||||
|
||||
if enable_tiling:
|
||||
@ -500,8 +509,8 @@ class CogVideoImageInterpolationEncode:
|
||||
latents_list = []
|
||||
|
||||
# Encode the chunk of images
|
||||
start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae.config.scaling_factor
|
||||
start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor
|
||||
end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor
|
||||
|
||||
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
@ -769,6 +778,7 @@ class CogVideoXFasterCache:
|
||||
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
|
||||
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
|
||||
"cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
|
||||
"num_blocks_to_cache": ("INT", {"default": 42, "min": 0, "max": 1024, "step": 1, "tooltip": "Number of transformer blocks to cache, 5b model has 42 blocks, tradeoff between speed and memory"}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -777,7 +787,7 @@ class CogVideoXFasterCache:
|
||||
FUNCTION = "args"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def args(self, start_step, hf_step, lf_step, cache_device):
|
||||
def args(self, start_step, hf_step, lf_step, cache_device, num_blocks_to_cache):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
if cache_device == "cuda:1":
|
||||
@ -786,7 +796,8 @@ class CogVideoXFasterCache:
|
||||
"start_step" : start_step,
|
||||
"hf_step" : hf_step,
|
||||
"lf_step" : lf_step,
|
||||
"cache_device" : device if cache_device != "offload_device" else offload_device
|
||||
"cache_device" : device if cache_device != "offload_device" else offload_device,
|
||||
"num_blocks_to_cache" : num_blocks_to_cache,
|
||||
}
|
||||
return (fastercache,)
|
||||
|
||||
@ -832,20 +843,25 @@ class CogVideoSampler:
|
||||
mm.soft_empty_cache()
|
||||
|
||||
base_path = pipeline["base_path"]
|
||||
model_name = pipeline.get("model_name", "")
|
||||
supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False
|
||||
|
||||
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
|
||||
assert (
|
||||
"I2V" not in pipeline.get("model_name", "") or
|
||||
"1.5" in pipeline.get("model_name", "") or
|
||||
"1_5" in pipeline.get("model_name", "") or
|
||||
"I2V" not in model_name or
|
||||
"1.5" in model_name or
|
||||
"1_5" in model_name or
|
||||
num_frames == 49 or
|
||||
context_options is not None
|
||||
), "1.0 I2V model can only do 49 frames"
|
||||
if image_cond_latents is not None:
|
||||
assert image_cond_latents["samples"].shape[0] == 1, "Image condition latents must be a single latent"
|
||||
assert "I2V" in pipeline.get("model_name", ""), "Image condition latents only supported for I2V models"
|
||||
assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models"
|
||||
if "I2V" in model_name:
|
||||
assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent"
|
||||
elif "interpolation" in model_name.lower():
|
||||
assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents"
|
||||
else:
|
||||
assert "I2V" not in pipeline.get("model_name", ""), "Image condition latents required for I2V models"
|
||||
assert not supports_image_conds, "Image condition latents required for I2V models"
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -885,6 +901,8 @@ class CogVideoSampler:
|
||||
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
|
||||
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
|
||||
pipe.transformer.fastercache_device = fastercache["cache_device"]
|
||||
pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"]
|
||||
log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}")
|
||||
else:
|
||||
pipe.transformer.use_fastercache = False
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
@ -1001,6 +1019,8 @@ class CogVideoDecode:
|
||||
latents = samples["samples"]
|
||||
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||
|
||||
additional_frames = getattr(pipeline["pipe"], "additional_frames", 0)
|
||||
|
||||
vae.enable_slicing()
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
@ -1024,7 +1044,8 @@ class CogVideoDecode:
|
||||
vae._clear_fake_context_parallel_cache()
|
||||
except:
|
||||
pass
|
||||
frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample
|
||||
|
||||
frames = vae.decode(latents[:, :, additional_frames:]).sample
|
||||
vae.disable_tiling()
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(offload_device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user