mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Flux 2 (#10879)
This commit is contained in:
parent
015a0599d0
commit
6b573ae0cb
@ -178,6 +178,15 @@ class Flux(SD3):
|
|||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return (latent / self.scale_factor) + self.shift_factor
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class Flux2(LatentFormat):
|
||||||
|
latent_channels = 128
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return latent
|
||||||
|
|
||||||
class Mochi(LatentFormat):
|
class Mochi(LatentFormat):
|
||||||
latent_channels = 12
|
latent_channels = 12
|
||||||
latent_dimensions = 3
|
latent_dimensions = 3
|
||||||
|
|||||||
@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
class MLPEmbedder(nn.Module):
|
class MLPEmbedder(nn.Module):
|
||||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||||
self.silu = nn.SiLU()
|
self.silu = nn.SiLU()
|
||||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return self.out_layer(self.silu(self.in_layer(x)))
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -98,11 +98,11 @@ class ModulationOut:
|
|||||||
|
|
||||||
|
|
||||||
class Modulation(nn.Module):
|
class Modulation(nn.Module):
|
||||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_double = double
|
self.is_double = double
|
||||||
self.multiplier = 6 if double else 3
|
self.multiplier = 6 if double else 3
|
||||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, vec: Tensor) -> tuple:
|
def forward(self, vec: Tensor) -> tuple:
|
||||||
if vec.ndim == 2:
|
if vec.ndim == 2:
|
||||||
@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SiLUActivation(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return self.gate_fn(x1) * x2
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
if mlp_silu_act:
|
||||||
nn.GELU(approximate="tanh"),
|
self.img_mlp = nn.Sequential(
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||||
)
|
SiLUActivation(),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
if mlp_silu_act:
|
||||||
nn.GELU(approximate="tanh"),
|
self.txt_mlp = nn.Sequential(
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||||
)
|
SiLUActivation(),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module):
|
|||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
qk_scale: float = None,
|
qk_scale: float = None,
|
||||||
modulation=True,
|
modulation=True,
|
||||||
|
mlp_silu_act=False,
|
||||||
|
bias=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.scale = qk_scale or head_dim**-0.5
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
|
||||||
|
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||||
|
if mlp_silu_act:
|
||||||
|
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||||
|
self.mlp_act = SiLUActivation()
|
||||||
|
else:
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
# qkv and mlp_in
|
# qkv and mlp_in
|
||||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||||
# proj and mlp_out
|
# proj and mlp_out
|
||||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
|
||||||
if modulation:
|
if modulation:
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
else:
|
else:
|
||||||
@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
del qkv
|
del qkv
|
||||||
@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||||
if vec.ndim == 2:
|
if vec.ndim == 2:
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from .layers import (
|
|||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
|
Modulation
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -33,6 +34,11 @@ class FluxParams:
|
|||||||
patch_size: int
|
patch_size: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
global_modulation: bool = False
|
||||||
|
mlp_silu_act: bool = False
|
||||||
|
ops_bias: bool = True
|
||||||
|
default_ref_method: str = "offset"
|
||||||
|
ref_index_scale: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(nn.Module):
|
||||||
@ -58,13 +64,17 @@ class Flux(nn.Module):
|
|||||||
self.hidden_size = params.hidden_size
|
self.hidden_size = params.hidden_size
|
||||||
self.num_heads = params.num_heads
|
self.num_heads = params.num_heads
|
||||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
if params.vec_in_dim is not None:
|
||||||
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.vector_in = None
|
||||||
|
|
||||||
self.guidance_in = (
|
self.guidance_in = (
|
||||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||||
)
|
)
|
||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -73,6 +83,9 @@ class Flux(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
|
modulation=params.global_modulation is False,
|
||||||
|
mlp_silu_act=params.mlp_silu_act,
|
||||||
|
proj_bias=params.ops_bias,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -81,13 +94,30 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
if params.global_modulation:
|
||||||
|
self.double_stream_modulation_img = Modulation(
|
||||||
|
self.hidden_size,
|
||||||
|
double=True,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.double_stream_modulation_txt = Modulation(
|
||||||
|
self.hidden_size,
|
||||||
|
double=True,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.single_stream_modulation = Modulation(
|
||||||
|
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@ -103,9 +133,6 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
if y is None:
|
|
||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@ -118,9 +145,17 @@ class Flux(nn.Module):
|
|||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
if self.vector_in is not None:
|
||||||
|
if y is None:
|
||||||
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
vec_orig = vec
|
||||||
|
if self.params.global_modulation:
|
||||||
|
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||||
|
|
||||||
if "post_input" in patches:
|
if "post_input" in patches:
|
||||||
for p in patches["post_input"]:
|
for p in patches["post_input"]:
|
||||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||||
@ -177,6 +212,9 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
if self.params.global_modulation:
|
||||||
|
vec, _ = self.single_stream_modulation(vec_orig)
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
@ -207,7 +245,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||||
@ -234,10 +272,10 @@ class Flux(nn.Module):
|
|||||||
h_offset += rope_options.get("shift_y", 0.0)
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
w_offset += rope_options.get("shift_x", 0.0)
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
@ -259,10 +297,10 @@ class Flux(nn.Module):
|
|||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if ref_latents_method == "index":
|
if ref_latents_method == "index":
|
||||||
index += 1
|
index += self.params.ref_index_scale
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
elif ref_latents_method == "uxo":
|
elif ref_latents_method == "uxo":
|
||||||
@ -286,7 +324,11 @@ class Flux(nn.Module):
|
|||||||
img = torch.cat([img, kontext], dim=1)
|
img = torch.cat([img, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if len(self.params.axes_dim) == 4: # Flux 2
|
||||||
|
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
|
|||||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||||
from comfy.ldm.modules.ema import LitEma
|
from comfy.ldm.modules.ema import LitEma
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
from einops import rearrange
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||||
def __init__(self, sample: bool = False):
|
def __init__(self, sample: bool = False):
|
||||||
@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
if ddconfig.get("batch_norm_latent", False):
|
||||||
|
self.bn_eps = 1e-4
|
||||||
|
self.bn_momentum = 0.1
|
||||||
|
self.ps = [2, 2]
|
||||||
|
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
|
||||||
|
eps=self.bn_eps,
|
||||||
|
momentum=self.bn_momentum,
|
||||||
|
affine=False,
|
||||||
|
track_running_stats=True,
|
||||||
|
)
|
||||||
|
self.bn.eval()
|
||||||
|
else:
|
||||||
|
self.bn = None
|
||||||
|
|
||||||
|
|
||||||
def get_autoencoder_params(self) -> list:
|
def get_autoencoder_params(self) -> list:
|
||||||
params = super().get_autoencoder_params()
|
params = super().get_autoencoder_params()
|
||||||
return params
|
return params
|
||||||
@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
z = torch.cat(z, 0)
|
z = torch.cat(z, 0)
|
||||||
|
|
||||||
z, reg_log = self.regularization(z)
|
z, reg_log = self.regularization(z)
|
||||||
|
|
||||||
|
if self.bn is not None:
|
||||||
|
z = rearrange(z,
|
||||||
|
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||||
|
pi=self.ps[0],
|
||||||
|
pj=self.ps[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
z = torch.nn.functional.batch_norm(z,
|
||||||
|
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
|
||||||
|
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
|
||||||
|
momentum=self.bn_momentum,
|
||||||
|
eps=self.bn_eps)
|
||||||
|
|
||||||
if return_reg_log:
|
if return_reg_log:
|
||||||
return z, reg_log
|
return z, reg_log
|
||||||
return z
|
return z
|
||||||
|
|
||||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||||
|
if self.bn is not None:
|
||||||
|
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
|
||||||
|
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
|
||||||
|
z = z * s + m
|
||||||
|
z = rearrange(
|
||||||
|
z,
|
||||||
|
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||||
|
pi=self.ps[0],
|
||||||
|
pj=self.ps[1],
|
||||||
|
)
|
||||||
|
|
||||||
if self.max_batch_size is None:
|
if self.max_batch_size is None:
|
||||||
dec = self.post_quant_conv(z)
|
dec = self.post_quant_conv(z)
|
||||||
dec = self.decoder(dec, **decoder_kwargs)
|
dec = self.decoder(dec, **decoder_kwargs)
|
||||||
|
|||||||
@ -898,12 +898,13 @@ class Flux(BaseModel):
|
|||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
shape = kwargs["noise"].shape
|
shape = kwargs["noise"].shape
|
||||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
|
||||||
# the model will pad to the patch size, and then divide
|
if mask_ref_size is not None:
|
||||||
# essentially dividing and rounding up
|
# the model will pad to the patch size, and then divide
|
||||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
# essentially dividing and rounding up
|
||||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
|
||||||
guidance = kwargs.get("guidance", 3.5)
|
guidance = kwargs.get("guidance", 3.5)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
@ -928,6 +929,16 @@ class Flux(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Flux2(Flux):
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
target_text_len = 512
|
||||||
|
if cross_attn.shape[1] < target_text_len:
|
||||||
|
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
class GenmoMochi(BaseModel):
|
class GenmoMochi(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
|||||||
@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["image_model"] = "flux2"
|
||||||
|
dit_config["axes_dim"] = [32, 32, 32, 32]
|
||||||
|
dit_config["num_heads"] = 48
|
||||||
|
dit_config["mlp_ratio"] = 3.0
|
||||||
|
dit_config["theta"] = 2000
|
||||||
|
dit_config["out_channels"] = 128
|
||||||
|
dit_config["global_modulation"] = True
|
||||||
|
dit_config["vec_in_dim"] = None
|
||||||
|
dit_config["mlp_silu_act"] = True
|
||||||
|
dit_config["qkv_bias"] = False
|
||||||
|
dit_config["ops_bias"] = False
|
||||||
|
dit_config["default_ref_method"] = "index"
|
||||||
|
dit_config["ref_index_scale"] = 10.0
|
||||||
|
patch_size = 1
|
||||||
|
else:
|
||||||
|
dit_config["image_model"] = "flux"
|
||||||
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
|
dit_config["num_heads"] = 24
|
||||||
|
dit_config["mlp_ratio"] = 4.0
|
||||||
|
dit_config["theta"] = 10000
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["qkv_bias"] = True
|
||||||
|
patch_size = 2
|
||||||
|
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
patch_size = 2
|
dit_config["hidden_size"] = 3072
|
||||||
|
dit_config["context_in_dim"] = 4096
|
||||||
|
|
||||||
dit_config["patch_size"] = patch_size
|
dit_config["patch_size"] = patch_size
|
||||||
in_key = "{}img_in.weight".format(key_prefix)
|
in_key = "{}img_in.weight".format(key_prefix)
|
||||||
if in_key in state_dict_keys:
|
if in_key in state_dict_keys:
|
||||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
w = state_dict[in_key]
|
||||||
dit_config["out_channels"] = 16
|
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||||
|
dit_config["hidden_size"] = w.shape[0]
|
||||||
|
|
||||||
|
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||||
|
if txt_in_key in state_dict_keys:
|
||||||
|
w = state_dict[txt_in_key]
|
||||||
|
dit_config["context_in_dim"] = w.shape[1]
|
||||||
|
dit_config["hidden_size"] = w.shape[0]
|
||||||
|
|
||||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||||
if vec_in_key in state_dict_keys:
|
if vec_in_key in state_dict_keys:
|
||||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||||
dit_config["context_in_dim"] = 4096
|
|
||||||
dit_config["hidden_size"] = 3072
|
|
||||||
dit_config["mlp_ratio"] = 4.0
|
|
||||||
dit_config["num_heads"] = 24
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
|
||||||
dit_config["theta"] = 10000
|
|
||||||
dit_config["qkv_bias"] = True
|
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
|
|||||||
26
comfy/sd.py
26
comfy/sd.py
@ -356,7 +356,7 @@ class VAE:
|
|||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||||
elif sd['decoder.conv_in.weight'].shape[1] == 32:
|
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
@ -382,6 +382,17 @@ class VAE:
|
|||||||
self.upscale_ratio = 4
|
self.upscale_ratio = 4
|
||||||
|
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
|
if 'decoder.post_quant_conv.weight' in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||||
|
|
||||||
|
if 'bn.running_mean' in sd:
|
||||||
|
ddconfig["batch_norm_latent"] = True
|
||||||
|
self.downscale_ratio *= 2
|
||||||
|
self.upscale_ratio *= 2
|
||||||
|
self.latent_channels *= 4
|
||||||
|
old_memory_used_decode = self.memory_used_decode
|
||||||
|
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||||
|
|
||||||
if 'post_quant_conv.weight' in sd:
|
if 'post_quant_conv.weight' in sd:
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
else:
|
else:
|
||||||
@ -940,6 +951,8 @@ class TEModel(Enum):
|
|||||||
QWEN25_7B = 11
|
QWEN25_7B = 11
|
||||||
BYT5_SMALL_GLYPH = 12
|
BYT5_SMALL_GLYPH = 12
|
||||||
GEMMA_3_4B = 13
|
GEMMA_3_4B = 13
|
||||||
|
MISTRAL3_24B = 14
|
||||||
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -972,6 +985,13 @@ def detect_te_model(sd):
|
|||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
|
if weight.shape[0] == 5120:
|
||||||
|
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||||
|
return TEModel.MISTRAL3_24B
|
||||||
|
else:
|
||||||
|
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||||
|
|
||||||
return TEModel.LLAMA3_8
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -1086,6 +1106,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
|
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
|
||||||
|
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||||
|
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
|
|||||||
@ -741,6 +741,37 @@ class FluxSchnell(Flux):
|
|||||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Flux2(Flux):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "flux2",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 2.02,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux2
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Flux2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None # TODO
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||||
|
|
||||||
class GenmoMochi(supported_models_base.BASE):
|
class GenmoMochi(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "mochi_preview",
|
"image_model": "mochi_preview",
|
||||||
@ -1422,6 +1453,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||||
|
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import comfy.text_encoders.t5
|
import comfy.text_encoders.t5
|
||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast, LlamaTokenizerFast
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -68,3 +71,105 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
|||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|
||||||
|
def load_mistral_tokenizer(data):
|
||||||
|
if torch.is_tensor(data):
|
||||||
|
data = data.numpy().tobytes()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.integrations.mistral import MistralConverter
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
|
||||||
|
|
||||||
|
mistral_vocab = json.loads(data)
|
||||||
|
|
||||||
|
special_tokens = {}
|
||||||
|
vocab = {}
|
||||||
|
|
||||||
|
max_vocab = mistral_vocab["config"]["default_vocab_size"]
|
||||||
|
|
||||||
|
for w in mistral_vocab["vocab"]:
|
||||||
|
r = w["rank"]
|
||||||
|
if r >= max_vocab:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vocab[base64.b64decode(w["token_bytes"])] = r
|
||||||
|
|
||||||
|
for w in mistral_vocab["special_tokens"]:
|
||||||
|
if "token_bytes" in w:
|
||||||
|
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
|
||||||
|
else:
|
||||||
|
special_tokens[w["token_str"]] = w["rank"]
|
||||||
|
|
||||||
|
all_special = []
|
||||||
|
for v in special_tokens:
|
||||||
|
all_special.append(v)
|
||||||
|
|
||||||
|
special_tokens.update(vocab)
|
||||||
|
vocab = special_tokens
|
||||||
|
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
|
||||||
|
|
||||||
|
class MistralTokenizerClass:
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path, **kwargs):
|
||||||
|
return LlamaTokenizerFast(**kwargs)
|
||||||
|
|
||||||
|
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||||
|
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"tekken_model": self.tekken_data}
|
||||||
|
|
||||||
|
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
|
||||||
|
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = {}
|
||||||
|
num_layers = model_options.get("num_layers", None)
|
||||||
|
if num_layers is not None:
|
||||||
|
textmodel_json_config["num_hidden_layers"] = num_layers
|
||||||
|
if num_layers < 40:
|
||||||
|
textmodel_json_config["final_norm"] = False
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
|
||||||
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
|
||||||
|
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
|
||||||
|
out = out.movedim(1, 2)
|
||||||
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
||||||
|
class Flux2TEModel_(Flux2TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if pruned:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["num_layers"] = 30
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Flux2TEModel_
|
||||||
|
|||||||
@ -34,6 +34,28 @@ class Llama2Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mistral3Small24BConfig:
|
||||||
|
vocab_size: int = 131072
|
||||||
|
hidden_size: int = 5120
|
||||||
|
intermediate_size: int = 32768
|
||||||
|
num_hidden_layers: int = 40
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 1000000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = None
|
||||||
|
k_norm = None
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
vocab_size: int = 151936
|
vocab_size: int = 151936
|
||||||
@ -465,6 +487,15 @@ class Llama2(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Mistral3Small24BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -2,7 +2,10 @@ import node_helpers
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import comfy.model_management
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import nodes
|
||||||
|
|
||||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
|
|||||||
|
|
||||||
encode = execute # TODO: remove
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
class EmptyFlux2LatentImage(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyFlux2LatentImage",
|
||||||
|
display_name="Empty Flux 2 Latent",
|
||||||
|
category="latent",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||||
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
class FluxGuidance(io.ComfyNode):
|
class FluxGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
|||||||
append = execute # TODO: remove
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_time_snr_shift(t, mu: float, sigma: float):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
||||||
|
mu = compute_empirical_mu(image_seq_len, num_steps)
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Scheduler(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Flux2Scheduler",
|
||||||
|
category="sampling/custom_sampling/schedulers",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("steps", default=20, min=1, max=4096),
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Sigmas.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, steps, width, height) -> io.NodeOutput:
|
||||||
|
seq_len = (width * height / (16 * 16))
|
||||||
|
sigmas = get_schedule(steps, round(seq_len))
|
||||||
|
return io.NodeOutput(sigmas)
|
||||||
|
|
||||||
|
|
||||||
class FluxExtension(ComfyExtension):
|
class FluxExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension):
|
|||||||
FluxDisableGuidance,
|
FluxDisableGuidance,
|
||||||
FluxKontextImageScale,
|
FluxKontextImageScale,
|
||||||
FluxKontextMultiReferenceLatentMethod,
|
FluxKontextMultiReferenceLatentMethod,
|
||||||
|
EmptyFlux2LatentImage,
|
||||||
|
Flux2Scheduler,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -929,7 +929,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user