allow limiting blocks to cache

This commit is contained in:
kijai 2024-11-12 08:42:01 +02:00
parent 0a121dba53
commit dac6a2a3ac
2 changed files with 67 additions and 32 deletions

View File

@ -235,6 +235,7 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
video_flow_feature: Optional[torch.Tensor] = None,
fuser=None,
block_use_fastercache=False,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
@ -257,18 +258,32 @@ class CogVideoXBlock(nn.Module):
del h, fuser
#region fastercache
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
if block_use_fastercache:
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
@ -276,12 +291,6 @@ class CogVideoXBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
@ -477,6 +486,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False):
@ -577,7 +587,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# 2. Patch embedding
p = self.config.patch_size
p_t = self.config.patch_size_t
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
@ -597,7 +607,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
)
@ -665,7 +677,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
image_rotary_emb=image_rotary_emb,
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device
)
#has_nan = torch.isnan(hidden_states).any()

View File

@ -452,6 +452,8 @@ class CogVideoImageInterpolationEncode:
"optional": {
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"mask": ("MASK", ),
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
},
}
@ -460,14 +462,21 @@ class CogVideoImageInterpolationEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, start_image, end_image, chunk_size=8, enable_tiling=False, mask=None):
def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
B, H, W, C = start_image.shape
vae = pipeline["pipe"].vae
vae = pipeline["pipe"].vae if vae_override is None else vae_override
vae.enable_slicing()
model_name = pipeline.get("model_name", "")
if ("1.5" in model_name or "1_5" in model_name):
vae_scaling_factor = 1 / vae.config.scaling_factor
else:
vae_scaling_factor = vae.config.scaling_factor
vae.enable_slicing()
if enable_tiling:
@ -500,8 +509,8 @@ class CogVideoImageInterpolationEncode:
latents_list = []
# Encode the chunk of images
start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae.config.scaling_factor
end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae.config.scaling_factor
start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor
end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
@ -769,6 +778,7 @@ class CogVideoXFasterCache:
"hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
"num_blocks_to_cache": ("INT", {"default": 42, "min": 0, "max": 1024, "step": 1, "tooltip": "Number of transformer blocks to cache, 5b model has 42 blocks, tradeoff between speed and memory"}),
},
}
@ -777,7 +787,7 @@ class CogVideoXFasterCache:
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
def args(self, start_step, hf_step, lf_step, cache_device):
def args(self, start_step, hf_step, lf_step, cache_device, num_blocks_to_cache):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
if cache_device == "cuda:1":
@ -786,7 +796,8 @@ class CogVideoXFasterCache:
"start_step" : start_step,
"hf_step" : hf_step,
"lf_step" : lf_step,
"cache_device" : device if cache_device != "offload_device" else offload_device
"cache_device" : device if cache_device != "offload_device" else offload_device,
"num_blocks_to_cache" : num_blocks_to_cache,
}
return (fastercache,)
@ -832,20 +843,25 @@ class CogVideoSampler:
mm.soft_empty_cache()
base_path = pipeline["base_path"]
model_name = pipeline.get("model_name", "")
supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False
assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'"
assert (
"I2V" not in pipeline.get("model_name", "") or
"1.5" in pipeline.get("model_name", "") or
"1_5" in pipeline.get("model_name", "") or
"I2V" not in model_name or
"1.5" in model_name or
"1_5" in model_name or
num_frames == 49 or
context_options is not None
), "1.0 I2V model can only do 49 frames"
if image_cond_latents is not None:
assert image_cond_latents["samples"].shape[0] == 1, "Image condition latents must be a single latent"
assert "I2V" in pipeline.get("model_name", ""), "Image condition latents only supported for I2V models"
assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models"
if "I2V" in model_name:
assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent"
elif "interpolation" in model_name.lower():
assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents"
else:
assert "I2V" not in pipeline.get("model_name", ""), "Image condition latents required for I2V models"
assert not supports_image_conds, "Image condition latents required for I2V models"
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -885,6 +901,8 @@ class CogVideoSampler:
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
pipe.transformer.fastercache_device = fastercache["cache_device"]
pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"]
log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}")
else:
pipe.transformer.use_fastercache = False
pipe.transformer.fastercache_counter = 0
@ -1001,6 +1019,8 @@ class CogVideoDecode:
latents = samples["samples"]
vae = pipeline["pipe"].vae if vae_override is None else vae_override
additional_frames = getattr(pipeline["pipe"], "additional_frames", 0)
vae.enable_slicing()
if not pipeline["cpu_offloading"]:
@ -1024,7 +1044,8 @@ class CogVideoDecode:
vae._clear_fake_context_parallel_cache()
except:
pass
frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample
frames = vae.decode(latents[:, :, additional_frames:]).sample
vae.disable_tiling()
if not pipeline["cpu_offloading"]:
vae.to(offload_device)