mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
cleanup
This commit is contained in:
parent
49e42c5dc6
commit
7f3a768d95
16
nodes.py
16
nodes.py
@ -121,20 +121,20 @@ class CogVideoPABConfig:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
|
||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
|
||||
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
|
||||
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
|
||||
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}),
|
||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
|
||||
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||
|
||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
|
||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,7 +142,7 @@ class CogVideoPABConfig:
|
||||
RETURN_NAMES = ("pab_config", )
|
||||
FUNCTION = "config"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
|
||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
|
||||
|
||||
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
import logging
|
||||
|
||||
import torch.distributed as dist
|
||||
from rich.logging import RichHandler
|
||||
|
||||
|
||||
def create_logger():
|
||||
"""
|
||||
Create a logger that writes to a log file and stdout.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
return logger
|
||||
|
||||
|
||||
def init_dist_logger():
|
||||
"""
|
||||
Update the logger to write to a log file.
|
||||
"""
|
||||
global logger
|
||||
if dist.get_rank() == 0:
|
||||
logger = logging.getLogger(__name__)
|
||||
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
|
||||
formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
else: # dummy logger (does nothing)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
logger = create_logger()
|
||||
@ -1,92 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
|
||||
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
|
||||
"""
|
||||
Set requires_grad flag for all parameters in a model.
|
||||
"""
|
||||
for p in model.parameters():
|
||||
p.requires_grad = flag
|
||||
|
||||
|
||||
def set_seed(seed, dp_rank=None):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 1000000)
|
||||
|
||||
if dp_rank is not None:
|
||||
seed = torch.tensor(seed, dtype=torch.int64).cuda()
|
||||
if dist.get_world_size() > 1:
|
||||
dist.broadcast(seed, 0)
|
||||
seed = seed + dp_rank
|
||||
|
||||
seed = int(seed)
|
||||
random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
def str_to_dtype(x: str):
|
||||
if x == "fp32":
|
||||
return torch.float32
|
||||
elif x == "fp16":
|
||||
return torch.float16
|
||||
elif x == "bf16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
|
||||
|
||||
|
||||
def batch_func(func, *args):
|
||||
"""
|
||||
Apply a function to each element of a batch.
|
||||
"""
|
||||
batch = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
|
||||
batch.append(func(arg))
|
||||
else:
|
||||
batch.append(arg)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def merge_args(args1, args2):
|
||||
"""
|
||||
Merge two argparse Namespace objects.
|
||||
"""
|
||||
if args2 is None:
|
||||
return args1
|
||||
|
||||
for k in args2._content.keys():
|
||||
if k in args1.__dict__:
|
||||
v = getattr(args2, k)
|
||||
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
|
||||
v = OmegaConf.to_object(v)
|
||||
setattr(args1, k, v)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown argument {k}")
|
||||
|
||||
return args1
|
||||
|
||||
|
||||
def all_exists(paths):
|
||||
return all(os.path.exists(path) for path in paths)
|
||||
|
||||
|
||||
def save_video(video, output_path, fps):
|
||||
"""
|
||||
Save a video to disk.
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
imageio.mimwrite(output_path, video, fps=fps)
|
||||
Loading…
x
Reference in New Issue
Block a user