mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-23 18:44:26 +08:00
Add ConsistencyFlowMatchEulerDiscreteScheduler
This commit is contained in:
parent
77dbdea8b8
commit
1ecc8e3195
@ -45,6 +45,7 @@ import comfy.model_management as mm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@ -154,6 +155,7 @@ class Hunyuan3DDiTPipeline:
|
||||
compile_args=None,
|
||||
attention_mode="sdpa",
|
||||
cublas_ops=False,
|
||||
scheduler="FlowMatchEulerDiscreteScheduler",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -224,7 +226,13 @@ class Hunyuan3DDiTPipeline:
|
||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
||||
|
||||
image_processor = instantiate_from_config(config['image_processor'])
|
||||
scheduler = instantiate_from_config(config['scheduler'])
|
||||
|
||||
if scheduler == "FlowMatchEulerDiscreteScheduler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
|
||||
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
|
||||
|
||||
#scheduler = instantiate_from_config(config['scheduler'])
|
||||
|
||||
if compile_args is not None:
|
||||
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
|
||||
|
||||
@ -12,6 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -305,3 +319,162 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
pcm_timesteps: int = 50,
|
||||
):
|
||||
sigmas = np.linspace(0, 1, num_train_timesteps)
|
||||
step_ratio = num_train_timesteps // pcm_timesteps
|
||||
|
||||
euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
|
||||
euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
|
||||
|
||||
self.euler_timesteps = euler_timesteps
|
||||
self.sigmas = sigmas[self.euler_timesteps]
|
||||
self.sigmas = torch.from_numpy((self.sigmas.copy()))
|
||||
self.timesteps = self.sigmas * num_train_timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
|
||||
inference_indices = np.linspace(
|
||||
0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
|
||||
)
|
||||
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||||
inference_indices = torch.from_numpy(inference_indices).long()
|
||||
|
||||
self.sigmas_ = self.sigmas[inference_indices]
|
||||
timesteps = self.sigmas_ * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas_ = torch.cat(
|
||||
[self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
|
||||
)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas_[self.step_index]
|
||||
sigma_next = self.sigmas_[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
pred_original_sample = sample + (1.0 - sigma) * model_output
|
||||
pred_original_sample = pred_original_sample.to(model_output.dtype)
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
|
||||
pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
22
nodes.py
22
nodes.py
@ -12,6 +12,8 @@ from tqdm import tqdm
|
||||
from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover
|
||||
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
|
||||
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
|
||||
from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.schedulers import (
|
||||
@ -1028,6 +1030,7 @@ class Hy3DGenerateMesh:
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"scheduler": (["FlowMatchEulerDiscreteScheduler", "ConsistencyFlowMatchEulerDiscreteScheduler"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1036,7 +1039,7 @@ class Hy3DGenerateMesh:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
|
||||
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None, scheduler="FlowMatchEulerDiscreteScheduler"):
|
||||
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
@ -1052,6 +1055,13 @@ class Hy3DGenerateMesh:
|
||||
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
|
||||
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
|
||||
|
||||
if scheduler == "FlowMatchEulerDiscreteScheduler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
|
||||
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
|
||||
|
||||
pipeline.scheduler = scheduler
|
||||
|
||||
pipeline.to(device)
|
||||
|
||||
try:
|
||||
@ -1090,7 +1100,8 @@ class Hy3DGenerateMeshMultiView():
|
||||
"front": ("IMAGE", ),
|
||||
"left": ("IMAGE", ),
|
||||
"right": ("IMAGE", ),
|
||||
"back": ("IMAGE", ),
|
||||
"back": ("IMAGE", ),
|
||||
"scheduler": (["FlowMatchEulerDiscreteScheduler", "ConsistencyFlowMatchEulerDiscreteScheduler"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1099,7 +1110,7 @@ class Hy3DGenerateMeshMultiView():
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
|
||||
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None, scheduler="FlowMatchEulerDiscreteScheduler"):
|
||||
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
@ -1125,6 +1136,11 @@ class Hy3DGenerateMeshMultiView():
|
||||
'back': back
|
||||
}
|
||||
|
||||
if scheduler == "FlowMatchEulerDiscreteScheduler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
|
||||
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
|
||||
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user