mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-04 22:46:47 +08:00
initial 5B support
This commit is contained in:
parent
8457fa7a4d
commit
7b80e61e36
40
nodes.py
40
nodes.py
@ -16,6 +16,12 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
|
"model": (
|
||||||
|
[
|
||||||
|
"THUDM/CogVideoX-2b",
|
||||||
|
"THUDM/CogVideoX-5b",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
@ -35,21 +41,24 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, precision):
|
def loadmodel(self, model, precision):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
|
||||||
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
if "2b" in model:
|
||||||
|
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
||||||
|
elif "5b" in model:
|
||||||
|
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideoX-5b")
|
||||||
|
|
||||||
if not os.path.exists(base_path):
|
if not os.path.exists(base_path):
|
||||||
log.info(f"Downloading model to: {base_path}")
|
log.info(f"Downloading model to: {base_path}")
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="THUDM/CogVideoX-2b",
|
repo_id=model,
|
||||||
ignore_patterns=["*text_encoder*"],
|
ignore_patterns=["*text_encoder*"],
|
||||||
local_dir=base_path,
|
local_dir=base_path,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
@ -199,14 +208,14 @@ class CogVideoSampler:
|
|||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||||
"num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 8}),
|
"num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 1}),
|
||||||
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
|
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
|
||||||
"steps": ("INT", {"default": 25, "min": 1}),
|
"steps": ("INT", {"default": 25, "min": 1}),
|
||||||
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"scheduler": (["DDIM", "DPM"],),
|
"scheduler": (["DDIM", "DPM"],),
|
||||||
"t_tile_length": ("INT", {"default": 16, "min": 16, "max": 128, "step": 4}),
|
"t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1}),
|
||||||
"t_tile_overlap": ("INT", {"default": 8, "min": 8, "max": 128, "step": 2}),
|
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
@ -276,10 +285,10 @@ class CogVideoDecode:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
RETURN_NAMES = ("images",)
|
RETURN_NAMES = ("images",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, samples):
|
def decode(self, pipeline, samples):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
latents = samples["samples"]
|
latents = samples["samples"]
|
||||||
@ -299,19 +308,20 @@ class CogVideoDecode:
|
|||||||
|
|
||||||
frames = []
|
frames = []
|
||||||
pbar = ProgressBar(num_seconds)
|
pbar = ProgressBar(num_seconds)
|
||||||
for i in range(num_seconds):
|
# for i in range(num_seconds):
|
||||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
# start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||||
current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample
|
# current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample
|
||||||
frames.append(current_frames)
|
# frames.append(current_frames)
|
||||||
|
|
||||||
pbar.update(1)
|
# pbar.update(1)
|
||||||
vae.clear_fake_context_parallel_cache()
|
frames = vae.decode(latents).sample
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
frames = torch.cat(frames, dim=2)
|
#frames = torch.cat(frames, dim=2)
|
||||||
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
||||||
video = video[0].permute(0, 2, 3, 1).cpu().float()
|
video = video[0].permute(0, 2, 3, 1).cpu().float()
|
||||||
|
print(video.min(), video.max())
|
||||||
|
|
||||||
return (video,)
|
return (video,)
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import inspect
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
@ -24,11 +25,29 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
|||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
from diffusers.video_processor import VideoProcessor
|
from diffusers.video_processor import VideoProcessor
|
||||||
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||||
|
tw = tgt_width
|
||||||
|
th = tgt_height
|
||||||
|
h, w = src
|
||||||
|
r = h / w
|
||||||
|
if r > (th / tw):
|
||||||
|
resize_height = th
|
||||||
|
resize_width = int(round(th / h * w))
|
||||||
|
else:
|
||||||
|
resize_width = tw
|
||||||
|
resize_height = int(round(tw / w * h))
|
||||||
|
|
||||||
|
crop_top = int(round((th - resize_height) / 2.0))
|
||||||
|
crop_left = int(round((tw - resize_width) / 2.0))
|
||||||
|
|
||||||
|
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
def retrieve_timesteps(
|
def retrieve_timesteps(
|
||||||
scheduler,
|
scheduler,
|
||||||
@ -228,6 +247,46 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
weights = torch.tensor(t_probs)
|
weights = torch.tensor(t_probs)
|
||||||
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self) -> None:
|
||||||
|
r"""Enables fused QKV projections."""
|
||||||
|
self.fusing_transformer = True
|
||||||
|
self.transformer.fuse_qkv_projections()
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self) -> None:
|
||||||
|
r"""Disable QKV projection fusion if enabled."""
|
||||||
|
if not self.fusing_transformer:
|
||||||
|
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||||
|
else:
|
||||||
|
self.transformer.unfuse_qkv_projections()
|
||||||
|
self.fusing_transformer = False
|
||||||
|
|
||||||
|
def _prepare_rotary_positional_embeddings(
|
||||||
|
self,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
num_frames: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
|
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
|
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
|
|
||||||
|
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||||
|
(grid_height, grid_width), base_size_width, base_size_height
|
||||||
|
)
|
||||||
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
|
embed_dim=self.transformer.config.attention_head_dim,
|
||||||
|
crops_coords=grid_crops_coords,
|
||||||
|
grid_size=(grid_height, grid_width),
|
||||||
|
temporal_size=num_frames,
|
||||||
|
use_real=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
|
freqs_sin = freqs_sin.to(device=device)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guidance_scale(self):
|
def guidance_scale(self):
|
||||||
@ -374,6 +433,15 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype)
|
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype)
|
||||||
print("latents.shape", latents.shape)
|
print("latents.shape", latents.shape)
|
||||||
print("latents.device", latents.device)
|
print("latents.device", latents.device)
|
||||||
|
|
||||||
|
|
||||||
|
# 6.5. Create rotary embeds if required
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# 7. Denoising loop
|
# 7. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
comfy_pbar = ProgressBar(num_inference_steps)
|
comfy_pbar = ProgressBar(num_inference_steps)
|
||||||
@ -383,94 +451,125 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
|
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
||||||
|
# =====================================================
|
||||||
|
grid_ts = 0
|
||||||
|
cur_t = 0
|
||||||
|
while cur_t < latents.shape[1]:
|
||||||
|
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
||||||
|
grid_ts += 1
|
||||||
|
|
||||||
|
all_t = latents.shape[1]
|
||||||
|
latents_all_list = []
|
||||||
|
# =====================================================
|
||||||
|
|
||||||
|
for t_i in range(grid_ts):
|
||||||
|
if t_i < grid_ts - 1:
|
||||||
|
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||||
|
if t_i == grid_ts - 1:
|
||||||
|
ofs_t = all_t - t_tile_length
|
||||||
|
|
||||||
|
input_start_t = ofs_t
|
||||||
|
input_end_t = ofs_t + t_tile_length
|
||||||
|
|
||||||
|
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
|
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||||
|
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||||
|
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||||
|
|
||||||
|
#t_input = t[None].to(device)
|
||||||
|
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
|
||||||
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
|
# predict noise model_output
|
||||||
# =====================================================
|
noise_pred = self.transformer(
|
||||||
grid_ts = 0
|
hidden_states=latent_model_input_tile,
|
||||||
cur_t = 0
|
encoder_hidden_states=prompt_embeds,
|
||||||
while cur_t < latents.shape[1]:
|
timestep=t_input,
|
||||||
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
|
image_rotary_emb=image_rotary_emb,
|
||||||
grid_ts += 1
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
all_t = latents.shape[1]
|
if self.do_classifier_free_guidance:
|
||||||
latents_all_list = []
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
# =====================================================
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
for t_i in range(grid_ts):
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
if t_i < grid_ts - 1:
|
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
latents_all_list.append(latents_tile)
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
# ==========================================
|
||||||
input_end_t = ofs_t + t_tile_length
|
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
|
||||||
|
contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
|
||||||
|
# Add each tile contribution to overall latents
|
||||||
|
for t_i in range(grid_ts):
|
||||||
|
if t_i < grid_ts - 1:
|
||||||
|
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
||||||
|
if t_i == grid_ts - 1:
|
||||||
|
ofs_t = all_t - t_tile_length
|
||||||
|
|
||||||
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
input_start_t = ofs_t
|
||||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
input_end_t = ofs_t + t_tile_length
|
||||||
|
|
||||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
||||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
||||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
|
||||||
|
latents_all /= contributors
|
||||||
|
|
||||||
|
latents = latents_all
|
||||||
|
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
comfy_pbar.update(1)
|
||||||
|
# ==========================================
|
||||||
|
else:
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latent_model_input.shape[0])
|
||||||
|
|
||||||
#t_input = t[None].to(device)
|
|
||||||
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=latent_model_input_tile,
|
hidden_states=latent_model_input,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
timestep=t_input,
|
timestep=timestep,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
if self.do_classifier_free_guidance:
|
|
||||||
|
self._guidance_scale = 1 + guidance_scale * (
|
||||||
|
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DPM is not supported with temporal tiling")
|
latents, old_pred_original_sample = self.scheduler.step(
|
||||||
# else:
|
noise_pred,
|
||||||
# latents_tile, old_pred_original_sample = self.scheduler.step(
|
old_pred_original_sample,
|
||||||
# noise_pred,
|
t,
|
||||||
# old_pred_original_sample,
|
timesteps[i - 1] if i > 0 else None,
|
||||||
# t,
|
latents,
|
||||||
# t_input[t_i - 1] if t_i > 0 else None,
|
**extra_step_kwargs,
|
||||||
# latents_tile,
|
return_dict=False,
|
||||||
# **extra_step_kwargs,
|
)
|
||||||
# return_dict=False,
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
# )
|
|
||||||
|
|
||||||
latents_all_list.append(latents_tile)
|
|
||||||
|
|
||||||
# ==========================================
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
|
progress_bar.update()
|
||||||
contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
|
comfy_pbar.update(1)
|
||||||
# Add each tile contribution to overall latents
|
|
||||||
for t_i in range(grid_ts):
|
|
||||||
if t_i < grid_ts - 1:
|
|
||||||
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
|
|
||||||
if t_i == grid_ts - 1:
|
|
||||||
ofs_t = all_t - t_tile_length
|
|
||||||
|
|
||||||
input_start_t = ofs_t
|
|
||||||
input_end_t = ofs_t + t_tile_length
|
|
||||||
|
|
||||||
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
|
|
||||||
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
|
|
||||||
|
|
||||||
latents_all /= contributors
|
|
||||||
|
|
||||||
latents = latents_all
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
||||||
progress_bar.update()
|
|
||||||
comfy_pbar.update(1)
|
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
self.maybe_free_model_hooks()
|
self.maybe_free_model_hooks()
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
huggingface_hub
|
huggingface_hub
|
||||||
diffusers>=0.30.0
|
diffusers>=0.30.1
|
||||||
Loading…
x
Reference in New Issue
Block a user