diff --git a/nodes.py b/nodes.py index 59fce7d..2d0a702 100644 --- a/nodes.py +++ b/nodes.py @@ -1116,7 +1116,9 @@ class CogVideoXFunSampler: generator= torch.Generator(device="cpu").manual_seed(seed) - autocastcondition = not pipeline["onediff"] + # Note: if we're on MPS return False, otherwise check for onediff as normal + # TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently. + autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1