mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
cleanup
This commit is contained in:
parent
49e42c5dc6
commit
7f3a768d95
16
nodes.py
16
nodes.py
@ -121,20 +121,20 @@ class CogVideoPABConfig:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
|
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
|
||||||
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||||
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||||
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
|
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
|
||||||
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||||
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}),
|
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
|
||||||
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||||
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
|
||||||
|
|
||||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
|
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ class CogVideoPABConfig:
|
|||||||
RETURN_NAMES = ("pab_config", )
|
RETURN_NAMES = ("pab_config", )
|
||||||
FUNCTION = "config"
|
FUNCTION = "config"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
|
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
|
||||||
|
|
||||||
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
||||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
||||||
|
|||||||
@ -1,32 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from rich.logging import RichHandler
|
|
||||||
|
|
||||||
|
|
||||||
def create_logger():
|
|
||||||
"""
|
|
||||||
Create a logger that writes to a log file and stdout.
|
|
||||||
"""
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def init_dist_logger():
|
|
||||||
"""
|
|
||||||
Update the logger to write to a log file.
|
|
||||||
"""
|
|
||||||
global logger
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
|
|
||||||
formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
else: # dummy logger (does nothing)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.addHandler(logging.NullHandler())
|
|
||||||
|
|
||||||
|
|
||||||
logger = create_logger()
|
|
||||||
@ -1,92 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import imageio
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Set requires_grad flag for all parameters in a model.
|
|
||||||
"""
|
|
||||||
for p in model.parameters():
|
|
||||||
p.requires_grad = flag
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed, dp_rank=None):
|
|
||||||
if seed == -1:
|
|
||||||
seed = random.randint(0, 1000000)
|
|
||||||
|
|
||||||
if dp_rank is not None:
|
|
||||||
seed = torch.tensor(seed, dtype=torch.int64).cuda()
|
|
||||||
if dist.get_world_size() > 1:
|
|
||||||
dist.broadcast(seed, 0)
|
|
||||||
seed = seed + dp_rank
|
|
||||||
|
|
||||||
seed = int(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(x: str):
|
|
||||||
if x == "fp32":
|
|
||||||
return torch.float32
|
|
||||||
elif x == "fp16":
|
|
||||||
return torch.float16
|
|
||||||
elif x == "bf16":
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
|
|
||||||
|
|
||||||
|
|
||||||
def batch_func(func, *args):
|
|
||||||
"""
|
|
||||||
Apply a function to each element of a batch.
|
|
||||||
"""
|
|
||||||
batch = []
|
|
||||||
for arg in args:
|
|
||||||
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
|
|
||||||
batch.append(func(arg))
|
|
||||||
else:
|
|
||||||
batch.append(arg)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def merge_args(args1, args2):
|
|
||||||
"""
|
|
||||||
Merge two argparse Namespace objects.
|
|
||||||
"""
|
|
||||||
if args2 is None:
|
|
||||||
return args1
|
|
||||||
|
|
||||||
for k in args2._content.keys():
|
|
||||||
if k in args1.__dict__:
|
|
||||||
v = getattr(args2, k)
|
|
||||||
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
|
|
||||||
v = OmegaConf.to_object(v)
|
|
||||||
setattr(args1, k, v)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unknown argument {k}")
|
|
||||||
|
|
||||||
return args1
|
|
||||||
|
|
||||||
|
|
||||||
def all_exists(paths):
|
|
||||||
return all(os.path.exists(path) for path in paths)
|
|
||||||
|
|
||||||
|
|
||||||
def save_video(video, output_path, fps):
|
|
||||||
"""
|
|
||||||
Save a video to disk.
|
|
||||||
"""
|
|
||||||
if dist.is_initialized() and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
imageio.mimwrite(output_path, video, fps=fps)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user