mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-17 12:14:22 +08:00
65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
from typing import Callable, Iterable, Sequence, Union
|
|
|
|
|
|
def checkpoint(
|
|
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
|
|
inputs: Sequence[torch.Tensor],
|
|
params: Iterable[torch.Tensor],
|
|
flag: bool,
|
|
use_deepspeed: bool = False
|
|
):
|
|
# Evaluate a function without caching intermediate activations, allowing for
|
|
# reduced memory at the expense of extra compute in the backward pass.
|
|
# :param func: the function to evaluate.
|
|
# :param inputs: the argument sequence to pass to `func`.
|
|
# :param params: a sequence of parameters `func` depends on but does not
|
|
# explicitly take as arguments.
|
|
# :param flag: if False, disable gradient checkpointing.
|
|
# :param use_deepspeed: if True, use deepspeed
|
|
if flag:
|
|
if use_deepspeed:
|
|
import deepspeed
|
|
return deepspeed.checkpointing.checkpoint(func, *inputs)
|
|
|
|
args = tuple(inputs) + tuple(params)
|
|
return CheckpointFunction.apply(func, len(inputs), *args)
|
|
else:
|
|
return func(*inputs)
|
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
@torch.amp.custom_fwd(device_type="cuda")
|
|
def forward(ctx, run_function, length, *args):
|
|
ctx.run_function = run_function
|
|
ctx.input_tensors = list(args[:length])
|
|
ctx.input_params = list(args[length:])
|
|
|
|
with torch.no_grad():
|
|
output_tensors = ctx.run_function(*ctx.input_tensors)
|
|
return output_tensors
|
|
|
|
@staticmethod
|
|
@torch.amp.custom_bwd(device_type="cuda")
|
|
def backward(ctx, *output_grads):
|
|
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
|
with torch.enable_grad():
|
|
# Fixes a bug where the first op in run_function modifies the
|
|
# Tensor storage in place, which is not allowed for detach()'d
|
|
# Tensors.
|
|
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
|
output_tensors = ctx.run_function(*shallow_copies)
|
|
input_grads = torch.autograd.grad(
|
|
output_tensors,
|
|
ctx.input_tensors + ctx.input_params,
|
|
output_grads,
|
|
allow_unused=True,
|
|
)
|
|
del ctx.input_tensors
|
|
del ctx.input_params
|
|
del output_tensors
|
|
return (None, None) + input_grads
|