fix VAE scaling (again)

This commit is contained in:
kijai 2024-11-13 15:37:45 +02:00
parent 34b650c785
commit e8a289112f
2 changed files with 6 additions and 9 deletions

View File

@ -350,8 +350,6 @@ class DownloadAndLoadCogVideoGGUFModel:
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
pab_config=None, block_edit=None, compile="disabled", attention_mode="sdpa"):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -597,9 +595,6 @@ class DownloadAndLoadToraModel:
DESCRIPTION = "Downloads and loads the the Tora model from Huggingface to 'ComfyUI/models/CogVideo/CogVideoX-5b-Tora'"
def loadmodel(self, model):
check_diffusers_version()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()

View File

@ -298,7 +298,7 @@ class CogVideoTextEncode:
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
if embeds.shape[1] > 226:
if embeds.shape[1] > max_tokens:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
embeds *= strength
if force_offload:
@ -371,7 +371,7 @@ class CogVideoImageEncode:
model_name = pipeline.get("model_name", "")
if ("1.5" in model_name or "1_5" in model_name) and image.shape[0] == 1:
vae_scaling_factor = 1 / vae.config.scaling_factor
vae_scaling_factor = 1 #/ vae.config.scaling_factor
else:
vae_scaling_factor = vae.config.scaling_factor
@ -599,16 +599,18 @@ class ToraEncodeTrajectory:
vae.to(device)
video_flow = vae.encode(video_flow).latent_dist.sample(generator) * vae.config.scaling_factor
log.info(f"video_flow shape after encoding: {video_flow.shape}") #torch.Size([1, 16, 4, 80, 80])
if not pipeline["cpu_offloading"]:
vae.to(offload_device)
#print("video_flow shape before traj_extractor: ", video_flow.shape) #torch.Size([1, 16, 4, 80, 80])
video_flow_features = tora_model["traj_extractor"](video_flow.to(torch.float32))
video_flow_features = torch.stack(video_flow_features)
#print("video_flow_features after traj_extractor: ", video_flow_features.shape) #torch.Size([42, 4, 128, 40, 40])
video_flow_features = video_flow_features * strength
log.info(f"video_flow shape: {video_flow.shape}")
tora = {
"video_flow_features" : video_flow_features,