Add ConsistencyFlowMatchEulerDiscreteScheduler

This commit is contained in:
kijai 2025-03-19 13:30:02 +02:00
parent 77dbdea8b8
commit 1ecc8e3195
3 changed files with 201 additions and 4 deletions

View File

@ -45,6 +45,7 @@ import comfy.model_management as mm
logger = logging.getLogger(__name__)
from .schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
def retrieve_timesteps(
scheduler,
@ -154,6 +155,7 @@ class Hunyuan3DDiTPipeline:
compile_args=None,
attention_mode="sdpa",
cublas_ops=False,
scheduler="FlowMatchEulerDiscreteScheduler",
**kwargs,
):
@ -224,7 +226,13 @@ class Hunyuan3DDiTPipeline:
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
image_processor = instantiate_from_config(config['image_processor'])
scheduler = instantiate_from_config(config['scheduler'])
if scheduler == "FlowMatchEulerDiscreteScheduler":
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
#scheduler = instantiate_from_config(config['scheduler'])
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]

View File

@ -12,6 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@ -305,3 +319,162 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def __len__(self):
return self.config.num_train_timesteps
@dataclass
class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
pred_original_sample: torch.FloatTensor
class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
pcm_timesteps: int = 50,
):
sigmas = np.linspace(0, 1, num_train_timesteps)
step_ratio = num_train_timesteps // pcm_timesteps
euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
self.euler_timesteps = euler_timesteps
self.sigmas = sigmas[self.euler_timesteps]
self.sigmas = torch.from_numpy((self.sigmas.copy()))
self.timesteps = self.sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
inference_indices = np.linspace(
0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = torch.from_numpy(inference_indices).long()
self.sigmas_ = self.sigmas[inference_indices]
timesteps = self.sigmas_ * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas_ = torch.cat(
[self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
)
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32)
sigma = self.sigmas_[self.step_index]
sigma_next = self.sigmas_[self.step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output
prev_sample = prev_sample.to(model_output.dtype)
pred_original_sample = sample + (1.0 - sigma) * model_output
pred_original_sample = pred_original_sample.to(model_output.dtype)
self._step_index += 1
if not return_dict:
return (prev_sample,)
return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
pred_original_sample=pred_original_sample)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -12,6 +12,8 @@ from tqdm import tqdm
from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from diffusers.schedulers import (
@ -1028,6 +1030,7 @@ class Hy3DGenerateMesh:
},
"optional": {
"mask": ("MASK", ),
"scheduler": (["FlowMatchEulerDiscreteScheduler", "ConsistencyFlowMatchEulerDiscreteScheduler"],),
}
}
@ -1036,7 +1039,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None, scheduler="FlowMatchEulerDiscreteScheduler"):
mm.unload_all_models()
mm.soft_empty_cache()
@ -1052,6 +1055,13 @@ class Hy3DGenerateMesh:
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
if scheduler == "FlowMatchEulerDiscreteScheduler":
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
pipeline.scheduler = scheduler
pipeline.to(device)
try:
@ -1090,7 +1100,8 @@ class Hy3DGenerateMeshMultiView():
"front": ("IMAGE", ),
"left": ("IMAGE", ),
"right": ("IMAGE", ),
"back": ("IMAGE", ),
"back": ("IMAGE", ),
"scheduler": (["FlowMatchEulerDiscreteScheduler", "ConsistencyFlowMatchEulerDiscreteScheduler"],),
}
}
@ -1099,7 +1110,7 @@ class Hy3DGenerateMeshMultiView():
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None, scheduler="FlowMatchEulerDiscreteScheduler"):
mm.unload_all_models()
mm.soft_empty_cache()
@ -1125,6 +1136,11 @@ class Hy3DGenerateMeshMultiView():
'back': back
}
if scheduler == "FlowMatchEulerDiscreteScheduler":
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
elif scheduler == "ConsistencyFlowMatchEulerDiscreteScheduler":
scheduler = ConsistencyFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, pcm_timesteps=100)
try:
torch.cuda.reset_peak_memory_stats(device)
except: