2024-10-30 21:30:03 +02:00

64 lines
2.2 KiB
Python

class PABConfig:
def __init__(
self,
steps: int,
cross_broadcast: bool = False,
cross_threshold: list = None,
cross_range: int = None,
spatial_broadcast: bool = False,
spatial_threshold: list = None,
spatial_range: int = None,
temporal_broadcast: bool = False,
temporal_threshold: list = None,
temporal_range: int = None,
mlp_broadcast: bool = False,
mlp_spatial_broadcast_config: dict = None,
mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
self.temporal_range = temporal_range
self.mlp_broadcast = mlp_broadcast
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
self.mlp_temporal_outputs = {}
self.mlp_spatial_outputs = {}
class CogVideoXPABConfig(PABConfig):
def __init__(
self,
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = False,
temporal_threshold: list = [100, 850],
temporal_range: int = 4,
cross_broadcast: bool = False,
cross_threshold: list = [100, 850],
cross_range: int = 6,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range
)