diff --git a/nodes.py b/nodes.py index 4474d77..dea1dbc 100644 --- a/nodes.py +++ b/nodes.py @@ -121,20 +121,20 @@ class CogVideoPABConfig: @classmethod def INPUT_TYPES(s): return {"required": { - "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}), + "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}), "spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), "spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), - "temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}), + "spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), + "temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}), "temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), "temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), - "cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}), + "temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), + "cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}), "cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), "cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), + "cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ), - "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ), + "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ), } } @@ -142,7 +142,7 @@ class CogVideoPABConfig: RETURN_NAMES = ("pab_config", ) FUNCTION = "config" CATEGORY = "CogVideoWrapper" - DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation" + DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use" def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range, temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, diff --git a/videosys/utils/logging.py b/videosys/utils/logging.py deleted file mode 100644 index 896a4d6..0000000 --- a/videosys/utils/logging.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging - -import torch.distributed as dist -from rich.logging import RichHandler - - -def create_logger(): - """ - Create a logger that writes to a log file and stdout. - """ - logger = logging.getLogger(__name__) - return logger - - -def init_dist_logger(): - """ - Update the logger to write to a log file. - """ - global logger - if dist.get_rank() == 0: - logger = logging.getLogger(__name__) - handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) - formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - else: # dummy logger (does nothing) - logger = logging.getLogger(__name__) - logger.addHandler(logging.NullHandler()) - - -logger = create_logger() diff --git a/videosys/utils/utils.py b/videosys/utils/utils.py deleted file mode 100644 index 622a36d..0000000 --- a/videosys/utils/utils.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import random - -import imageio -import numpy as np -import torch -import torch.distributed as dist -from omegaconf import DictConfig, ListConfig, OmegaConf - - -def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: - """ - Set requires_grad flag for all parameters in a model. - """ - for p in model.parameters(): - p.requires_grad = flag - - -def set_seed(seed, dp_rank=None): - if seed == -1: - seed = random.randint(0, 1000000) - - if dp_rank is not None: - seed = torch.tensor(seed, dtype=torch.int64).cuda() - if dist.get_world_size() > 1: - dist.broadcast(seed, 0) - seed = seed + dp_rank - - seed = int(seed) - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - -def str_to_dtype(x: str): - if x == "fp32": - return torch.float32 - elif x == "fp16": - return torch.float16 - elif x == "bf16": - return torch.bfloat16 - else: - raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}") - - -def batch_func(func, *args): - """ - Apply a function to each element of a batch. - """ - batch = [] - for arg in args: - if isinstance(arg, torch.Tensor) and arg.shape[0] == 2: - batch.append(func(arg)) - else: - batch.append(arg) - - return batch - - -def merge_args(args1, args2): - """ - Merge two argparse Namespace objects. - """ - if args2 is None: - return args1 - - for k in args2._content.keys(): - if k in args1.__dict__: - v = getattr(args2, k) - if isinstance(v, ListConfig) or isinstance(v, DictConfig): - v = OmegaConf.to_object(v) - setattr(args1, k, v) - else: - raise RuntimeError(f"Unknown argument {k}") - - return args1 - - -def all_exists(paths): - return all(os.path.exists(path) for path in paths) - - -def save_video(video, output_path, fps): - """ - Save a video to disk. - """ - if dist.is_initialized() and dist.get_rank() != 0: - return - os.makedirs(os.path.dirname(output_path), exist_ok=True) - imageio.mimwrite(output_path, video, fps=fps)