mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Merge remote-tracking branch 'upstream/master' into multitalk
This commit is contained in:
commit
c36323fe53
@ -1,5 +1,5 @@
|
||||
As of the time of writing this you need this preview driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
|
||||
4
.github/workflows/release-stable-all.yml
vendored
4
.github/workflows/release-stable-all.yml
vendored
@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 6.4.4"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm644"
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
* @kosinkadink
|
||||
* @guill
|
||||
|
||||
@ -68,6 +68,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
|
||||
@ -59,6 +59,9 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@ -66,15 +69,16 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@ -101,7 +105,11 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@ -132,7 +140,10 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@ -424,7 +435,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
if not isinstance(dest, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@ -131,7 +137,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@ -167,6 +174,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
@ -431,6 +431,7 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@ -494,7 +495,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@ -565,6 +566,7 @@ class Wan22(Wan21):
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
@ -719,6 +721,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
@ -40,7 +40,8 @@ class ChromaParams:
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
|
||||
@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
if mlp_silu_act:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
if mlp_silu_act:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
mlp = self.mlp_act(mlp)
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
|
||||
@ -15,7 +15,8 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -34,11 +35,14 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@ -76,6 +80,11 @@ class Flux(nn.Module):
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@ -86,6 +95,7 @@ class Flux(nn.Module):
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -94,7 +104,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@ -150,6 +160,8 @@ class Flux(nn.Module):
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
@ -171,7 +183,10 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -215,7 +230,10 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -326,8 +344,9 @@ class Flux(nn.Module):
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.axes_dim) == 4: # Flux 2
|
||||
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
|
||||
@ -509,7 +509,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
@ -517,15 +517,27 @@ class NextDiT(nn.Module):
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
@ -313,6 +313,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["theta"] = 2000
|
||||
dit_config["out_channels"] = 128
|
||||
dit_config["global_modulation"] = True
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["mlp_silu_act"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["ops_bias"] = False
|
||||
dit_config["default_ref_method"] = "index"
|
||||
dit_config["ref_index_scale"] = 10.0
|
||||
dit_config["txt_ids_dims"] = [3]
|
||||
patch_size = 1
|
||||
else:
|
||||
dit_config["image_model"] = "flux"
|
||||
@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["txt_ids_dims"] = []
|
||||
patch_size = 2
|
||||
|
||||
dit_config["in_channels"] = 16
|
||||
@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
|
||||
@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
@ -1012,9 +1012,18 @@ def force_channels_last():
|
||||
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
NUM_STREAMS = 0
|
||||
|
||||
if NUM_STREAMS > 0:
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
@ -1030,7 +1039,10 @@ def current_stream(device):
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS <= 1:
|
||||
if NUM_STREAMS == 0:
|
||||
return None
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
@ -1043,7 +1055,9 @@ def get_offload_stream(device):
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.cuda.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1051,7 +1065,9 @@ def get_offload_stream(device):
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.xpu.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1069,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
|
||||
@ -148,6 +148,15 @@ class LowVramPatch:
|
||||
else:
|
||||
return out
|
||||
|
||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
convert_func = None
|
||||
@ -269,6 +278,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -662,7 +674,16 @@ class ModelPatcher:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if weight_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||
if bias_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -676,20 +697,22 @@ class ModelPatcher:
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
offload_buffer = 0
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
module_offload_mem, module_mem, n, m, params = x
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
if not lowvram_fits:
|
||||
offload_buffer = potential_offload
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
@ -723,9 +746,11 @@ class ModelPatcher:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
if full_load or lowvram_fits:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
else:
|
||||
offload_buffer = potential_offload
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@ -766,7 +791,7 @@ class ModelPatcher:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
@ -778,6 +803,7 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
@ -831,6 +857,7 @@ class ModelPatcher:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
@ -849,13 +876,14 @@ class ModelPatcher:
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
params = unload[3]
|
||||
module_offload_mem, module_mem, n, m, params = unload
|
||||
|
||||
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
||||
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
@ -906,15 +934,18 @@ class ModelPatcher:
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
|
||||
@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
|
||||
@ -235,8 +235,8 @@ class QuantizedTensor(torch.Tensor):
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self):
|
||||
return self._qdata.is_contiguous()
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
return plain_tensor * scale
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
|
||||
47
comfy/sd.py
47
comfy/sd.py
@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -60,6 +61,8 @@ import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
import comfy.taesd.taehv
|
||||
import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
@ -508,13 +511,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
@ -584,6 +588,35 @@ class VAE:
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -924,6 +957,7 @@ class CLIPType(Enum):
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -955,6 +989,7 @@ class TEModel(Enum):
|
||||
MISTRAL3_24B = 14
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -988,9 +1023,12 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.QWEN3_4B
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
@ -1118,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
||||
171
comfy/taesd/taehv.py
Normal file
171
comfy/taesd/taehv.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
operations=comfy.ops.disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out, act_func):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
||||
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = act_func
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
_x = x.reshape(B, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
x = x.view(B, T, C, H, W)
|
||||
else:
|
||||
out = []
|
||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.popleft()
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
del xt
|
||||
else:
|
||||
b = model[i]
|
||||
if isinstance(b, MemBlock):
|
||||
if mem[i] is None:
|
||||
xt_new = b(xt, xt * 0)
|
||||
mem[i] = xt.detach().clone()
|
||||
else:
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i] = xt.detach().clone()
|
||||
del xt
|
||||
work_queue.appendleft(TWorkItem(xt_new, i+1))
|
||||
elif isinstance(b, TPool):
|
||||
if mem[i] is None:
|
||||
mem[i] = []
|
||||
mem[i].append(xt.detach().clone())
|
||||
if len(mem[i]) == b.stride:
|
||||
B, C, H, W = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
|
||||
mem[i] = []
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C, H, W = xt.shape
|
||||
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
|
||||
work_queue.appendleft(TWorkItem(xt_next, i+1))
|
||||
del xt
|
||||
else:
|
||||
xt = b(xt)
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
self.latent_channels = latent_channels
|
||||
self.parallel = parallel
|
||||
self.latent_format = latent_format
|
||||
self.show_progress_bar = show_progress_bar
|
||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
act_func = nn.ReLU(inplace=True)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
@ -100,6 +100,28 @@ class Qwen3_4BConfig:
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
vocab_size: int = 152064
|
||||
@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Ovis25_2BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
69
comfy/text_encoders/ovis.py
Normal file
69
comfy/text_encoders/ovis.py
Normal file
@ -0,0 +1,69 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class OvisTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
|
||||
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class Ovis25_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
|
||||
class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||
out, pooled = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen3_2b"][0]
|
||||
count_im_start = 0
|
||||
if template_end == -1:
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 4004 and count_im_start < 1:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 1):
|
||||
if tok_pairs[template_end + 1][0] == 25:
|
||||
template_end += 1
|
||||
|
||||
out = out[:, template_end:]
|
||||
return out, pooled, {}
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||
class OvisTEModel_(OvisTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
@ -179,36 +179,36 @@
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
|
||||
@ -675,6 +675,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
hidden_size = mmdit_config.get("dim", 0)
|
||||
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
key_map = {}
|
||||
|
||||
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attention.".format(prefix_from)
|
||||
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attention.norm_q.weight": "attention.q_norm.weight",
|
||||
"attention.norm_k.weight": "attention.k_norm.weight",
|
||||
"attention.to_out.0.weight": "attention.out.weight",
|
||||
"attention.to_out.0.bias": "attention.out.bias",
|
||||
"attention_norm1.weight": "attention_norm1.weight",
|
||||
"attention_norm2.weight": "attention_norm2.weight",
|
||||
"feed_forward.w1.weight": "feed_forward.w1.weight",
|
||||
"feed_forward.w2.weight": "feed_forward.w2.weight",
|
||||
"feed_forward.w3.weight": "feed_forward.w3.weight",
|
||||
"ffn_norm1.weight": "ffn_norm1.weight",
|
||||
"ffn_norm2.weight": "ffn_norm2.weight",
|
||||
}
|
||||
if has_adaln:
|
||||
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
|
||||
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
|
||||
for k, v in block_map.items():
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_context_refiner):
|
||||
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_noise_refiner):
|
||||
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
|
||||
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
|
||||
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
|
||||
("x_embedder.weight", "all_x_embedder.2-1.weight"),
|
||||
("x_embedder.bias", "all_x_embedder.2-1.bias"),
|
||||
("x_pad_token", "x_pad_token"),
|
||||
("cap_embedder.0.weight", "cap_embedder.0.weight"),
|
||||
("cap_embedder.1.weight", "cap_embedder.1.weight"),
|
||||
("cap_embedder.1.bias", "cap_embedder.1.bias"),
|
||||
("cap_pad_token", "cap_pad_token"),
|
||||
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
|
||||
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
|
||||
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
|
||||
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
|
||||
]
|
||||
|
||||
for c, diffusers in MAP_BASIC:
|
||||
key_map[diffusers] = "{}{}".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.cli_args import args
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
extra_kwargs = {}
|
||||
if format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
|
||||
@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiFileData(BaseModel):
|
||||
fileUri: str | None = Field(None)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
|
||||
66
comfy_api_nodes/apis/kling_api.py
Normal file
66
comfy_api_nodes/apis/kling_api.py
Normal file
@ -0,0 +1,66 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
|
||||
|
||||
|
||||
class OmniParamVideo(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
|
||||
keep_original_sound: str = Field(..., description="'yes' or 'no'")
|
||||
|
||||
|
||||
class OmniProFirstLastFrameRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
image_list: list[OmniParamImage] | None = Field(
|
||||
None, max_length=7, description="Max length 4 when video is present."
|
||||
)
|
||||
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
duration: str | None = Field(None, description="Total video duration")
|
||||
id: str | None = Field(None, description="Generated video ID")
|
||||
url: str | None = Field(None, description="URL for generated video")
|
||||
|
||||
|
||||
class TaskStatusVideoResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResponseData(BaseModel):
|
||||
created_at: int | None = Field(None, description="Task creation time")
|
||||
updated_at: int | None = Field(None, description="Task update time")
|
||||
task_status: str | None = None
|
||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||
task_id: str | None = Field(None, description="Task ID")
|
||||
task_result: TaskStatusVideoResults | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResponse(BaseModel):
|
||||
code: int | None = Field(None, description="Error code")
|
||||
message: str | None = Field(None, description="Error message")
|
||||
request_id: str | None = Field(None, description="Request ID")
|
||||
data: TaskStatusVideoResponseData | None = Field(None)
|
||||
@ -1,34 +1,21 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Image2(BaseModel):
|
||||
bytesBase64Encoded: str
|
||||
gcsUri: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
class VeoRequestInstanceImage(BaseModel):
|
||||
bytesBase64Encoded: str | None = Field(None)
|
||||
gcsUri: str | None = Field(None)
|
||||
mimeType: str | None = Field(None)
|
||||
|
||||
|
||||
class Image3(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = None
|
||||
gcsUri: str
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Instance1(BaseModel):
|
||||
image: Optional[Union[Image2, Image3]] = Field(
|
||||
None, description='Optional image to guide video generation'
|
||||
)
|
||||
class VeoRequestInstance(BaseModel):
|
||||
image: VeoRequestInstanceImage | None = Field(None)
|
||||
lastFrame: VeoRequestInstanceImage | None = Field(None)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class PersonGeneration1(str, Enum):
|
||||
ALLOW = 'ALLOW'
|
||||
BLOCK = 'BLOCK'
|
||||
|
||||
|
||||
class Parameters1(BaseModel):
|
||||
class VeoRequestParameters(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: Optional[PersonGeneration1] = None
|
||||
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
resolution: str | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: Optional[list[Instance1]] = None
|
||||
parameters: Optional[Parameters1] = None
|
||||
instances: list[VeoRequestInstance] | None = Field(None)
|
||||
parameters: VeoRequestParameters | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
|
||||
@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiImageConfig,
|
||||
@ -38,10 +36,10 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum):
|
||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: torch.Tensor,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
||||
if image_limit < 0:
|
||||
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||
total_images = get_number_of_images(images)
|
||||
if total_images <= 0:
|
||||
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||
|
||||
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
|
||||
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
|
||||
|
||||
# Number of images we'll send as URLs (fileData)
|
||||
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=num_url_images,
|
||||
)
|
||||
for reference_image_url in reference_images_urls:
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
fileData=GeminiFileData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
fileUri=reference_image_url,
|
||||
)
|
||||
)
|
||||
)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
data=tensor_to_base64_string(images[idx]),
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -338,8 +355,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
@ -364,29 +380,6 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
@ -562,8 +555,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
@ -582,30 +574,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiImage2(IO.ComfyNode):
|
||||
@ -702,7 +671,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
if images is not None:
|
||||
if get_number_of_images(images) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(create_image_parts(images))
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
@ -725,30 +694,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@ -4,15 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar
|
||||
import math
|
||||
import logging
|
||||
|
||||
from typing_extensions import override
|
||||
import math
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
KlingCameraControl,
|
||||
KlingCameraConfig,
|
||||
@ -50,25 +48,31 @@ from comfy_api_nodes.apis import (
|
||||
KlingCharacterEffectModelName,
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
OmniParamImage,
|
||||
OmniParamVideo,
|
||||
OmniProFirstLastFrameRequest,
|
||||
OmniProReferences2VideoRequest,
|
||||
OmniProText2VideoRequest,
|
||||
TaskStatusVideoResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
sync_op,
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
KLING_API_VERSION = "v1"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
||||
@ -94,8 +98,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32
|
||||
AVERAGE_DURATION_VIDEO_EFFECTS = 320
|
||||
AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
MODE_TEXT2VIDEO = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
@ -130,6 +132,8 @@ MODE_START_END_FRAME = {
|
||||
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
|
||||
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
|
||||
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
|
||||
"pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
|
||||
"pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
|
||||
}
|
||||
"""
|
||||
Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
|
||||
@ -206,6 +210,20 @@ VOICES_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||
"""Verifies that at least one camera control configuration is non-zero."""
|
||||
return any(not math.isclose(value, 0.0) for value in configs)
|
||||
@ -296,7 +314,7 @@ def get_video_from_response(response) -> KlingVideoResult:
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
def get_video_url_from_response(response) -> str | None:
|
||||
"""Returns the first video url from the Kling video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
@ -315,7 +333,7 @@ def get_images_from_response(response) -> list[KlingImageResult]:
|
||||
return images
|
||||
|
||||
|
||||
def get_images_urls_from_response(response) -> Optional[str]:
|
||||
def get_images_urls_from_response(response) -> str | None:
|
||||
"""Returns the list of image urls from the Kling image generation task result.
|
||||
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
|
||||
"""
|
||||
@ -349,7 +367,7 @@ async def execute_text2video(
|
||||
model_mode: str,
|
||||
duration: str,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
task_creation_response = await sync_op(
|
||||
@ -394,8 +412,8 @@ async def execute_image2video(
|
||||
model_mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
||||
validate_input_image(start_frame)
|
||||
@ -451,9 +469,9 @@ async def execute_video_effect(
|
||||
model_name: str,
|
||||
duration: KlingVideoGenDuration,
|
||||
image_1: torch.Tensor,
|
||||
image_2: Optional[torch.Tensor] = None,
|
||||
model_mode: Optional[KlingVideoGenMode] = None,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
image_2: torch.Tensor | None = None,
|
||||
model_mode: KlingVideoGenMode | None = None,
|
||||
) -> tuple[InputImpl.VideoFromFile, str, str]:
|
||||
if dual_character:
|
||||
request_input_field = KlingDualCharacterEffectInput(
|
||||
model_name=model_name,
|
||||
@ -499,13 +517,13 @@ async def execute_video_effect(
|
||||
|
||||
async def execute_lipsync(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
voice_language: Optional[str] = None,
|
||||
model_mode: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
voice_speed: Optional[float] = None,
|
||||
voice_id: Optional[str] = None,
|
||||
video: Input.Video,
|
||||
audio: Input.Audio | None = None,
|
||||
voice_language: str | None = None,
|
||||
model_mode: str | None = None,
|
||||
text: str | None = None,
|
||||
voice_speed: float | None = None,
|
||||
voice_id: str | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if text:
|
||||
validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC)
|
||||
@ -740,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProTextToVideoNode",
|
||||
display_name="Kling Omni Text to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use text prompts to generate videos with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProText2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProFirstLastFrameNode",
|
||||
display_name="Kling Omni First-Last-Frame to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("duration", options=["5", "10"]),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
optional=True,
|
||||
tooltip="An optional end frame for the video. "
|
||||
"This cannot be used simultaneously with 'reference_images'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if end_frame is not None and reference_images is not None:
|
||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = [
|
||||
OmniParamImage(
|
||||
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
|
||||
type="first_frame",
|
||||
)
|
||||
]
|
||||
if end_frame is not None:
|
||||
validate_image_dimensions(end_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
|
||||
image_list.append(
|
||||
OmniParamImage(
|
||||
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
|
||||
type="end_frame",
|
||||
)
|
||||
)
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 6:
|
||||
raise ValueError("The maximum number of reference images allowed is 6.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProFirstLastFrameRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageToVideoNode",
|
||||
display_name="Kling Omni Image to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if get_number_of_images(reference_images) > 7:
|
||||
raise ValueError("The maximum number of reference images is 7.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = []
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProVideoToVideoNode",
|
||||
display_name="Kling Omni Video to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
image_list: list[OmniParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 4:
|
||||
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
video_list = [
|
||||
OmniParamVideo(
|
||||
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
|
||||
refer_type="feature",
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
)
|
||||
]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProEditVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling Omni Edit Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
image_list: list[OmniParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 4:
|
||||
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
video_list = [
|
||||
OmniParamVideo(
|
||||
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
|
||||
refer_type="base",
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
)
|
||||
]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusVideoResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=None,
|
||||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
"""
|
||||
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
||||
@ -787,7 +1185,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_text2video(
|
||||
cls,
|
||||
@ -854,8 +1252,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
cls,
|
||||
@ -965,15 +1363,11 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"),
|
||||
IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[i.value for i in KlingVideoGenAspectRatio],
|
||||
default="16:9",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[2],
|
||||
default=modes[8],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -1170,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
category="api node/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"effect_scene",
|
||||
options=[i.value for i in KlingSingleImageEffectsScene],
|
||||
@ -1254,8 +1651,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
video: Input.Video,
|
||||
audio: Input.Audio,
|
||||
voice_language: str,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_lipsync(
|
||||
@ -1314,7 +1711,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
text: str,
|
||||
voice: str,
|
||||
voice_speed: float,
|
||||
@ -1471,7 +1868,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
human_fidelity: float,
|
||||
n: int,
|
||||
aspect_ratio: KlingImageGenAspectRatio,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
image: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
|
||||
@ -1533,6 +1930,11 @@ class KlingExtension(ComfyExtension):
|
||||
KlingImageGenerationNode,
|
||||
KlingSingleImageVideoEffectNode,
|
||||
KlingDualCharacterVideoEffectNode,
|
||||
OmniProTextToVideoNode,
|
||||
OmniProFirstLastFrameNode,
|
||||
OmniProImageToVideoNode,
|
||||
OmniProVideoToVideoNode,
|
||||
OmniProEditVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from server import PromptServer
|
||||
import folder_paths
|
||||
import base64
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
def create_input_message_contents(
|
||||
cls,
|
||||
prompt: str,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
image: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
) -> InputMessageContentList:
|
||||
"""Create a list of input message contents from prompt and optional image."""
|
||||
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [
|
||||
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
|
||||
InputTextContent(text=prompt, type="input_text"),
|
||||
]
|
||||
if image is not None:
|
||||
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
persist_context: bool = False,
|
||||
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
advanced_options: Optional[CreateModelResponseProperties] = None,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
advanced_options: CreateModelResponseProperties | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
status_extractor=lambda response: response.status,
|
||||
completed_statuses=["incomplete", "completed"]
|
||||
)
|
||||
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))
|
||||
|
||||
# Update history
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(output_text)
|
||||
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
|
||||
|
||||
|
||||
class OpenAIInputFiles(IO.ComfyNode):
|
||||
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
def execute(
|
||||
cls,
|
||||
truncation: bool,
|
||||
instructions: Optional[str] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
instructions: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
"""
|
||||
Configure advanced options for the OpenAI Chat Node.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollResponse,
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoRequestInstance,
|
||||
VeoRequestInstanceImage,
|
||||
VeoRequestParameters,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
)
|
||||
|
||||
|
||||
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Veo3FirstLastFrameNode",
|
||||
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||
category="api node/video/Veo",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=2,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation",
|
||||
),
|
||||
IO.Image.Input("first_frame", tooltip="Start frame"),
|
||||
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
||||
default="veo-3.1-fast-generate",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=True,
|
||||
tooltip="Generate audio for the video.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
model = MODELS_MAP[model]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=[
|
||||
VeoRequestInstance(
|
||||
prompt=prompt,
|
||||
image=VeoRequestInstanceImage(
|
||||
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
|
||||
),
|
||||
lastFrame=VeoRequestInstanceImage(
|
||||
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
|
||||
),
|
||||
),
|
||||
],
|
||||
parameters=VeoRequestParameters(
|
||||
aspectRatio=aspect_ratio,
|
||||
personGeneration="ALLOW",
|
||||
durationSeconds=duration,
|
||||
enhancePrompt=True, # cannot be False for Veo3
|
||||
seed=seed,
|
||||
generateAudio=generate_audio,
|
||||
negativePrompt=negative_prompt,
|
||||
resolution=resolution,
|
||||
),
|
||||
),
|
||||
)
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
poll_interval=5.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
response = poll_response.response
|
||||
filtered_count = response.raiMediaFilteredCount
|
||||
if filtered_count:
|
||||
reasons = response.raiMediaFilteredReasons or []
|
||||
reason_part = f": {reasons[0]}" if reasons else ""
|
||||
raise Exception(
|
||||
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||
)
|
||||
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
Veo3FirstLastFrameNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: Optional[str] = None,
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
show_batch_index: bool = True,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
num_to_upload = min(batch_len, max_images)
|
||||
batch_start_ts = time.monotonic()
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
for idx in range(num_to_upload):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||
|
||||
effective_label = wait_label
|
||||
if wait_label and show_batch_index and num_to_upload > 1:
|
||||
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
|
||||
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
@ -95,6 +103,7 @@ async def upload_video_to_comfyapi(
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
@ -119,15 +128,16 @@ async def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
upload_mime_type: str | None,
|
||||
wait_label: str | None = "Uploading",
|
||||
progress_origin_ts: float | None = None,
|
||||
) -> str:
|
||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||
if upload_mime_type is None:
|
||||
@ -148,6 +158,7 @@ async def upload_file_to_comfyapi(
|
||||
file_bytes_io,
|
||||
content_type=upload_mime_type,
|
||||
wait_label=wait_label,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
)
|
||||
return create_resp.download_url
|
||||
|
||||
@ -155,27 +166,18 @@ async def upload_file_to_comfyapi(
|
||||
async def upload_file(
|
||||
cls: type[IO.ComfyNode],
|
||||
upload_url: str,
|
||||
file: Union[BytesIO, str],
|
||||
file: BytesIO | str,
|
||||
*,
|
||||
content_type: Optional[str] = None,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: Optional[str] = None,
|
||||
wait_label: str | None = None,
|
||||
progress_origin_ts: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||
|
||||
Args:
|
||||
cls: Node class (provides auth context + UI progress hooks).
|
||||
upload_url: Pre-signed PUT URL.
|
||||
file: BytesIO or path string.
|
||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||
max_retries: Maximum retry attempts.
|
||||
retry_delay: Initial delay in seconds.
|
||||
retry_backoff: Exponential backoff factor.
|
||||
wait_label: Progress label shown in Comfy UI.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||
"""
|
||||
@ -198,7 +200,7 @@ async def upload_file(
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = time.monotonic()
|
||||
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1432
comfy_extras/nodes_dataset.py
Normal file
1432
comfy_extras/nodes_dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
import uuid
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
@ -34,58 +38,6 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def process(self, model_file, image, **kwargs):
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
||||
|
||||
load_image_node = nodes.LoadImage()
|
||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
|
||||
]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
"image": ("LOAD_3D_ANIMATION", {}),
|
||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
@ -120,7 +72,8 @@ class Preview3D():
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
"camera_info": ("LOAD3D_CAMERA", {}),
|
||||
"bg_image": ("IMAGE", {})
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
@ -133,50 +86,33 @@ class Preview3D():
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
|
||||
bg_image_path = None
|
||||
if bg_image is not None:
|
||||
|
||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_array)
|
||||
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||
bg_image_path = os.path.join(temp_dir, filename)
|
||||
img.save(bg_image_path, compress_level=1)
|
||||
|
||||
bg_image_path = f"temp/{filename}"
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
}
|
||||
}
|
||||
|
||||
class Preview3DAnimation():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||
},
|
||||
"optional": {
|
||||
"camera_info": ("LOAD3D_CAMERA", {})
|
||||
}}
|
||||
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def process(self, model_file, **kwargs):
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
|
||||
return {
|
||||
"ui": {
|
||||
"result": [model_file, camera_info]
|
||||
"result": [model_file, camera_info, bg_image_path]
|
||||
}
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Load3D": Load3D,
|
||||
"Load3DAnimation": Load3DAnimation,
|
||||
"Preview3D": Preview3D,
|
||||
"Preview3DAnimation": Preview3DAnimation
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Load3D": "Load 3D",
|
||||
"Load3DAnimation": "Load 3D - Animation",
|
||||
"Preview3D": "Preview 3D",
|
||||
"Preview3DAnimation": "Preview 3D - Animation"
|
||||
"Load3D": "Load 3D & Animation",
|
||||
"Preview3D": "Preview 3D & Animation",
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.75"
|
||||
__version__ = "0.3.76"
|
||||
|
||||
@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None:
|
||||
user_directory = user_dir
|
||||
|
||||
|
||||
# System User Protection - Protects system directories from HTTP endpoint access
|
||||
# System Users are internal-only users that cannot be accessed via HTTP endpoints.
|
||||
# They use the '__' prefix convention (similar to Python's private member convention).
|
||||
SYSTEM_USER_PREFIX = "__"
|
||||
|
||||
|
||||
def get_system_user_directory(name: str = "system") -> str:
|
||||
"""
|
||||
Get the path to a System User directory.
|
||||
|
||||
System User directories (prefixed with '__') are only accessible via internal API,
|
||||
not through HTTP endpoints. Use this for storing system-internal data that
|
||||
should not be exposed to users.
|
||||
|
||||
Args:
|
||||
name: System user name (e.g., "system", "cache"). Must be alphanumeric
|
||||
with underscores allowed, but cannot start with underscore.
|
||||
|
||||
Returns:
|
||||
Absolute path to the system user directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If name is empty, invalid, or starts with underscore.
|
||||
|
||||
Example:
|
||||
>>> get_system_user_directory("cache")
|
||||
'/path/to/user/__cache'
|
||||
"""
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError("System user name cannot be empty")
|
||||
if not name.replace("_", "").isalnum():
|
||||
raise ValueError(f"Invalid system user name: '{name}'")
|
||||
if name.startswith("_"):
|
||||
raise ValueError("System user name should not start with underscore")
|
||||
return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}")
|
||||
|
||||
|
||||
def get_public_user_directory(user_id: str) -> str | None:
|
||||
"""
|
||||
Get the path to a Public User directory for HTTP endpoint access.
|
||||
|
||||
This function provides structural security by returning None for any
|
||||
System User (prefixed with '__'). All HTTP endpoints should use this
|
||||
function instead of directly constructing user paths.
|
||||
|
||||
Args:
|
||||
user_id: User identifier from HTTP request.
|
||||
|
||||
Returns:
|
||||
Absolute path to the user directory, or None if user_id is invalid
|
||||
or refers to a System User.
|
||||
|
||||
Example:
|
||||
>>> get_public_user_directory("default")
|
||||
'/path/to/user/default'
|
||||
>>> get_public_user_directory("__system")
|
||||
None
|
||||
"""
|
||||
if not user_id or not isinstance(user_id, str):
|
||||
return None
|
||||
if user_id.startswith(SYSTEM_USER_PREFIX):
|
||||
return None
|
||||
return os.path.join(get_user_directory(), user_id)
|
||||
|
||||
|
||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
def get_directory_by_type(type_name: str) -> str | None:
|
||||
if type_name == "output":
|
||||
|
||||
@ -2,17 +2,24 @@ import torch
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
from comfy.sd import VAE
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
)
|
||||
def preview_to_image(latent_image, do_scale=True):
|
||||
if do_scale:
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
)
|
||||
else:
|
||||
latents_ubyte = (latent_image.clamp(0, 1)
|
||||
.mul(0xFF) # to 0..255
|
||||
)
|
||||
if comfy.model_management.directml_enabled:
|
||||
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
|
||||
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
||||
return preview_to_image(x_sample)
|
||||
|
||||
class TAEHVPreviewerImpl(TAESDPreviewerImpl):
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
|
||||
return preview_to_image(x_sample, do_scale=False)
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
|
||||
@ -81,8 +92,13 @@ def get_previewer(device, latent_format):
|
||||
|
||||
if method == LatentPreviewMethod.TAESD:
|
||||
if taesd_decoder_path:
|
||||
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
if latent_format.taesd_decoder_name in VIDEO_TAES:
|
||||
taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
|
||||
taesd.first_stage_model.show_progress_bar = False
|
||||
previewer = TAEHVPreviewerImpl(taesd)
|
||||
else:
|
||||
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
else:
|
||||
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||
|
||||
|
||||
30
main.py
30
main.py
@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
@ -22,6 +23,23 @@ if __name__ == "__main__":
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
|
||||
def handle_comfyui_manager_unavailable():
|
||||
if not args.windows_standalone_build:
|
||||
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||
args.enable_manager = False
|
||||
|
||||
|
||||
if args.enable_manager:
|
||||
if importlib.util.find_spec("comfyui_manager"):
|
||||
import comfyui_manager
|
||||
|
||||
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
|
||||
handle_comfyui_manager_unavailable()
|
||||
else:
|
||||
handle_comfyui_manager_unavailable()
|
||||
|
||||
|
||||
def apply_custom_paths():
|
||||
# extra model paths
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
@ -79,6 +97,11 @@ def execute_prestartup_script():
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
|
||||
if args.enable_manager:
|
||||
if comfyui_manager.should_be_disabled(module_path):
|
||||
continue
|
||||
|
||||
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
||||
continue
|
||||
|
||||
@ -101,6 +124,10 @@ def execute_prestartup_script():
|
||||
logging.info("")
|
||||
|
||||
apply_custom_paths()
|
||||
|
||||
if args.enable_manager:
|
||||
comfyui_manager.prestartup()
|
||||
|
||||
execute_prestartup_script()
|
||||
|
||||
|
||||
@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
if args.enable_manager and not args.disable_manager_ui:
|
||||
comfyui_manager.start()
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
|
||||
1
manager_requirements.txt
Normal file
1
manager_requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
comfyui_manager==4.0.3b3
|
||||
30
nodes.py
30
nodes.py
@ -43,6 +43,9 @@ import folder_paths
|
||||
import latent_preview
|
||||
import node_helpers
|
||||
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
@ -692,8 +695,10 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
@staticmethod
|
||||
def vae_list():
|
||||
def vae_list(s):
|
||||
vaes = folder_paths.get_filename_list("vae")
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
sdxl_taesd_enc = False
|
||||
@ -722,6 +727,11 @@ class VAELoader:
|
||||
f1_taesd_dec = True
|
||||
elif v.startswith("taef1_decoder."):
|
||||
f1_taesd_enc = True
|
||||
else:
|
||||
for tae in s.video_taes:
|
||||
if v.startswith(tae):
|
||||
vaes.append(v)
|
||||
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
@ -765,7 +775,7 @@ class VAELoader:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (s.vae_list(), )}}
|
||||
return {"required": { "vae_name": (s.vae_list(s), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
@ -776,10 +786,13 @@ class VAELoader:
|
||||
if vae_name == "pixel_space":
|
||||
sd = {}
|
||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
||||
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
elif vae_name in self.image_taes:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
if os.path.splitext(vae_name)[0] in self.video_taes:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae.throw_exception_if_invalid()
|
||||
@ -929,7 +942,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -2233,6 +2246,12 @@ async def init_external_custom_nodes():
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
|
||||
if args.enable_manager:
|
||||
if comfyui_manager.should_be_disabled(module_path):
|
||||
logging.info(f"Blocked by policy: {module_path}")
|
||||
continue
|
||||
|
||||
time_before = time.perf_counter()
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
@ -2278,6 +2297,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_train.py",
|
||||
"nodes_dataset.py",
|
||||
"nodes_sag.py",
|
||||
"nodes_perpneg.py",
|
||||
"nodes_stable3d.py",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.75"
|
||||
version = "0.3.76"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.30.6
|
||||
comfyui-workflow-templates==0.7.20
|
||||
comfyui-frontend-package==1.33.10
|
||||
comfyui-workflow-templates==0.7.25
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
10
server.py
10
server.py
@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
|
||||
# Import cache control middleware
|
||||
from middleware.cache_middleware import cache_control
|
||||
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
@ -174,7 +177,7 @@ def create_block_external_middleware():
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||
return response
|
||||
|
||||
return block_external_middleware
|
||||
@ -212,6 +215,9 @@ class PromptServer():
|
||||
if args.disable_api_nodes:
|
||||
middlewares.append(create_block_external_middleware())
|
||||
|
||||
if args.enable_manager:
|
||||
middlewares.append(comfyui_manager.create_middleware())
|
||||
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
@ -599,7 +605,7 @@ class PromptServer():
|
||||
|
||||
system_stats = {
|
||||
"system": {
|
||||
"os": os.name,
|
||||
"os": sys.platform,
|
||||
"ram_total": ram_total,
|
||||
"ram_free": ram_free,
|
||||
"comfyui_version": __version__,
|
||||
|
||||
193
tests-unit/app_test/user_manager_system_user_test.py
Normal file
193
tests-unit/app_test/user_manager_system_user_test.py
Normal file
@ -0,0 +1,193 @@
|
||||
"""Tests for System User Protection in user_manager.py
|
||||
|
||||
Tests cover:
|
||||
- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers
|
||||
- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory()
|
||||
- add_user(): 3rd defense layer - prevents creation of System User names
|
||||
- Defense layers integration tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import tempfile
|
||||
|
||||
import folder_paths
|
||||
from app.user_manager import UserManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_directory():
|
||||
"""Create a temporary user directory."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
original_dir = folder_paths.get_user_directory()
|
||||
folder_paths.set_user_directory(temp_dir)
|
||||
yield temp_dir
|
||||
folder_paths.set_user_directory(original_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(mock_user_directory):
|
||||
"""Create a UserManager instance for testing."""
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
manager = UserManager()
|
||||
# Add a default user for testing
|
||||
manager.users = {"default": "default", "test_user_123": "Test User"}
|
||||
yield manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
class TestGetRequestUserId:
|
||||
"""Tests for get_request_user_id() - 1st defense layer.
|
||||
|
||||
Verifies:
|
||||
- System Users (__ prefix) in HTTP header are rejected with KeyError
|
||||
- Public Users pass through successfully
|
||||
"""
|
||||
|
||||
def test_system_user_raises_error(self, user_manager, mock_request):
|
||||
"""Test System User in header raises KeyError."""
|
||||
mock_request.headers = {"comfy-user": "__system"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
with pytest.raises(KeyError, match="Unknown user"):
|
||||
user_manager.get_request_user_id(mock_request)
|
||||
|
||||
def test_system_user_cache_raises_error(self, user_manager, mock_request):
|
||||
"""Test System User cache raises KeyError."""
|
||||
mock_request.headers = {"comfy-user": "__cache"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
with pytest.raises(KeyError, match="Unknown user"):
|
||||
user_manager.get_request_user_id(mock_request)
|
||||
|
||||
def test_normal_user_works(self, user_manager, mock_request):
|
||||
"""Test normal user access works."""
|
||||
mock_request.headers = {"comfy-user": "default"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
user_id = user_manager.get_request_user_id(mock_request)
|
||||
assert user_id == "default"
|
||||
|
||||
def test_unknown_user_raises_error(self, user_manager, mock_request):
|
||||
"""Test unknown user raises KeyError."""
|
||||
mock_request.headers = {"comfy-user": "unknown_user"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
with pytest.raises(KeyError, match="Unknown user"):
|
||||
user_manager.get_request_user_id(mock_request)
|
||||
|
||||
|
||||
class TestGetRequestUserFilepath:
|
||||
"""Tests for get_request_user_filepath() - 2nd defense layer.
|
||||
|
||||
Verifies:
|
||||
- Returns None when get_public_user_directory() returns None (System User)
|
||||
- Acts as backup defense if 1st layer is bypassed
|
||||
"""
|
||||
|
||||
def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory):
|
||||
"""Test System User returns None (structural blocking)."""
|
||||
# First, we need to mock get_request_user_id to return System User
|
||||
# But actually, get_request_user_id will raise KeyError first
|
||||
# So we test via get_public_user_directory returning None
|
||||
mock_request.headers = {"comfy-user": "default"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
# Patch get_public_user_directory to return None for testing
|
||||
with patch.object(folder_paths, 'get_public_user_directory', return_value=None):
|
||||
result = user_manager.get_request_user_filepath(mock_request, "test.txt")
|
||||
assert result is None
|
||||
|
||||
def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory):
|
||||
"""Test normal user gets valid filepath."""
|
||||
mock_request.headers = {"comfy-user": "default"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
path = user_manager.get_request_user_filepath(mock_request, "test.txt")
|
||||
assert path is not None
|
||||
assert "default" in path
|
||||
assert path.endswith("test.txt")
|
||||
|
||||
|
||||
class TestAddUser:
|
||||
"""Tests for add_user() - 3rd defense layer (creation-time blocking).
|
||||
|
||||
Verifies:
|
||||
- System User name (__ prefix) creation is rejected with ValueError
|
||||
- Sanitized usernames that become System User are also rejected
|
||||
"""
|
||||
|
||||
def test_system_user_prefix_name_raises(self, user_manager):
|
||||
"""Test System User prefix in name raises ValueError."""
|
||||
with pytest.raises(ValueError, match="System User prefix not allowed"):
|
||||
user_manager.add_user("__system")
|
||||
|
||||
def test_system_user_prefix_cache_raises(self, user_manager):
|
||||
"""Test System User cache prefix raises ValueError."""
|
||||
with pytest.raises(ValueError, match="System User prefix not allowed"):
|
||||
user_manager.add_user("__cache")
|
||||
|
||||
def test_sanitized_system_user_prefix_raises(self, user_manager):
|
||||
"""Test sanitized name becoming System User prefix raises ValueError (bypass prevention)."""
|
||||
# "__test" directly starts with System User prefix
|
||||
with pytest.raises(ValueError, match="System User prefix not allowed"):
|
||||
user_manager.add_user("__test")
|
||||
|
||||
def test_normal_user_creation(self, user_manager, mock_user_directory):
|
||||
"""Test normal user creation works."""
|
||||
user_id = user_manager.add_user("Normal User")
|
||||
assert user_id is not None
|
||||
assert not user_id.startswith("__")
|
||||
assert "Normal-User" in user_id or "Normal_User" in user_id
|
||||
|
||||
def test_empty_name_raises(self, user_manager):
|
||||
"""Test empty name raises ValueError."""
|
||||
with pytest.raises(ValueError, match="username not provided"):
|
||||
user_manager.add_user("")
|
||||
|
||||
def test_whitespace_only_raises(self, user_manager):
|
||||
"""Test whitespace-only name raises ValueError."""
|
||||
with pytest.raises(ValueError, match="username not provided"):
|
||||
user_manager.add_user(" ")
|
||||
|
||||
|
||||
class TestDefenseLayers:
|
||||
"""Integration tests for all three defense layers.
|
||||
|
||||
Verifies:
|
||||
- Each defense layer blocks System Users independently
|
||||
- System User bypass is impossible through any layer
|
||||
"""
|
||||
|
||||
def test_layer1_get_request_user_id(self, user_manager, mock_request):
|
||||
"""Test 1st defense layer blocks System Users."""
|
||||
mock_request.headers = {"comfy-user": "__system"}
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
with pytest.raises(KeyError):
|
||||
user_manager.get_request_user_id(mock_request)
|
||||
|
||||
def test_layer2_get_public_user_directory(self):
|
||||
"""Test 2nd defense layer blocks System Users."""
|
||||
result = folder_paths.get_public_user_directory("__system")
|
||||
assert result is None
|
||||
|
||||
def test_layer3_add_user(self, user_manager):
|
||||
"""Test 3rd defense layer blocks System User creation."""
|
||||
with pytest.raises(ValueError):
|
||||
user_manager.add_user("__system")
|
||||
206
tests-unit/folder_paths_test/system_user_test.py
Normal file
206
tests-unit/folder_paths_test/system_user_test.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""Tests for System User Protection in folder_paths.py
|
||||
|
||||
Tests cover:
|
||||
- get_system_user_directory(): Internal API for custom nodes to access System User directories
|
||||
- get_public_user_directory(): HTTP endpoint access with System User blocking
|
||||
- Backward compatibility: Existing APIs unchanged
|
||||
- Security: Path traversal and injection prevention
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from folder_paths import (
|
||||
get_system_user_directory,
|
||||
get_public_user_directory,
|
||||
get_user_directory,
|
||||
set_user_directory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_user_directory():
|
||||
"""Create a temporary user directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
original_dir = get_user_directory()
|
||||
set_user_directory(temp_dir)
|
||||
yield temp_dir
|
||||
set_user_directory(original_dir)
|
||||
|
||||
|
||||
class TestGetSystemUserDirectory:
|
||||
"""Tests for get_system_user_directory() - internal API for System User directories.
|
||||
|
||||
Verifies:
|
||||
- Custom nodes can access System User directories via internal API
|
||||
- Input validation prevents path traversal attacks
|
||||
"""
|
||||
|
||||
def test_default_name(self, mock_user_directory):
|
||||
"""Test default 'system' name."""
|
||||
path = get_system_user_directory()
|
||||
assert path.endswith("__system")
|
||||
assert mock_user_directory in path
|
||||
|
||||
def test_custom_name(self, mock_user_directory):
|
||||
"""Test custom system user name."""
|
||||
path = get_system_user_directory("cache")
|
||||
assert path.endswith("__cache")
|
||||
assert "__cache" in path
|
||||
|
||||
def test_name_with_underscore(self, mock_user_directory):
|
||||
"""Test name with underscore in middle."""
|
||||
path = get_system_user_directory("my_cache")
|
||||
assert "__my_cache" in path
|
||||
|
||||
def test_empty_name_raises(self):
|
||||
"""Test empty name raises ValueError."""
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
get_system_user_directory("")
|
||||
|
||||
def test_none_name_raises(self):
|
||||
"""Test None name raises ValueError."""
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
get_system_user_directory(None)
|
||||
|
||||
def test_name_starting_with_underscore_raises(self):
|
||||
"""Test name starting with underscore raises ValueError."""
|
||||
with pytest.raises(ValueError, match="should not start with underscore"):
|
||||
get_system_user_directory("_system")
|
||||
|
||||
def test_path_traversal_raises(self):
|
||||
"""Test path traversal attempt raises ValueError (security)."""
|
||||
with pytest.raises(ValueError, match="Invalid system user name"):
|
||||
get_system_user_directory("../escape")
|
||||
|
||||
def test_path_traversal_middle_raises(self):
|
||||
"""Test path traversal in middle raises ValueError (security)."""
|
||||
with pytest.raises(ValueError, match="Invalid system user name"):
|
||||
get_system_user_directory("system/../other")
|
||||
|
||||
def test_special_chars_raise(self):
|
||||
"""Test special characters raise ValueError (security)."""
|
||||
with pytest.raises(ValueError, match="Invalid system user name"):
|
||||
get_system_user_directory("system!")
|
||||
|
||||
def test_returns_absolute_path(self, mock_user_directory):
|
||||
"""Test returned path is absolute."""
|
||||
path = get_system_user_directory("test")
|
||||
assert os.path.isabs(path)
|
||||
|
||||
|
||||
class TestGetPublicUserDirectory:
|
||||
"""Tests for get_public_user_directory() - HTTP endpoint access with System User blocking.
|
||||
|
||||
Verifies:
|
||||
- System Users (__ prefix) return None, blocking HTTP access
|
||||
- Public Users get valid paths
|
||||
- New endpoints using this function are automatically protected
|
||||
"""
|
||||
|
||||
def test_normal_user(self, mock_user_directory):
|
||||
"""Test normal user returns valid path."""
|
||||
path = get_public_user_directory("default")
|
||||
assert path is not None
|
||||
assert "default" in path
|
||||
assert mock_user_directory in path
|
||||
|
||||
def test_system_user_returns_none(self):
|
||||
"""Test System User (__ prefix) returns None - blocks HTTP access."""
|
||||
assert get_public_user_directory("__system") is None
|
||||
|
||||
def test_system_user_cache_returns_none(self):
|
||||
"""Test System User cache returns None."""
|
||||
assert get_public_user_directory("__cache") is None
|
||||
|
||||
def test_empty_user_returns_none(self):
|
||||
"""Test empty user returns None."""
|
||||
assert get_public_user_directory("") is None
|
||||
|
||||
def test_none_user_returns_none(self):
|
||||
"""Test None user returns None."""
|
||||
assert get_public_user_directory(None) is None
|
||||
|
||||
def test_header_injection_returns_none(self):
|
||||
"""Test header injection attempt returns None (security)."""
|
||||
assert get_public_user_directory("__system\r\nX-Injected: true") is None
|
||||
|
||||
def test_null_byte_injection_returns_none(self):
|
||||
"""Test null byte injection handling (security)."""
|
||||
# Note: startswith check happens before any path operations
|
||||
result = get_public_user_directory("user\x00__system")
|
||||
# This should return a path since it doesn't start with __
|
||||
# The actual security comes from the path not being __*
|
||||
assert result is not None or result is None # Depends on validation
|
||||
|
||||
def test_path_traversal_attempt(self, mock_user_directory):
|
||||
"""Test path traversal attempt handling."""
|
||||
# This function doesn't validate paths, only reserved prefix
|
||||
# Path traversal should be handled by the caller
|
||||
path = get_public_user_directory("../../../etc/passwd")
|
||||
# Returns path but doesn't start with __, so not None
|
||||
# Actual path validation happens in user_manager
|
||||
assert path is not None or "__" not in "../../../etc/passwd"
|
||||
|
||||
def test_returns_absolute_path(self, mock_user_directory):
|
||||
"""Test returned path is absolute."""
|
||||
path = get_public_user_directory("testuser")
|
||||
assert path is not None
|
||||
assert os.path.isabs(path)
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Tests for backward compatibility with existing APIs.
|
||||
|
||||
Verifies:
|
||||
- get_user_directory() API unchanged
|
||||
- Existing user data remains accessible
|
||||
"""
|
||||
|
||||
def test_get_user_directory_unchanged(self, mock_user_directory):
|
||||
"""Test get_user_directory() still works as before."""
|
||||
user_dir = get_user_directory()
|
||||
assert user_dir is not None
|
||||
assert os.path.isabs(user_dir)
|
||||
assert user_dir == mock_user_directory
|
||||
|
||||
def test_existing_user_accessible(self, mock_user_directory):
|
||||
"""Test existing users can access their directories."""
|
||||
path = get_public_user_directory("default")
|
||||
assert path is not None
|
||||
assert "default" in path
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases in System User detection.
|
||||
|
||||
Verifies:
|
||||
- Only __ prefix is blocked (not _, not middle __)
|
||||
- Bypass attempts are prevented
|
||||
"""
|
||||
|
||||
def test_prefix_only(self):
|
||||
"""Test prefix-only string is blocked."""
|
||||
assert get_public_user_directory("__") is None
|
||||
|
||||
def test_single_underscore_allowed(self):
|
||||
"""Test single underscore prefix is allowed (not System User)."""
|
||||
path = get_public_user_directory("_system")
|
||||
assert path is not None
|
||||
assert "_system" in path
|
||||
|
||||
def test_triple_underscore_blocked(self):
|
||||
"""Test triple underscore is blocked (starts with __)."""
|
||||
assert get_public_user_directory("___system") is None
|
||||
|
||||
def test_underscore_in_middle_allowed(self):
|
||||
"""Test underscore in middle is allowed."""
|
||||
path = get_public_user_directory("my__system")
|
||||
assert path is not None
|
||||
assert "my__system" in path
|
||||
|
||||
def test_leading_space_allowed(self):
|
||||
"""Test leading space + prefix is allowed (doesn't start with __)."""
|
||||
path = get_public_user_directory(" __system")
|
||||
assert path is not None
|
||||
375
tests-unit/prompt_server_test/system_user_endpoint_test.py
Normal file
375
tests-unit/prompt_server_test/system_user_endpoint_test.py
Normal file
@ -0,0 +1,375 @@
|
||||
"""E2E Tests for System User Protection HTTP Endpoints
|
||||
|
||||
Tests cover:
|
||||
- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move)
|
||||
- User creation blocking: System User names cannot be created via POST /users
|
||||
- Backward compatibility: Public Users work as before
|
||||
- Custom node scenario: Internal API works while HTTP is blocked
|
||||
- Structural security: get_public_user_directory() provides automatic protection
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from aiohttp import web
|
||||
from app.user_manager import UserManager
|
||||
from unittest.mock import patch
|
||||
import folder_paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_directory(tmp_path):
|
||||
"""Create a temporary user directory."""
|
||||
original_dir = folder_paths.get_user_directory()
|
||||
folder_paths.set_user_directory(str(tmp_path))
|
||||
yield tmp_path
|
||||
folder_paths.set_user_directory(original_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager_multi_user(mock_user_directory):
|
||||
"""Create UserManager in multi-user mode."""
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
um = UserManager()
|
||||
# Add test users
|
||||
um.users = {"default": "default", "test_user_123": "Test User"}
|
||||
yield um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_multi_user(user_manager_multi_user):
|
||||
"""Create app with multi-user mode enabled."""
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager_multi_user.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
class TestSystemUserEndpointBlocking:
|
||||
"""E2E tests for System User blocking on all HTTP endpoints.
|
||||
|
||||
Verifies:
|
||||
- GET /userdata blocked for System Users
|
||||
- POST /userdata blocked for System Users
|
||||
- DELETE /userdata blocked for System Users
|
||||
- POST /userdata/.../move/... blocked for System Users
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_userdata_get_blocks_system_user(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
GET /userdata with System User header should be blocked.
|
||||
"""
|
||||
# Create test directory for System User (simulating internal creation)
|
||||
system_user_dir = mock_user_directory / "__system"
|
||||
system_user_dir.mkdir()
|
||||
(system_user_dir / "secret.txt").write_text("sensitive data")
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
# Attempt to access System User's data via HTTP
|
||||
resp = await client.get(
|
||||
"/userdata?dir=.",
|
||||
headers={"comfy-user": "__system"}
|
||||
)
|
||||
|
||||
# Should be blocked (403 Forbidden or similar error)
|
||||
assert resp.status in [400, 403, 500], \
|
||||
f"System User access should be blocked, got {resp.status}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_userdata_post_blocks_system_user(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
POST /userdata with System User header should be blocked.
|
||||
"""
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.post(
|
||||
"/userdata/test.txt",
|
||||
headers={"comfy-user": "__system"},
|
||||
data=b"malicious content"
|
||||
)
|
||||
|
||||
assert resp.status in [400, 403, 500], \
|
||||
f"System User write should be blocked, got {resp.status}"
|
||||
|
||||
# Verify no file was created
|
||||
assert not (mock_user_directory / "__system" / "test.txt").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_userdata_delete_blocks_system_user(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
DELETE /userdata with System User header should be blocked.
|
||||
"""
|
||||
# Create a file in System User directory
|
||||
system_user_dir = mock_user_directory / "__system"
|
||||
system_user_dir.mkdir()
|
||||
secret_file = system_user_dir / "secret.txt"
|
||||
secret_file.write_text("do not delete")
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.delete(
|
||||
"/userdata/secret.txt",
|
||||
headers={"comfy-user": "__system"}
|
||||
)
|
||||
|
||||
assert resp.status in [400, 403, 500], \
|
||||
f"System User delete should be blocked, got {resp.status}"
|
||||
|
||||
# Verify file still exists
|
||||
assert secret_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_v2_userdata_blocks_system_user(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
GET /v2/userdata with System User header should be blocked.
|
||||
"""
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.get(
|
||||
"/v2/userdata",
|
||||
headers={"comfy-user": "__system"}
|
||||
)
|
||||
|
||||
assert resp.status in [400, 403, 500], \
|
||||
f"System User v2 access should be blocked, got {resp.status}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_userdata_blocks_system_user(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
POST /userdata/{file}/move/{dest} with System User header should be blocked.
|
||||
"""
|
||||
system_user_dir = mock_user_directory / "__system"
|
||||
system_user_dir.mkdir()
|
||||
(system_user_dir / "source.txt").write_text("sensitive data")
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.post(
|
||||
"/userdata/source.txt/move/dest.txt",
|
||||
headers={"comfy-user": "__system"}
|
||||
)
|
||||
|
||||
assert resp.status in [400, 403, 500], \
|
||||
f"System User move should be blocked, got {resp.status}"
|
||||
|
||||
# Verify source file still exists (move was blocked)
|
||||
assert (system_user_dir / "source.txt").exists()
|
||||
|
||||
|
||||
class TestSystemUserCreationBlocking:
|
||||
"""E2E tests for blocking System User name creation via POST /users.
|
||||
|
||||
Verifies:
|
||||
- POST /users returns 400 for System User name (not 500)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_users_blocks_system_user_name(
|
||||
self, aiohttp_client, app_multi_user
|
||||
):
|
||||
"""POST /users with System User name should return 400 Bad Request."""
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
resp = await client.post(
|
||||
"/users",
|
||||
json={"username": "__system"}
|
||||
)
|
||||
|
||||
assert resp.status == 400, \
|
||||
f"System User creation should return 400, got {resp.status}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_users_blocks_system_user_prefix_variations(
|
||||
self, aiohttp_client, app_multi_user
|
||||
):
|
||||
"""POST /users with any System User prefix variation should return 400 Bad Request."""
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
system_user_names = ["__system", "__cache", "__config", "__anything"]
|
||||
|
||||
for name in system_user_names:
|
||||
resp = await client.post("/users", json={"username": name})
|
||||
assert resp.status == 400, \
|
||||
f"System User name '{name}' should return 400, got {resp.status}"
|
||||
|
||||
|
||||
class TestPublicUserStillWorks:
|
||||
"""E2E tests for backward compatibility - Public Users should work as before.
|
||||
|
||||
Verifies:
|
||||
- Public Users can access their data via HTTP
|
||||
- Public Users can create files via HTTP
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_user_can_access_userdata(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
Public Users should still be able to access their data.
|
||||
"""
|
||||
# Create test directory for Public User
|
||||
user_dir = mock_user_directory / "default"
|
||||
user_dir.mkdir()
|
||||
test_dir = user_dir / "workflows"
|
||||
test_dir.mkdir()
|
||||
(test_dir / "test.json").write_text('{"test": true}')
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.get(
|
||||
"/userdata?dir=workflows",
|
||||
headers={"comfy-user": "default"}
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "test.json" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_user_can_create_files(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
Public Users should still be able to create files.
|
||||
"""
|
||||
# Create user directory
|
||||
user_dir = mock_user_directory / "default"
|
||||
user_dir.mkdir()
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.post(
|
||||
"/userdata/newfile.txt",
|
||||
headers={"comfy-user": "default"},
|
||||
data=b"user content"
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
assert (user_dir / "newfile.txt").exists()
|
||||
|
||||
|
||||
class TestCustomNodeScenario:
|
||||
"""Tests for custom node use case: internal API access vs HTTP blocking.
|
||||
|
||||
Verifies:
|
||||
- Internal API (get_system_user_directory) works for custom nodes
|
||||
- HTTP endpoint cannot access data created via internal API
|
||||
"""
|
||||
|
||||
def test_internal_api_can_access_system_user(self, mock_user_directory):
|
||||
"""
|
||||
Internal API (get_system_user_directory) should work for custom nodes.
|
||||
"""
|
||||
# Custom node uses internal API
|
||||
system_path = folder_paths.get_system_user_directory("mynode_config")
|
||||
|
||||
assert system_path is not None
|
||||
assert "__mynode_config" in system_path
|
||||
|
||||
# Can create and write to System User directory
|
||||
os.makedirs(system_path, exist_ok=True)
|
||||
config_file = os.path.join(system_path, "settings.json")
|
||||
with open(config_file, "w") as f:
|
||||
f.write('{"api_key": "secret"}')
|
||||
|
||||
assert os.path.exists(config_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_cannot_access_internal_data(
|
||||
self, aiohttp_client, app_multi_user, mock_user_directory
|
||||
):
|
||||
"""
|
||||
HTTP endpoint cannot access data created via internal API.
|
||||
"""
|
||||
# Custom node creates data via internal API
|
||||
system_path = folder_paths.get_system_user_directory("mynode_config")
|
||||
os.makedirs(system_path, exist_ok=True)
|
||||
with open(os.path.join(system_path, "secret.json"), "w") as f:
|
||||
f.write('{"api_key": "secret"}')
|
||||
|
||||
client = await aiohttp_client(app_multi_user)
|
||||
|
||||
# Attacker tries to access via HTTP
|
||||
with patch('app.user_manager.args') as mock_args:
|
||||
mock_args.multi_user = True
|
||||
resp = await client.get(
|
||||
"/userdata/secret.json",
|
||||
headers={"comfy-user": "__mynode_config"}
|
||||
)
|
||||
|
||||
# Should be blocked
|
||||
assert resp.status in [400, 403, 500]
|
||||
|
||||
|
||||
class TestStructuralSecurity:
|
||||
"""Tests for structural security pattern.
|
||||
|
||||
Verifies:
|
||||
- get_public_user_directory() automatically blocks System Users
|
||||
- New endpoints using this function are automatically protected
|
||||
"""
|
||||
|
||||
def test_get_public_user_directory_blocks_system_user(self):
|
||||
"""
|
||||
Any code using get_public_user_directory() is automatically protected.
|
||||
"""
|
||||
# This is the structural security - any new endpoint using this function
|
||||
# will automatically block System Users
|
||||
assert folder_paths.get_public_user_directory("__system") is None
|
||||
assert folder_paths.get_public_user_directory("__cache") is None
|
||||
assert folder_paths.get_public_user_directory("__anything") is None
|
||||
|
||||
# Public Users work
|
||||
assert folder_paths.get_public_user_directory("default") is not None
|
||||
assert folder_paths.get_public_user_directory("user123") is not None
|
||||
|
||||
def test_structural_security_pattern(self, mock_user_directory):
|
||||
"""
|
||||
Demonstrate the structural security pattern for new endpoints.
|
||||
|
||||
Any new endpoint should follow this pattern:
|
||||
1. Get user from request
|
||||
2. Use get_public_user_directory() - automatically blocks System Users
|
||||
3. If None, return error
|
||||
"""
|
||||
def new_endpoint_handler(user_id: str) -> str | None:
|
||||
"""Example of how new endpoints should be implemented."""
|
||||
user_path = folder_paths.get_public_user_directory(user_id)
|
||||
if user_path is None:
|
||||
return None # Blocked
|
||||
return user_path
|
||||
|
||||
# System Users are automatically blocked
|
||||
assert new_endpoint_handler("__system") is None
|
||||
assert new_endpoint_handler("__secret") is None
|
||||
|
||||
# Public Users work
|
||||
assert new_endpoint_handler("default") is not None
|
||||
Loading…
x
Reference in New Issue
Block a user