allow limiting blocks to cache

This commit is contained in:
kijai 2024-11-12 08:42:01 +02:00
parent 0a121dba53
commit dac6a2a3ac
2 changed files with 67 additions and 32 deletions

View File

@ -235,6 +235,7 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None, video_flow_feature: Optional[torch.Tensor] = None,
fuser=None, fuser=None,
block_use_fastercache=False,
fastercache_counter=0, fastercache_counter=0,
fastercache_start_step=15, fastercache_start_step=15,
fastercache_device="cuda:0", fastercache_device="cuda:0",
@ -257,18 +258,32 @@ class CogVideoXBlock(nn.Module):
del h, fuser del h, fuser
#region fastercache #region fastercache
B = norm_hidden_states.shape[0] if block_use_fastercache:
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: B = norm_hidden_states.shape[0]
attn_hidden_states = ( if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
self.cached_hidden_states[1][:B] + attn_hidden_states = (
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) self.cached_hidden_states[1][:B] +
* 0.3 (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
).to(norm_hidden_states.device, non_blocking=True) * 0.3
attn_encoder_hidden_states = ( ).to(norm_hidden_states.device, non_blocking=True)
self.cached_encoder_hidden_states[1][:B] + attn_encoder_hidden_states = (
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) self.cached_encoder_hidden_states[1][:B] +
* 0.3 (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
).to(norm_hidden_states.device, non_blocking=True) * 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
else: else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1( attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
@ -276,12 +291,6 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode, attention_mode=self.attention_mode,
) )
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
@ -477,6 +486,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_lf_step = 40 self.fastercache_lf_step = 40
self.fastercache_hf_step = 30 self.fastercache_hf_step = 30
self.fastercache_device = "cuda" self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = attention_mode self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
@ -577,7 +587,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Patch embedding # 2. Patch embedding
p = self.config.patch_size p = self.config.patch_size
p_t = self.config.patch_size_t p_t = self.config.patch_size_t
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
@ -597,7 +607,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter, fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device fastercache_device = self.fastercache_device
) )
@ -665,7 +677,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter, fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device fastercache_device = self.fastercache_device
) )
#has_nan = torch.isnan(hidden_states).any() #has_nan = torch.isnan(hidden_states).any()

View File

@ -452,6 +452,8 @@ class CogVideoImageInterpolationEncode:
"optional": { "optional": {
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"mask": ("MASK", ), "mask": ("MASK", ),
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
}, },
} }
@ -460,14 +462,21 @@ class CogVideoImageInterpolationEncode:
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, start_image, end_image, chunk_size=8, enable_tiling=False, mask=None): def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
B, H, W, C = start_image.shape B, H, W, C = start_image.shape
vae = pipeline["pipe"].vae vae = pipeline["pipe"].vae if vae_override is None else vae_override
vae.enable_slicing()
model_name = pipeline.get("model_name", "")
if ("1.5" in model_name or "1_5" in model_name):
vae_scaling_factor = 1 / vae.config.scaling_factor
else:
vae_scaling_factor = vae.config.scaling_factor
vae.enable_slicing() vae.enable_slicing()
if enable_tiling: if enable_tiling:
@ -500,8 +509,8 @@ class CogVideoImageInterpolationEncode:
latents_list = [] latents_list = []
# Encode the chunk of images # Encode the chunk of images
start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae.config.scaling_factor start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor
end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae.config.scaling_factor end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
@ -769,6 +778,7 @@ class CogVideoXFasterCache:
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}), "hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}), "lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}), "cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
"num_blocks_to_cache": ("INT", {"default": 42, "min": 0, "max": 1024, "step": 1, "tooltip": "Number of transformer blocks to cache, 5b model has 42 blocks, tradeoff between speed and memory"}),
}, },
} }
@ -777,7 +787,7 @@ class CogVideoXFasterCache:
FUNCTION = "args" FUNCTION = "args"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def args(self, start_step, hf_step, lf_step, cache_device): def args(self, start_step, hf_step, lf_step, cache_device, num_blocks_to_cache):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
if cache_device == "cuda:1": if cache_device == "cuda:1":
@ -786,7 +796,8 @@ class CogVideoXFasterCache:
"start_step" : start_step, "start_step" : start_step,
"hf_step" : hf_step, "hf_step" : hf_step,
"lf_step" : lf_step, "lf_step" : lf_step,
"cache_device" : device if cache_device != "offload_device" else offload_device "cache_device" : device if cache_device != "offload_device" else offload_device,
"num_blocks_to_cache" : num_blocks_to_cache,
} }
return (fastercache,) return (fastercache,)
@ -832,20 +843,25 @@ class CogVideoSampler:
mm.soft_empty_cache() mm.soft_empty_cache()
base_path = pipeline["base_path"] base_path = pipeline["base_path"]
model_name = pipeline.get("model_name", "")
supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'" assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
assert ( assert (
"I2V" not in pipeline.get("model_name", "") or "I2V" not in model_name or
"1.5" in pipeline.get("model_name", "") or "1.5" in model_name or
"1_5" in pipeline.get("model_name", "") or "1_5" in model_name or
num_frames == 49 or num_frames == 49 or
context_options is not None context_options is not None
), "1.0 I2V model can only do 49 frames" ), "1.0 I2V model can only do 49 frames"
if image_cond_latents is not None: if image_cond_latents is not None:
assert image_cond_latents["samples"].shape[0] == 1, "Image condition latents must be a single latent" assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models"
assert "I2V" in pipeline.get("model_name", ""), "Image condition latents only supported for I2V models" if "I2V" in model_name:
assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent"
elif "interpolation" in model_name.lower():
assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents"
else: else:
assert "I2V" not in pipeline.get("model_name", ""), "Image condition latents required for I2V models" assert not supports_image_conds, "Image condition latents required for I2V models"
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -885,6 +901,8 @@ class CogVideoSampler:
pipe.transformer.fastercache_lf_step = fastercache["lf_step"] pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
pipe.transformer.fastercache_hf_step = fastercache["hf_step"] pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
pipe.transformer.fastercache_device = fastercache["cache_device"] pipe.transformer.fastercache_device = fastercache["cache_device"]
pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"]
log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}")
else: else:
pipe.transformer.use_fastercache = False pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0 pipe.transformer.fastercache_counter = 0
@ -1001,6 +1019,8 @@ class CogVideoDecode:
latents = samples["samples"] latents = samples["samples"]
vae = pipeline["pipe"].vae if vae_override is None else vae_override vae = pipeline["pipe"].vae if vae_override is None else vae_override
additional_frames = getattr(pipeline["pipe"], "additional_frames", 0)
vae.enable_slicing() vae.enable_slicing()
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"]:
@ -1024,7 +1044,8 @@ class CogVideoDecode:
vae._clear_fake_context_parallel_cache() vae._clear_fake_context_parallel_cache()
except: except:
pass pass
frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample
frames = vae.decode(latents[:, :, additional_frames:]).sample
vae.disable_tiling() vae.disable_tiling()
if not pipeline["cpu_offloading"]: if not pipeline["cpu_offloading"]:
vae.to(offload_device) vae.to(offload_device)