mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Merge 5c6fcbda91babbefc6d907a7febe3c0450f05f74 into bed12674a1d2c4bfdfbdd098686390f807996c90
This commit is contained in:
commit
96fb67be5f
366
comfy_extras/nodes_dype.py
Normal file
366
comfy_extras/nodes_dype.py
Normal file
@ -0,0 +1,366 @@
|
||||
# adapted from https://github.com/guyyariv/DyPE
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
) # Inverse dim formula to find number of rotations
|
||||
|
||||
|
||||
def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
|
||||
"""Find the correction range for NTK-by-parts interpolation"""
|
||||
low = np.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
|
||||
high = np.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def find_newbase_ntk(dim, base, scale):
|
||||
"""Calculate the new base for NTK-aware scaling"""
|
||||
return base * (scale ** (dim / (dim - 2)))
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: np.ndarray | int,
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32,
|
||||
yarn=False,
|
||||
max_pe_len=None,
|
||||
ori_max_pe_len=64,
|
||||
dype=False,
|
||||
current_timestep=1.0,
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials with RoPE.
|
||||
Supports YARN interpolation for vision transformers.
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`):
|
||||
Position indices for the frequency tensor. [S] or scalar.
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`, *optional*, defaults to False):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for linear interpolation.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for NTK-Aware RoPE.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to True):
|
||||
If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
|
||||
Otherwise, they are concatenated.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
Data type of the frequency tensor.
|
||||
yarn (`bool`, *optional*, defaults to False):
|
||||
If True, use YARN interpolation combining NTK, linear, and base methods.
|
||||
max_pe_len (`int`, *optional*):
|
||||
Maximum position encoding length (current patches for vision models).
|
||||
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
||||
Original maximum position encoding length (base patches for vision models).
|
||||
dype (`bool`, *optional*, defaults to False):
|
||||
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
|
||||
current_timestep (`float`, *optional*, defaults to 1.0):
|
||||
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
If use_real=True, returns tuple of (cos, sin) tensors.
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos)
|
||||
|
||||
device = pos.device
|
||||
|
||||
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
||||
if not isinstance(max_pe_len, torch.Tensor):
|
||||
max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
|
||||
|
||||
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
||||
|
||||
beta_0 = 1.25
|
||||
beta_1 = 0.75
|
||||
gamma_0 = 16
|
||||
gamma_1 = 2
|
||||
|
||||
freqs_base = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
||||
)
|
||||
|
||||
freqs_linear = 1.0 / torch.einsum(
|
||||
"..., f -> ... f",
|
||||
scale,
|
||||
(
|
||||
theta
|
||||
** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
||||
),
|
||||
)
|
||||
|
||||
new_base = find_newbase_ntk(dim, theta, scale)
|
||||
if new_base.dim() > 0:
|
||||
new_base = new_base.view(-1, 1)
|
||||
freqs_ntk = 1.0 / torch.pow(
|
||||
new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
||||
)
|
||||
if freqs_ntk.dim() > 1:
|
||||
freqs_ntk = freqs_ntk.squeeze()
|
||||
|
||||
if dype:
|
||||
beta_0 = beta_0 ** (2.0 * (current_timestep**2.0))
|
||||
beta_1 = beta_1 ** (2.0 * (current_timestep**2.0))
|
||||
|
||||
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
||||
low = max(0, low)
|
||||
high = min(dim // 2, high)
|
||||
|
||||
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(
|
||||
freqs_dtype
|
||||
)
|
||||
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
||||
|
||||
if dype:
|
||||
gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0))
|
||||
gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0))
|
||||
|
||||
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
||||
low = max(0, low)
|
||||
high = min(dim // 2, high)
|
||||
|
||||
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(
|
||||
freqs_dtype
|
||||
)
|
||||
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
||||
|
||||
else:
|
||||
theta_ntk = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (
|
||||
theta_ntk
|
||||
** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
|
||||
)
|
||||
/ linear_factor
|
||||
)
|
||||
|
||||
freqs = pos.unsqueeze(-1) * freqs
|
||||
|
||||
is_npu = freqs.device.type == "npu"
|
||||
if is_npu:
|
||||
freqs = freqs.float()
|
||||
|
||||
if use_real and repeat_interleave_real:
|
||||
freqs_cos = (
|
||||
freqs.cos()
|
||||
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
||||
.float()
|
||||
)
|
||||
freqs_sin = (
|
||||
freqs.sin()
|
||||
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
|
||||
.float()
|
||||
)
|
||||
|
||||
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
||||
mscale = torch.where(
|
||||
scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0
|
||||
).to(scale)
|
||||
freqs_cos = freqs_cos * mscale
|
||||
freqs_sin = freqs_sin * mscale
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
theta: int,
|
||||
axes_dim: list[int],
|
||||
method: str = "yarn",
|
||||
dype: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.base_resolution = 1024
|
||||
self.base_patches = (self.base_resolution // 8) // 2
|
||||
self.method = method
|
||||
self.dype = dype if method != "base" else False
|
||||
self.current_timestep = 1.0
|
||||
|
||||
def set_timestep(self, timestep: float):
|
||||
"""Set current timestep for DyPE. Timestep normalized to [0, 1] where 1 is pure noise."""
|
||||
self.current_timestep = timestep
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
freqs_dtype = torch.bfloat16 if ids.device.type == "cuda" else torch.float32
|
||||
|
||||
for i in range(n_axes):
|
||||
axis_dim = self.axes_dim[i]
|
||||
axis_pos = pos[..., i]
|
||||
|
||||
common_kwargs = {
|
||||
"dim": axis_dim,
|
||||
"pos": axis_pos,
|
||||
"theta": self.theta,
|
||||
"repeat_interleave_real": True,
|
||||
"use_real": True,
|
||||
"freqs_dtype": freqs_dtype,
|
||||
}
|
||||
|
||||
if i > 0:
|
||||
max_pos = axis_pos.max().item()
|
||||
current_patches = max_pos + 1
|
||||
|
||||
if self.method == "yarn" and current_patches > self.base_patches:
|
||||
max_pe_len = torch.tensor(
|
||||
current_patches, dtype=freqs_dtype, device=pos.device
|
||||
)
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
**common_kwargs,
|
||||
yarn=True,
|
||||
max_pe_len=max_pe_len,
|
||||
ori_max_pe_len=self.base_patches,
|
||||
dype=self.dype,
|
||||
current_timestep=self.current_timestep,
|
||||
)
|
||||
|
||||
elif self.method == "ntk" and current_patches > self.base_patches:
|
||||
base_ntk = (current_patches / self.base_patches) ** (
|
||||
self.axes_dim[i] / (self.axes_dim[i] - 2)
|
||||
)
|
||||
ntk_factor = (
|
||||
base_ntk ** (2.0 * (self.current_timestep**2.0))
|
||||
if self.dype
|
||||
else base_ntk
|
||||
)
|
||||
ntk_factor = max(1.0, ntk_factor)
|
||||
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
**common_kwargs, ntk_factor=ntk_factor
|
||||
)
|
||||
|
||||
else:
|
||||
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
|
||||
else:
|
||||
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
|
||||
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
|
||||
emb_parts = []
|
||||
for cos, sin in zip(cos_out, sin_out):
|
||||
cos_reshaped = cos.view(*cos.shape[:-1], -1, 2)[..., :1]
|
||||
sin_reshaped = sin.view(*sin.shape[:-1], -1, 2)[..., :1]
|
||||
row1 = torch.cat([cos_reshaped, -sin_reshaped], dim=-1)
|
||||
row2 = torch.cat([sin_reshaped, cos_reshaped], dim=-1)
|
||||
matrix = torch.stack([row1, row2], dim=-2)
|
||||
emb_parts.append(matrix)
|
||||
|
||||
emb = torch.cat(emb_parts, dim=-3)
|
||||
return emb.unsqueeze(1).to(ids.device)
|
||||
|
||||
|
||||
def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher:
|
||||
if getattr(model.model, "_dype", None) == method:
|
||||
return model
|
||||
|
||||
m = model.clone()
|
||||
m.model._dype = method
|
||||
|
||||
_pe_embedder = m.model.diffusion_model.pe_embedder
|
||||
_theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim
|
||||
|
||||
pos_embedder = FluxPosEmbed(_theta, _axes_dim, method, dype=True)
|
||||
m.add_object_patch("diffusion_model.pe_embedder", pos_embedder)
|
||||
|
||||
sigma_max = m.model.model_sampling.sigma_max.item()
|
||||
|
||||
def dype_wrapper_function(model_function, args_dict):
|
||||
timestep_tensor = args_dict.get("timestep")
|
||||
if timestep_tensor is not None and timestep_tensor.numel() > 0:
|
||||
current_sigma = timestep_tensor.flatten()[0].item()
|
||||
|
||||
if sigma_max > 0:
|
||||
normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0)
|
||||
pos_embedder.set_timestep(normalized_timestep)
|
||||
|
||||
input_x, c = args_dict.get("input"), args_dict.get("c", {})
|
||||
return model_function(input_x, args_dict.get("timestep"), **c)
|
||||
|
||||
m.set_model_unet_function_wrapper(dype_wrapper_function)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class DyPEPatchModelFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DyPEPatchModelFlux",
|
||||
display_name="DyPE Patch Model (Flux)",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input(
|
||||
"method",
|
||||
options=["yarn", "ntk", "base"],
|
||||
default="yarn",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput:
|
||||
m = apply_dype_flux(model, method)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class DyPEExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
DyPEPatchModelFlux,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return DyPEExtension()
|
||||
Loading…
x
Reference in New Issue
Block a user