Compare commits

...

30 Commits

Author SHA1 Message Date
Comfy Org PR Bot
650e716dda
Bump comfyui-frontend-package to 1.35.9 (#11470)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-23 21:29:41 -08:00
comfyanonymous
e4c61d7555 ComfyUI v0.6.0 2025-12-23 20:50:02 -05:00
ComfyUI Wiki
22ff1bbfcb
chore: update workflow templates to v0.7.63 (#11482) 2025-12-23 20:48:45 -05:00
Alexander Piskun
f4f44bb807
api-nodes: use new custom endpoint for Nano Banana (#11311) 2025-12-23 12:10:27 -08:00
comfyanonymous
33aa808713
Make denoised output on custom sampler nodes work with nested tensors. (#11471) 2025-12-22 16:43:24 -05:00
ComfyUI Wiki
eb0e10aec4
Update workflow templates to v0.7.62 (#11467) 2025-12-22 16:02:41 -05:00
Alexander Piskun
c176b214cc
extend possible duration range for Kling O1 StartEndFrame node (#11451) 2025-12-21 22:44:49 -08:00
comfyanonymous
91bf6b6aa3
Add node to create empty latents for qwen image layered model. (#11460) 2025-12-21 19:59:40 -05:00
comfyanonymous
807538fe6c
Core release process. (#11447) 2025-12-20 20:02:02 -05:00
Alexander Piskun
bbb11e2608
fix(api-nodes): Topaz 4k video upscaling (#11438) 2025-12-20 08:48:28 -08:00
Alexander Piskun
0899012ad6
chore(api-nodes): by default set Watermark generation to False (#11437) 2025-12-19 22:24:37 -08:00
comfyanonymous
fb478f679a
Only apply gemma quant config to gemma model for newbie. (#11436) 2025-12-20 01:02:43 -05:00
woctordho
4c432c11ed
Implement Jina CLIP v2 and NewBie dual CLIP (#11415)
* Implement Jina CLIP v2

* Support quantized Gemma in NewBie dual CLIP
2025-12-20 00:57:22 -05:00
comfyanonymous
31e961736a
Fix issue with batches and newbie. (#11435) 2025-12-20 00:23:51 -05:00
rattus
767ee30f21
ZImageFunControlNet: Fix mask concatenation in --gpu-only (#11421)
This operation trades in latents which in --gpu-only may be out of the GPU
The two VAE results will follow the --gpu-only defined behaviour so follow
the inpaint image device when calculating the mask in this path.
2025-12-20 00:22:17 -05:00
comfyanonymous
3ab9748903
Disable prompt weights on newbie te. (#11434) 2025-12-20 00:19:47 -05:00
woctordho
0aa7fa464e
Implement sliding attention in Gemma3 (#11409) 2025-12-20 00:16:46 -05:00
drozbay
514c24d756
Fix error from logging line (#11423)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-19 20:22:45 -08:00
comfyanonymous
809ce68749
Support nested tensor denoise masks. (#11431) 2025-12-19 19:59:25 -05:00
BradPepersAMD
cc4ddba1b6
Allow enabling use of MIOpen by setting COMFYUI_ENABLE_MIOPEN=1 as an env var (#11366) 2025-12-19 17:01:50 -05:00
Dr.Lt.Data
8376ff6831
bump comfyui_manager version to the 4.0.3b7 (#11422) 2025-12-19 10:41:56 -08:00
Alexander Piskun
5b4d0664c8
add Flux2MaxImage API Node (#11420) 2025-12-19 10:02:49 -08:00
comfyanonymous
894802b0f9
Add LatentCutToBatch node. (#11411) 2025-12-18 22:21:40 -05:00
comfyanonymous
28eaab608b
Diffusion model part of Qwen Image Layered. (#11408)
Only thing missing after this is some nodes to make using it easier.
2025-12-18 20:21:14 -05:00
comfyanonymous
6a2678ac65
Trim/pad channels in VAE code. (#11406) 2025-12-18 18:22:38 -05:00
comfyanonymous
e4fb3a3572
Support loading Wan/Qwen VAEs with different in/out channels. (#11405) 2025-12-18 17:45:33 -05:00
ComfyUI Wiki
e8ebbe668e
chore: update workflow templates to v0.7.60 (#11403) 2025-12-18 17:09:29 -05:00
ric-yu
1ca89b810e
Add unified jobs API with /api/jobs endpoints (#11054)
* feat: create a /jobs api to return queue and history jobs

* update unused vars

* include priority

* create jobs helper file

* fix ruff

* update how we set error message

* include execution error in both responses

* rename error -> failed, fix output shape

* re-use queue and history functions

* set workflow id

* allow srot by exec duration

* fix tests

* send priority and remove error msg

* use ws messages to get start and end times

* revert main.py fully

* refactor: move all /jobs business logic to jobs.py

* fix failing test

* remove some tests

* fix non dict nodes

* address comments

* filter by workflow id and remove null fields

* add clearer typing - remove get("..") or ..

* refactor query params to top get_job(s) doc, add remove_sensitive_from_queue

* add brief comment explaining why we skip animated

* comment that format field is for frontend backward compatibility

* fix whitespace

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2025-12-17 21:44:31 -08:00
comfyanonymous
bf7dc63bd6
skip_load_model -> force_full_load (#11390)
This should be a bit more clear and less prone to potential breakage if the
logic of the load models changes a bit.
2025-12-17 23:29:32 -05:00
Kohaku-Blueleaf
86dbb89fc9
Resolution bucketing and Trainer implementation refactoring (#11117) 2025-12-17 22:15:27 -05:00
40 changed files with 2276 additions and 381 deletions

View File

@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release

View File

@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:

View File

@ -625,7 +625,7 @@ class NextDiT(nn.Module):
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations
)
def forward(self, timestep, hidden_states):
self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
use_additional_t_cond=False,
dtype=None,
device=None,
operations=None,
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype,
device=device,
operations=operations
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward(
self,
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
additional_t_cond=None,
transformer_options={},
control=None,
**kwargs
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):

View File

@ -1110,7 +1110,7 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)

View File

@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@ -620,6 +621,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -26,6 +26,7 @@ import importlib
import platform
import weakref
import gc
import os
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -333,11 +334,13 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")

View File

@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@ -984,9 +984,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
@ -1013,6 +1010,24 @@ class CFGGuider:
else:
latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.model_patcher
import comfy.lora
@ -321,6 +323,7 @@ class VAE:
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
@ -435,6 +438,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
@ -546,7 +550,9 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@ -582,6 +588,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
@ -690,9 +697,7 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
if self.crop_input:
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
@ -701,6 +706,19 @@ class VAE:
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
if pixels.shape[-1] > self.output_channels:
pixels = pixels[..., :self.output_channels]
elif pixels.shape[-1] < self.output_channels:
if self.pad_channel_value is not None:
if isinstance(self.pad_channel_value, str):
mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -992,6 +1010,7 @@ class CLIPType(Enum):
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
NEWBIE = 24
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -1022,6 +1041,7 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
JINA_CLIP_2 = 18
def detect_te_model(sd):
@ -1031,6 +1051,8 @@ def detect_te_model(sd):
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
@ -1191,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
else:
# clip_l
if clip_type == CLIPType.SD3:
@ -1246,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -513,6 +513,8 @@ class SDTokenizer:
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.disable_weights = disable_weights
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
@ -547,7 +549,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
if kwargs.get("disable_weights", False):
if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)

View File

@ -0,0 +1,219 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -3,7 +3,6 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
@ -186,8 +185,8 @@ class Gemma3_4B_Config:
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
class RMSNorm(nn.Module):
@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]

View File

@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel):
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
watermark: bool | None = Field(False)
class Image2ImageTaskCreationRequest(BaseModel):
@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel):
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
watermark: bool | None = Field(False)
class Seedream4Options(BaseModel):
@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
watermark: bool = Field(False)
class ImageTaskCreationResponse(BaseModel):

View File

@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
uploadImagesToStorage: bool = Field(True)
class GeminiGenerateContentRequest(BaseModel):

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch
from pydantic import BaseModel
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
)
def convert_mask_to_image(mask: torch.Tensor):
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
"prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False,
raw: bool = False,
seed: int = 0,
image_prompt: torch.Tensor | None = None,
image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1,
) -> IO.NodeOutput:
if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
"prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str,
guidance: float,
steps: int,
input_image: torch.Tensor | None = None,
input_image: Input.Image | None = None,
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
prompt_upsampling: bool,
top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
mask: torch.Tensor,
image: Input.Image,
mask: Input.Image,
prompt: str,
prompt_upsampling: bool,
steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ProImageNode",
display_name="Flux.2 [pro] Image",
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
default=True,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
"If active, automatically modifies the prompt for more creative generation.",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
],
outputs=[IO.Image.Output()],
hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int,
seed: int,
prompt_upsampling: bool,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
) -> IO.NodeOutput:
reference_images = {}
if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode,
FluxProFillNode,
Flux2ProImageNode,
Flux2MaxImageNode,
]

View File

@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True,
),
@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
sequential_image_generation: str = "disabled",
max_images: int = 1,
seed: int = 0,
watermark: bool = True,
watermark: bool = False,
fail_on_partial: bool = True,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),

View File

@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
if part_type == "text" and part.text:
parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@ -858,7 +858,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("duration", options=["5", "10"]),
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
@ -897,6 +897,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
if duration not in (5, 10) and end_frame is None and reference_images is None:
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [

View File

@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
}
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode):
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
options=["low", "middle", "high"],
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
target_frame_rate = src_frame_rate
filters = []
if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
if "1080p" in upscaler_resolution:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
filters.append(
topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model],

View File

@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel):
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
class Image2ImageParametersField(BaseModel):
size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
watermark: bool = Field(False)
class Text2VideoParametersField(BaseModel):
@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode):
height: int = 1024,
seed: int = 0,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
):
initial_response = await sync_op(
cls,
@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode):
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
watermark: bool = False,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode):
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
shot_type: str = "single",
):
if "480p" in size and model == "wan2.6-t2v":
@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode):
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
shot_type: str = "single",
):
if get_number_of_images(image) != 1:

291
comfy_execution/jobs.py Normal file
View File

@ -0,0 +1,291 @@
"""
Job utilities for the /api/jobs endpoint.
Provides normalization and helper functions for job status tracking.
"""
from typing import Optional
from comfy_api.internal import prune_dict
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
Returns:
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
return create_time, workflow_id
def is_previewable(media_type: str, item: dict) -> bool:
"""
Check if an output item is previewable.
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', or 'audio'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
# Check format field (MIME type).
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
fmt = item.get('format', '')
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
return True
# Check for 3D files by extension
filename = item.get('filename', '').lower()
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
return True
return False
def normalize_queue_item(item: tuple, status: str) -> dict:
"""Convert queue item tuple to unified job dict.
Expects item with sensitive data already removed (5 elements).
"""
priority, prompt_id, _, extra_data, _ = item
create_time, workflow_id = _extract_job_metadata(extra_data)
return prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'outputs_count': 0,
'workflow_id': workflow_id,
})
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
"""Convert history item dict to unified job dict.
History items have sensitive data already removed (prompt tuple has 5 elements).
"""
prompt_tuple = history_item['prompt']
priority, _, prompt, extra_data, _ = prompt_tuple
create_time, workflow_id = _extract_job_metadata(extra_data)
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
execution_error = None
execution_start_time = None
execution_end_time = None
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
event_name, event_data = entry[0], entry[1]
if isinstance(event_data, dict):
if event_name == 'execution_start':
execution_start_time = event_data.get('timestamp')
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
job = prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'execution_start_time': execution_start_time,
'execution_end_time': execution_end_time,
'execution_error': execution_error,
'outputs_count': outputs_count,
'preview_output': preview_output,
'workflow_id': workflow_id,
})
if include_outputs:
job['outputs'] = outputs
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
'extra_data': extra_data,
}
return job
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
"""
Count outputs and find preview in a single pass.
Returns (outputs_count, preview_output).
Preview priority (matching frontend):
1. type="output" with previewable media
2. Any previewable media
"""
count = 0
preview_output = None
fallback_preview = None
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
continue
for media_type, items in node_outputs.items():
# 'animated' is a boolean flag, not actual output items
if media_type == 'animated' or not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
count += 1
if preview_output is None and is_previewable(media_type, item):
enriched = {
**item,
'nodeId': node_id,
'mediaType': media_type
}
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
return count, preview_output or fallback_preview
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
"""Sort jobs list by specified field and order."""
reverse = (sort_order == 'desc')
if sort_by == 'execution_duration':
def get_sort_key(job):
start = job.get('execution_start_time', 0)
end = job.get('execution_end_time', 0)
return end - start if end and start else 0
else:
def get_sort_key(job):
return job.get('create_time', 0)
return sorted(jobs, key=get_sort_key, reverse=reverse)
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
"""
Get a single job by prompt_id from history or queue.
Args:
prompt_id: The prompt ID to look up
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
Returns:
Job dict with full details, or None if not found
"""
if prompt_id in history:
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
for item in running:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
for item in queued:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.PENDING)
return None
def get_all_jobs(
running: list,
queued: list,
history: dict,
status_filter: Optional[list[str]] = None,
workflow_id: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: Optional[int] = None,
offset: int = 0
) -> tuple[list[dict], int]:
"""
Get all jobs (running, pending, completed) with filtering and sorting.
Args:
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
offset: Number of items to skip
Returns:
tuple: (jobs_list, total_count)
"""
jobs = []
if status_filter is None:
status_filter = JobStatus.ALL
if JobStatus.IN_PROGRESS in status_filter:
for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
if JobStatus.PENDING in status_filter:
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
jobs = apply_sorting(jobs, sort_by, sort_order)
total_count = len(jobs)
if offset > 0:
jobs = jobs[offset:]
if limit is not None:
jobs = jobs[:limit]
return (jobs, total_count)

View File

@ -760,8 +760,12 @@ class SamplerCustom(io.ComfyNode):
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
out_denoised["samples"] = x0_out
else:
out_denoised = out
return io.NodeOutput(out, out_denoised)
@ -948,8 +952,12 @@ class SamplerCustomAdvanced(io.ComfyNode):
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
out_denoised["samples"] = x0_out
else:
out_denoised = out
return io.NodeOutput(out, out_denoised)

View File

@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[-2], latent.shape[-1]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]

View File

@ -5,6 +5,7 @@ import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
import math
def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]:
@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode):
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return io.NodeOutput(samples_out)
class LatentCutToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if dim < 2:
return io.NodeOutput(samples)
s = s1.movedim(dim, 1)
if s.shape[1] < slice_size:
slice_size = s.shape[1]
elif s.shape[1] % slice_size != 0:
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
new_shape = [-1, slice_size] + list(s.shape[2:])
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
return io.NodeOutput(samples_out)
class LatentBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension):
LatentInterpolate,
LatentConcat,
LatentCut,
LatentCutToBatch,
LatentBatch,
LatentBatchSeedBehavior,
LatentApplyOperation,

View File

@ -348,7 +348,7 @@ class ZImageControlPatch:
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))

View File

@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
],
outputs=[
io.Image.Output(),
@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
)
@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
total = megapixels * 1024 * 1024
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1,-1)
return io.NodeOutput(s)

View File

@ -3,7 +3,9 @@ import comfy.utils
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import nodes
class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod
@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
return io.NodeOutput(conditioning)
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyQwenImageLayeredLatentImage",
display_name="Empty Qwen Image Layered Latent",
category="latent/qwen",
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
EmptyQwenImageLayeredLatentImage,
]

View File

@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.samplers
import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
@ -115,47 +201,59 @@ class TrainSampler(comfy.samplers.Sampler):
bwd_loss.backward()
return loss
def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(min(self.batch_size, dataset_size))
for _ in range(batch_size)
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
@ -171,7 +269,10 @@ class TrainSampler(comfy.samplers.Sampler):
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
else:
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
@ -202,6 +303,41 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
@ -283,6 +419,364 @@ def unpatch(m):
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter.train().requires_grad_(True), lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
],
outputs=[
io.Model.Output(
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -427,181 +927,92 @@ class TrainLoraNode(io.ComfyNode):
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer = optimizer[0]
loss_function = loss_function[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
latents = _process_latents_standard_mode(latents)
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
existing_weights, existing_steps = _load_existing_lora(existing_lora)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(
lora_dtype
# Create optimizer and loss function
optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
criterion = _create_loss_function(loss_function_name)
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
)
torch.cuda.empty_cache()
# Setup sampler and guider like in test script
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
@ -613,29 +1024,28 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Training loop
# Setup guider
guider = TrainGuider(mp)
guider.set_conds(positive)
# Run training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
_run_training_loop(
guider,
train_sampler,
sigmas,
seed=noise.seed,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode):
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode):#
@classmethod
def define_schema(cls):
return io.Schema(

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.5.1"
__version__ = "0.6.0"

View File

@ -1 +1 @@
comfyui_manager==4.0.3b5
comfyui_manager==4.0.3b7

View File

@ -343,7 +343,7 @@ class VAEEncode:
CATEGORY = "latent"
def encode(self, vae, pixels):
t = vae.encode(pixels[:,:,:,:3])
t = vae.encode(pixels)
return ({"samples":t}, )
class VAEEncodeTiled:
@ -361,7 +361,7 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
class VAEEncodeForInpaint:
@ -970,7 +970,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -980,7 +980,7 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.5.1"
version = "0.6.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.34.9
comfyui-workflow-templates==0.7.59
comfyui-frontend-package==1.35.9
comfyui-workflow-templates==0.7.63
comfyui-embedded-docs==0.3.1
torch
torchsde

135
server.py
View File

@ -7,6 +7,7 @@ import time
import nodes
import folder_paths
import execution
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
import uuid
import urllib
import json
@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control
if args.enable_manager:
import comfyui_manager
def _remove_sensitive_from_queue(queue: list) -> list:
"""Remove sensitive data (index 5) from queue item tuples."""
return [item[:5] for item in queue]
async def send_socket_catch_exception(function, message):
try:
await function(message)
@ -694,6 +701,129 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.
Query parameters:
status: Filter by status (comma-separated): pending, in_progress, completed, failed
workflow_id: Filter by workflow ID
sort_by: Sort field: created_at (default), execution_duration
sort_order: Sort direction: asc, desc (default)
limit: Max items to return (positive integer)
offset: Items to skip (non-negative integer, default 0)
"""
query = request.rel_url.query
status_param = query.get('status')
workflow_id = query.get('workflow_id')
sort_by = query.get('sort_by', 'created_at').lower()
sort_order = query.get('sort_order', 'desc').lower()
status_filter = None
if status_param:
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
if invalid_statuses:
return web.json_response(
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
status=400
)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
status=400
)
if sort_order not in {'asc', 'desc'}:
return web.json_response(
{"error": "sort_order must be 'asc' or 'desc'"},
status=400
)
limit = None
# If limit is provided, validate that it is a positive integer, else continue without a limit
if 'limit' in query:
try:
limit = int(query.get('limit'))
if limit <= 0:
return web.json_response(
{"error": "limit must be a positive integer"},
status=400
)
except (ValueError, TypeError):
return web.json_response(
{"error": "limit must be an integer"},
status=400
)
offset = 0
if 'offset' in query:
try:
offset = int(query.get('offset'))
if offset < 0:
offset = 0
except (ValueError, TypeError):
return web.json_response(
{"error": "offset must be an integer"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history()
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
jobs, total = get_all_jobs(
running, queued, history,
status_filter=status_filter,
workflow_id=workflow_id,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset
)
has_more = (offset + len(jobs)) < total
return web.json_response({
'jobs': jobs,
'pagination': {
'offset': offset,
'limit': limit,
'total': total,
'has_more': has_more
}
})
@routes.get("/api/jobs/{job_id}")
async def get_job_by_id(request):
"""Get a single job by ID."""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history(prompt_id=job_id)
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
job = get_job(job_id, running, queued, history)
if job is None:
return web.json_response(
{"error": "Job not found"},
status=404
)
return web.json_response(job)
@routes.get("/history")
async def get_history(request):
max_items = request.rel_url.query.get("max_items", None)
@ -717,9 +847,8 @@ class PromptServer():
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
remove_sensitive = lambda queue: [x[:5] for x in queue]
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")

View File

@ -99,6 +99,37 @@ class ComfyClient:
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None):
url = "http://{}/api/jobs".format(self.server_address)
params = {}
if status is not None:
params["status"] = status
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
if params:
url_values = urllib.parse.urlencode(params)
url = "{}?{}".format(url, url_values)
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def get_job(self, job_id):
url = "http://{}/api/jobs/{}".format(self.server_address, job_id)
try:
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
if e.code == 404:
return None
raise
def set_test_name(self, name):
self.test_name = name
@ -877,3 +908,106 @@ class TestExecution:
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
# Jobs API tests
def test_jobs_api_job_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
assert "id" in job, "Job should have id"
assert "status" in job, "Job should have status"
assert "create_time" in job, "Job should have create_time"
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
def test_jobs_api_preview_output_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
preview = job["preview_output"]
assert "filename" in preview, "Preview should have filename"
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
def test_jobs_api_pagination(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
self._create_history_item(client, builder)
first_page = client.get_jobs(limit=2, offset=0)
second_page = client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
first_ids = {j["id"] for j in first_page["jobs"]}
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
def test_jobs_api_sorting(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
self._create_history_item(client, builder)
desc_jobs = client.get_jobs(sort_order="desc")
asc_jobs = client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
if len(desc_times) >= 2:
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
def test_jobs_api_status_filter(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
self._create_history_item(client, builder)
completed_jobs = client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
def test_get_job_by_id(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"
def test_get_job_not_found(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"

View File

@ -0,0 +1,361 @@
"""Unit tests for comfy_execution/jobs.py"""
from comfy_execution.jobs import (
JobStatus,
is_previewable,
normalize_queue_item,
normalize_history_item,
get_outputs_summary,
apply_sorting,
)
class TestJobStatus:
"""Test JobStatus constants."""
def test_status_values(self):
"""Status constants should have expected string values."""
assert JobStatus.PENDING == 'pending'
assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed'
def test_all_contains_all_statuses(self):
"""ALL should contain all status values."""
assert JobStatus.PENDING in JobStatus.ALL
assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4
class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
"""Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
def test_3d_extensions_case_insensitive(self):
"""3D extension check should be case insensitive."""
item = {'filename': 'MODEL.GLB'}
assert is_previewable('files', item) is True
def test_video_format_previewable(self):
"""Items with video/ format should be previewable."""
item = {'format': 'video/mp4'}
assert is_previewable('files', item) is True
def test_audio_format_previewable(self):
"""Items with audio/ format should be previewable."""
item = {'format': 'audio/wav'}
assert is_previewable('files', item) is True
def test_other_format_not_previewable(self):
"""Items with other format should not be previewable."""
item = {'format': 'application/json'}
assert is_previewable('files', item) is False
class TestGetOutputsSummary:
"""Unit tests for get_outputs_summary()"""
def test_empty_outputs(self):
"""Empty outputs should return 0 count and None preview."""
count, preview = get_outputs_summary({})
assert count == 0
assert preview is None
def test_counts_across_multiple_nodes(self):
"""Outputs from multiple nodes should all be counted."""
outputs = {
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
'node3': {'images': [
{'filename': 'c.png', 'type': 'output'},
{'filename': 'd.png', 'type': 'output'}
]}
}
count, preview = get_outputs_summary(outputs)
assert count == 4
def test_skips_animated_key_and_non_list_values(self):
"""The 'animated' key and non-list values should be skipped."""
outputs = {
'node1': {
'images': [{'filename': 'test.png', 'type': 'output'}],
'animated': [True], # Should skip due to key name
'metadata': 'string', # Should skip due to non-list
'count': 42 # Should skip due to non-list
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
def test_preview_prefers_type_output(self):
"""Items with type='output' should be preferred for preview."""
outputs = {
'node1': {
'images': [
{'filename': 'temp.png', 'type': 'temp'},
{'filename': 'output.png', 'type': 'output'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview['filename'] == 'output.png'
def test_preview_fallback_when_no_output_type(self):
"""If no type='output', should use first previewable."""
outputs = {
'node1': {
'images': [
{'filename': 'temp1.png', 'type': 'temp'},
{'filename': 'temp2.png', 'type': 'temp'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['filename'] == 'temp1.png'
def test_non_previewable_media_types_counted_but_no_preview(self):
"""Non-previewable media types should be counted but not used as preview."""
outputs = {
'node1': {
'latents': [
{'filename': 'latent1.safetensors'},
{'filename': 'latent2.safetensors'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview is None
def test_previewable_media_types(self):
"""Images, video, and audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
outputs = {
'node1': {
media_type: [{'filename': 'test.file', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"{media_type} should be previewable"
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"3D file {ext} should be previewable"
def test_format_mime_type_previewable(self):
"""Files with video/ or audio/ format should be previewable."""
for fmt in ['video/x-custom', 'audio/x-custom']:
outputs = {
'node1': {
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"Format {fmt} should be previewable"
def test_preview_enriched_with_node_metadata(self):
"""Preview should include nodeId, mediaType, and original fields."""
outputs = {
'node123': {
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['nodeId'] == 'node123'
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
class TestApplySorting:
"""Unit tests for apply_sorting()"""
def test_sort_by_create_time_desc(self):
"""Default sort by create_time descending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'desc')
assert [j['id'] for j in result] == ['b', 'c', 'a']
def test_sort_by_create_time_asc(self):
"""Sort by create_time ascending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'asc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_by_execution_duration(self):
"""Sort by execution_duration should order by duration."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
]
result = apply_sorting(jobs, 'execution_duration', 'desc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_with_none_values(self):
"""Jobs with None values should sort as 0."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
]
result = apply_sorting(jobs, 'execution_duration', 'asc')
assert result[0]['id'] == 'b' # None treated as 0, comes first
class TestNormalizeQueueItem:
"""Unit tests for normalize_queue_item()"""
def test_basic_normalization(self):
"""Queue item should be normalized to job dict."""
item = (
10, # priority/number
'prompt-123', # prompt_id
{'nodes': {}}, # prompt
{
'create_time': 1234567890,
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
}, # extra_data
['node1'], # outputs_to_execute
)
job = normalize_queue_item(item, JobStatus.PENDING)
assert job['id'] == 'prompt-123'
assert job['status'] == 'pending'
assert job['priority'] == 10
assert job['create_time'] == 1234567890
assert 'execution_start_time' not in job
assert 'execution_end_time' not in job
assert 'execution_error' not in job
assert 'preview_output' not in job
assert job['outputs_count'] == 0
assert job['workflow_id'] == 'workflow-abc'
class TestNormalizeHistoryItem:
"""Unit tests for normalize_history_item()"""
def test_completed_job(self):
"""Completed history item should have correct status and times from messages."""
history_item = {
'prompt': (
5, # priority
'prompt-456',
{'nodes': {}},
{
'create_time': 1234567890000,
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
},
['node1'],
),
'status': {
'status_str': 'success',
'completed': True,
'messages': [
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
]
},
'outputs': {},
}
job = normalize_history_item('prompt-456', history_item)
assert job['id'] == 'prompt-456'
assert job['status'] == 'completed'
assert job['priority'] == 5
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567893000
assert job['workflow_id'] == 'workflow-xyz'
def test_failed_job(self):
"""Failed history item should have failed status and error from messages."""
history_item = {
'prompt': (
5,
'prompt-789',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
('execution_error', {
'prompt_id': 'prompt-789',
'node_id': '5',
'node_type': 'KSampler',
'exception_message': 'CUDA out of memory',
'exception_type': 'RuntimeError',
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-789', history_item)
assert job['status'] == 'failed'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
assert job['execution_error']['node_id'] == '5'
assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_include_outputs(self):
"""When include_outputs=True, should include full output data."""
history_item = {
'prompt': (
5,
'prompt-123',
{'nodes': {'1': {}}},
{'create_time': 1234567890, 'client_id': 'abc'},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
}
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
assert 'outputs' in job
assert 'workflow' in job
assert 'execution_status' in job
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
assert job['workflow'] == {
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}