This commit is contained in:
Jukka Seppänen 2024-09-23 03:10:45 +03:00
parent 49e42c5dc6
commit 7f3a768d95
3 changed files with 8 additions and 132 deletions

View File

@ -121,20 +121,20 @@ class CogVideoPABConfig:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB, highest impact"}),
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB, medium impact"}),
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}),
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB, low impact"}),
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range, higher values are faster but quality may suffer"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Should match the sampling steps"} ),
}
}
@ -142,7 +142,7 @@ class CogVideoPABConfig:
RETURN_NAMES = ("pab_config", )
FUNCTION = "config"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation. Increases memory use"
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,

View File

@ -1,32 +0,0 @@
import logging
import torch.distributed as dist
from rich.logging import RichHandler
def create_logger():
"""
Create a logger that writes to a log file and stdout.
"""
logger = logging.getLogger(__name__)
return logger
def init_dist_logger():
"""
Update the logger to write to a log file.
"""
global logger
if dist.get_rank() == 0:
logger = logging.getLogger(__name__)
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
logger = create_logger()

View File

@ -1,92 +0,0 @@
import os
import random
import imageio
import numpy as np
import torch
import torch.distributed as dist
from omegaconf import DictConfig, ListConfig, OmegaConf
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def set_seed(seed, dp_rank=None):
if seed == -1:
seed = random.randint(0, 1000000)
if dp_rank is not None:
seed = torch.tensor(seed, dtype=torch.int64).cuda()
if dist.get_world_size() > 1:
dist.broadcast(seed, 0)
seed = seed + dp_rank
seed = int(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def str_to_dtype(x: str):
if x == "fp32":
return torch.float32
elif x == "fp16":
return torch.float16
elif x == "bf16":
return torch.bfloat16
else:
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
def batch_func(func, *args):
"""
Apply a function to each element of a batch.
"""
batch = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
batch.append(func(arg))
else:
batch.append(arg)
return batch
def merge_args(args1, args2):
"""
Merge two argparse Namespace objects.
"""
if args2 is None:
return args1
for k in args2._content.keys():
if k in args1.__dict__:
v = getattr(args2, k)
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
v = OmegaConf.to_object(v)
setattr(args1, k, v)
else:
raise RuntimeError(f"Unknown argument {k}")
return args1
def all_exists(paths):
return all(os.path.exists(path) for path in paths)
def save_video(video, output_path, fps):
"""
Save a video to disk.
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
os.makedirs(os.path.dirname(output_path), exist_ok=True)
imageio.mimwrite(output_path, video, fps=fps)