mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-10 06:24:26 +08:00
Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V Also fix for slicing deprecation in pytorch 2.9
502 lines
25 KiB
Python
502 lines
25 KiB
Python
from __future__ import annotations
|
|
from typing import TYPE_CHECKING, Union
|
|
from comfy_api.latest import io, ComfyExtension
|
|
import comfy.patcher_extension
|
|
import logging
|
|
import torch
|
|
import comfy.model_patcher
|
|
if TYPE_CHECKING:
|
|
from uuid import UUID
|
|
|
|
|
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
|
# get values from args
|
|
transformer_options: dict[str] = args[-1]
|
|
if not isinstance(transformer_options, dict):
|
|
transformer_options = kwargs.get("transformer_options")
|
|
if not transformer_options:
|
|
transformer_options = args[-2]
|
|
easycache: EasyCacheHolder = transformer_options["easycache"]
|
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
|
sigmas = transformer_options["sigmas"]
|
|
uuids = transformer_options["uuids"]
|
|
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
|
return executor(*args, **kwargs)
|
|
# prepare next x_prev
|
|
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
|
next_x_prev = x
|
|
input_change = None
|
|
do_easycache = easycache.should_do_easycache(sigmas)
|
|
if do_easycache:
|
|
easycache.check_metadata(x)
|
|
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
|
if easycache.skip_current_step:
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
|
return easycache.apply_cache_diff(x, uuids)
|
|
if easycache.initial_step:
|
|
easycache.first_cond_uuid = uuids[0]
|
|
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
|
easycache.initial_step = False
|
|
if has_first_cond_uuid:
|
|
if easycache.has_x_prev_subsampled():
|
|
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
|
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
|
easycache.cumulative_change_rate += approx_output_change_rate
|
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
|
# other conds should also skip this step, and instead use their cached values
|
|
easycache.skip_current_step = True
|
|
return easycache.apply_cache_diff(x, uuids)
|
|
else:
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
|
easycache.cumulative_change_rate = 0.0
|
|
|
|
output: torch.Tensor = executor(*args, **kwargs)
|
|
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
|
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
|
if easycache.verbose:
|
|
output_change_rate = output_change / easycache.output_prev_norm
|
|
easycache.output_change_rates.append(output_change_rate.item())
|
|
if easycache.has_relative_transformation_rate():
|
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
|
if input_change is not None:
|
|
easycache.relative_transformation_rate = output_change / input_change
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
|
# TODO: allow cache_diff to be offloaded
|
|
easycache.update_cache_diff(output, next_x_prev, uuids)
|
|
if has_first_cond_uuid:
|
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
|
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
|
if easycache.verbose:
|
|
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
|
return output
|
|
|
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|
# get values from args
|
|
timestep: float = args[1]
|
|
model_options: dict[str] = args[2]
|
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
|
if easycache.is_past_end_timestep(timestep):
|
|
return executor(*args, **kwargs)
|
|
# prepare next x_prev
|
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
|
next_x_prev = x
|
|
input_change = None
|
|
do_easycache = easycache.should_do_easycache(timestep)
|
|
if do_easycache:
|
|
easycache.check_metadata(x)
|
|
if easycache.has_x_prev_subsampled():
|
|
if easycache.has_x_prev_subsampled():
|
|
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
|
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
|
easycache.cumulative_change_rate += approx_output_change_rate
|
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
|
if easycache.verbose:
|
|
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
|
# other conds should also skip this step, and instead use their cached values
|
|
easycache.skip_current_step = True
|
|
return easycache.apply_cache_diff(x)
|
|
else:
|
|
if easycache.verbose:
|
|
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
|
easycache.cumulative_change_rate = 0.0
|
|
output: torch.Tensor = executor(*args, **kwargs)
|
|
if easycache.has_output_prev_norm():
|
|
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
|
if easycache.verbose:
|
|
output_change_rate = output_change / easycache.output_prev_norm
|
|
easycache.output_change_rates.append(output_change_rate.item())
|
|
if easycache.has_relative_transformation_rate():
|
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
|
if easycache.verbose:
|
|
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
|
if input_change is not None:
|
|
easycache.relative_transformation_rate = output_change / input_change
|
|
if easycache.verbose:
|
|
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
|
# TODO: allow cache_diff to be offloaded
|
|
easycache.update_cache_diff(output, next_x_prev)
|
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
|
easycache.output_prev_subsampled = easycache.subsample(output)
|
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
|
if easycache.verbose:
|
|
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
|
return output
|
|
|
|
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
|
model_options = args[-1]
|
|
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
|
|
easycache.skip_current_step = False
|
|
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
|
|
return executor(*args, **kwargs)
|
|
|
|
def easycache_sample_wrapper(executor, *args, **kwargs):
|
|
"""
|
|
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
|
|
"""
|
|
try:
|
|
guider = executor.class_obj
|
|
orig_model_options = guider.model_options
|
|
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
|
# clone and prepare timesteps
|
|
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
|
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
|
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
|
return executor(*args, **kwargs)
|
|
finally:
|
|
easycache = guider.model_options['transformer_options']['easycache']
|
|
output_change_rates = easycache.output_change_rates
|
|
approx_output_change_rates = easycache.approx_output_change_rates
|
|
if easycache.verbose:
|
|
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
|
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
|
total_steps = len(args[3])-1
|
|
# catch division by zero for log statement; sucks to crash after all sampling is done
|
|
try:
|
|
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
|
except ZeroDivisionError:
|
|
speedup = 1.0
|
|
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
|
easycache.reset()
|
|
guider.model_options = orig_model_options
|
|
|
|
|
|
class EasyCacheHolder:
|
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
|
self.name = "EasyCache"
|
|
self.reuse_threshold = reuse_threshold
|
|
self.start_percent = start_percent
|
|
self.end_percent = end_percent
|
|
self.subsample_factor = subsample_factor
|
|
self.offload_cache_diff = offload_cache_diff
|
|
self.verbose = verbose
|
|
# timestep values
|
|
self.start_t = 0.0
|
|
self.end_t = 0.0
|
|
# control values
|
|
self.relative_transformation_rate: float = None
|
|
self.cumulative_change_rate = 0.0
|
|
self.initial_step = True
|
|
self.skip_current_step = False
|
|
# cache values
|
|
self.first_cond_uuid = None
|
|
self.x_prev_subsampled: torch.Tensor = None
|
|
self.output_prev_subsampled: torch.Tensor = None
|
|
self.output_prev_norm: torch.Tensor = None
|
|
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
|
self.output_change_rates = []
|
|
self.approx_output_change_rates = []
|
|
self.total_steps_skipped = 0
|
|
# how to deal with mismatched dims
|
|
self.allow_mismatch = True
|
|
self.cut_from_start = True
|
|
self.state_metadata = None
|
|
self.output_channels = output_channels
|
|
|
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
|
return not (timestep[0] > self.end_t).item()
|
|
|
|
def should_do_easycache(self, timestep: float) -> bool:
|
|
return (timestep[0] <= self.start_t).item()
|
|
|
|
def has_x_prev_subsampled(self) -> bool:
|
|
return self.x_prev_subsampled is not None
|
|
|
|
def has_output_prev_subsampled(self) -> bool:
|
|
return self.output_prev_subsampled is not None
|
|
|
|
def has_output_prev_norm(self) -> bool:
|
|
return self.output_prev_norm is not None
|
|
|
|
def has_relative_transformation_rate(self) -> bool:
|
|
return self.relative_transformation_rate is not None
|
|
|
|
def prepare_timesteps(self, model_sampling):
|
|
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
|
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
|
return self
|
|
|
|
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
|
|
batch_offset = x.shape[0] // len(uuids)
|
|
uuid_idx = uuids.index(self.first_cond_uuid)
|
|
if self.subsample_factor > 1:
|
|
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
|
|
if clone:
|
|
return to_return.clone()
|
|
return to_return
|
|
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
|
|
if clone:
|
|
return to_return.clone()
|
|
return to_return
|
|
|
|
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
|
if self.first_cond_uuid in uuids:
|
|
self.total_steps_skipped += 1
|
|
batch_offset = x.shape[0] // len(uuids)
|
|
for i, uuid in enumerate(uuids):
|
|
# slice out only what is relevant to this cond
|
|
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
|
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
|
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
|
if not self.allow_mismatch:
|
|
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
|
slicing = []
|
|
skip_this_dim = True
|
|
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
|
if skip_this_dim:
|
|
skip_this_dim = False
|
|
continue
|
|
if dim_u != dim_x:
|
|
if self.cut_from_start:
|
|
slicing.append(slice(dim_x-dim_u, None))
|
|
else:
|
|
slicing.append(slice(None, dim_u))
|
|
else:
|
|
slicing.append(slice(None))
|
|
batch_slice = batch_slice + slicing
|
|
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
|
return x
|
|
|
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
|
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
|
if output.shape[1:] != x.shape[1:]:
|
|
if not self.allow_mismatch:
|
|
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
|
|
slicing = []
|
|
skip_dim = True
|
|
for dim_o, dim_x in zip(output.shape, x.shape):
|
|
if not skip_dim and dim_o != dim_x:
|
|
if self.cut_from_start:
|
|
slicing.append(slice(dim_x-dim_o, None))
|
|
else:
|
|
slicing.append(slice(None, dim_o))
|
|
else:
|
|
slicing.append(slice(None))
|
|
skip_dim = False
|
|
x = x[tuple(slicing)]
|
|
diff = output - x
|
|
batch_offset = diff.shape[0] // len(uuids)
|
|
for i, uuid in enumerate(uuids):
|
|
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
|
|
|
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
|
return self.first_cond_uuid in uuids
|
|
|
|
def check_metadata(self, x: torch.Tensor) -> bool:
|
|
metadata = (x.device, x.dtype, x.shape[1:])
|
|
if self.state_metadata is None:
|
|
self.state_metadata = metadata
|
|
return True
|
|
if metadata == self.state_metadata:
|
|
return True
|
|
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
|
self.reset()
|
|
return False
|
|
|
|
def reset(self):
|
|
self.relative_transformation_rate = 0.0
|
|
self.cumulative_change_rate = 0.0
|
|
self.initial_step = True
|
|
self.skip_current_step = False
|
|
self.output_change_rates = []
|
|
self.first_cond_uuid = None
|
|
del self.x_prev_subsampled
|
|
self.x_prev_subsampled = None
|
|
del self.output_prev_subsampled
|
|
self.output_prev_subsampled = None
|
|
del self.output_prev_norm
|
|
self.output_prev_norm = None
|
|
del self.uuid_cache_diffs
|
|
self.uuid_cache_diffs = {}
|
|
self.total_steps_skipped = 0
|
|
self.state_metadata = None
|
|
return self
|
|
|
|
def clone(self):
|
|
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
|
|
|
|
|
class EasyCacheNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="EasyCache",
|
|
display_name="EasyCache",
|
|
description="Native EasyCache implementation.",
|
|
category="advanced/debug/model",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to add EasyCache to."),
|
|
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
|
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
|
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(tooltip="The model with EasyCache."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
|
model = model.clone()
|
|
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
|
return io.NodeOutput(model)
|
|
|
|
|
|
class LazyCacheHolder:
|
|
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
|
self.name = "LazyCache"
|
|
self.reuse_threshold = reuse_threshold
|
|
self.start_percent = start_percent
|
|
self.end_percent = end_percent
|
|
self.subsample_factor = subsample_factor
|
|
self.offload_cache_diff = offload_cache_diff
|
|
self.verbose = verbose
|
|
# timestep values
|
|
self.start_t = 0.0
|
|
self.end_t = 0.0
|
|
# control values
|
|
self.relative_transformation_rate: float = None
|
|
self.cumulative_change_rate = 0.0
|
|
self.initial_step = True
|
|
# cache values
|
|
self.x_prev_subsampled: torch.Tensor = None
|
|
self.output_prev_subsampled: torch.Tensor = None
|
|
self.output_prev_norm: torch.Tensor = None
|
|
self.cache_diff: torch.Tensor = None
|
|
self.output_change_rates = []
|
|
self.approx_output_change_rates = []
|
|
self.total_steps_skipped = 0
|
|
self.state_metadata = None
|
|
self.output_channels = output_channels
|
|
|
|
def has_cache_diff(self) -> bool:
|
|
return self.cache_diff is not None
|
|
|
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
|
return not (timestep[0] > self.end_t).item()
|
|
|
|
def should_do_easycache(self, timestep: float) -> bool:
|
|
return (timestep[0] <= self.start_t).item()
|
|
|
|
def has_x_prev_subsampled(self) -> bool:
|
|
return self.x_prev_subsampled is not None
|
|
|
|
def has_output_prev_subsampled(self) -> bool:
|
|
return self.output_prev_subsampled is not None
|
|
|
|
def has_output_prev_norm(self) -> bool:
|
|
return self.output_prev_norm is not None
|
|
|
|
def has_relative_transformation_rate(self) -> bool:
|
|
return self.relative_transformation_rate is not None
|
|
|
|
def prepare_timesteps(self, model_sampling):
|
|
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
|
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
|
return self
|
|
|
|
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
|
if self.subsample_factor > 1:
|
|
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
|
if clone:
|
|
return to_return.clone()
|
|
return to_return
|
|
if clone:
|
|
return x.clone()
|
|
return x
|
|
|
|
def apply_cache_diff(self, x: torch.Tensor):
|
|
self.total_steps_skipped += 1
|
|
return x + self.cache_diff.to(x.device)
|
|
|
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
|
self.cache_diff = output - x
|
|
|
|
def check_metadata(self, x: torch.Tensor) -> bool:
|
|
metadata = (x.device, x.dtype, x.shape)
|
|
if self.state_metadata is None:
|
|
self.state_metadata = metadata
|
|
return True
|
|
if metadata == self.state_metadata:
|
|
return True
|
|
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
|
self.reset()
|
|
return False
|
|
|
|
def reset(self):
|
|
self.relative_transformation_rate = 0.0
|
|
self.cumulative_change_rate = 0.0
|
|
self.initial_step = True
|
|
self.output_change_rates = []
|
|
self.approx_output_change_rates = []
|
|
del self.cache_diff
|
|
self.cache_diff = None
|
|
del self.x_prev_subsampled
|
|
self.x_prev_subsampled = None
|
|
del self.output_prev_subsampled
|
|
self.output_prev_subsampled = None
|
|
del self.output_prev_norm
|
|
self.output_prev_norm = None
|
|
self.total_steps_skipped = 0
|
|
self.state_metadata = None
|
|
return self
|
|
|
|
def clone(self):
|
|
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
|
|
|
class LazyCacheNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="LazyCache",
|
|
display_name="LazyCache",
|
|
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
|
category="advanced/debug/model",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to add LazyCache to."),
|
|
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
|
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
|
|
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
|
|
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(tooltip="The model with LazyCache."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
|
model = model.clone()
|
|
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
|
return io.NodeOutput(model)
|
|
|
|
|
|
class EasyCacheExtension(ComfyExtension):
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
EasyCacheNode,
|
|
LazyCacheNode,
|
|
]
|
|
|
|
def comfy_entrypoint():
|
|
return EasyCacheExtension()
|