mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-13 21:52:22 +08:00
Forgot the other samplers
This commit is contained in:
parent
035a370fa1
commit
4d08a997b7
12
nodes.py
12
nodes.py
@ -946,7 +946,9 @@ class CogVideoSampler:
|
|||||||
padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device)
|
padding = torch.zeros((negative.shape[0], target_length - negative.shape[1], negative.shape[2]), device=negative.device)
|
||||||
negative = torch.cat((negative, padding), dim=1)
|
negative = torch.cat((negative, padding), dim=1)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
# Note: if we're on MPS return False, otherwise check for onediff as normal
|
||||||
|
# TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently.
|
||||||
|
autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
with autocast_context:
|
with autocast_context:
|
||||||
latents = pipeline["pipe"](
|
latents = pipeline["pipe"](
|
||||||
@ -1213,7 +1215,9 @@ class CogVideoXFunVid2VidSampler:
|
|||||||
|
|
||||||
generator= torch.Generator(device).manual_seed(seed)
|
generator= torch.Generator(device).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
# Note: if we're on MPS return False, otherwise check for onediff as normal
|
||||||
|
# TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently.
|
||||||
|
autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
with autocast_context:
|
with autocast_context:
|
||||||
video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
|
video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
|
||||||
@ -1460,7 +1464,9 @@ class CogVideoXFunControlSampler:
|
|||||||
|
|
||||||
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
|
generator=torch.Generator(torch.device("cpu")).manual_seed(seed)
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"]
|
# Note: if we're on MPS return False, otherwise check for onediff as normal
|
||||||
|
# TODO: Torch 2.5+ is supposed to add AMP support for MPS ... not available currently.
|
||||||
|
autocastcondition = not pipeline["onediff"] if not torch.backends.mps.is_available() else False
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
with autocast_context:
|
with autocast_context:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user