From 4d08a997b75e0bb6493f431aa8655ce21e611d7b Mon Sep 17 00:00:00 2001 From: Chris Chance Date: Tue, 15 Oct 2024 15:39:24 -0400 Subject: [PATCH] Forgot the other samplers --- nodes.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 2d0a702..12e0140 100644 --- a/nodes.py +++ b/nodes.py @@ -946,7 +946,9 @@ class CogVideoSampler: padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device) negative = torch.cat((negative, padding), dim=1) - autocastcondition = not pipeline["onediff"] + # Note: if we're on MPS return False, otherwise check for onediff as normal + # TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently. + autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: latents = pipeline["pipe"]( @@ -1213,7 +1215,9 @@ class CogVideoXFunVid2VidSampler: generator= torch.Generator(device).manual_seed(seed) - autocastcondition = not pipeline["onediff"] + # Note: if we're on MPS return False, otherwise check for onediff as normal + # TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently. + autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 @@ -1460,7 +1464,9 @@ class CogVideoXFunControlSampler: generator=torch.Generator(torch.device("cpu")).manual_seed(seed) - autocastcondition = not pipeline["onediff"] + # Note: if we're on MPS return False, otherwise check for onediff as normal + # TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently. + autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: