From ecdc8697d53919a9178bf53ef327a110582db8ea Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:49:28 -0800 Subject: [PATCH 1/6] Qwen Image Lora training fix from #11090 (#11094) --- comfy_extras/nodes_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index cb24ab709..19b8baaf4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode): noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents, From ea17add3c62197b10fd0b71d9169d339adc55c47 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:15:15 -0800 Subject: [PATCH 2/6] Fix case where text encoders where running on the CPU instead of GPU. (#11095) --- comfy/sd.py | 2 ++ comfy/sd1_clip.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5..734bd2845 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -193,6 +193,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -240,6 +241,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3db..503a51843 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: @@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None From 6be85c7920224b45bbc6417e00147815e78c12a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:28:44 +1000 Subject: [PATCH 3/6] mp: use look-ahead actuals for stream offload VRAM calculation (#11096) TIL that the WAN TE has a 2GB weight followed by 16MB as the next size down. This means that team 8GB VRAM would fully offload the TE in async offload mode as it just multiplied this giant size my the num streams. Do the more complex logic of summing up the upcoming to-load weight sizes to avoid triple counting this massive weight. partial unload does the converse of recording the NS most recent unloads as they go. --- comfy/model_patcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index df2d8e827..3dcac3eef 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,12 +699,12 @@ class ModelPatcher: offloaded = [] offload_buffer = 0 loading.sort(reverse=True) - for x in loading: + for i, x in enumerate(loading): module_offload_mem, module_mem, n, m, params = x lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) + potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]])) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -876,14 +876,18 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + if len(unload_list) > 0: + NS = comfy.model_management.NUM_STREAMS + offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -935,6 +939,8 @@ class ModelPatcher: m.comfy_patched_weights = False memory_freed += module_mem offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) logging.debug("freed {}".format(n)) for param in params: From f4bdf5f8302ef10db99644a8672e614ddb29c473 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:04 +1000 Subject: [PATCH 4/6] sd: revise hy VAE VRAM (#11105) This was recently collapsed down to rolling VAE through temporal. Clamp The time dimension. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 734bd2845..fe4dd65f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -483,8 +483,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + #This is likely to significantly over-estimate with single image or low frame counts as the + #implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) From 9bc893c5bbd2838bdd15ebd40e3a3e548ce3e4f0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:50:36 +1000 Subject: [PATCH 5/6] sd: bump HY1.5 VAE estimate (#11107) Im able to push vram above estimate on partial unload. Bump the estimate. This is experimentally determined with a 720P and 480P datapoint calibrating for 24GB VRAM total. --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index fe4dd65f8..03bdb33d5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -471,7 +471,7 @@ class VAE: decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True From 3c8456223c5f6a41af7d99219b391c8c58acb552 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 5 Dec 2025 00:05:28 +0200 Subject: [PATCH 6/6] [API Nodes]: fixes and refactor (#11104) * chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder * chore(api-nodes): add validate_video_frame_count function from LTX PR * chore(api-nodes): replace deprecated V1 imports * fix(api-nodes): the types returned by the "poll_op" function are now correct. --- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/_helpers.py | 14 +-- comfy_api_nodes/util/client.py | 145 ++++++++++++----------- comfy_api_nodes/util/conversions.py | 21 ++-- comfy_api_nodes/util/download_helpers.py | 20 ++-- comfy_api_nodes/util/request_logger.py | 2 - comfy_api_nodes/util/upload_helpers.py | 16 ++- comfy_api_nodes/util/validation_utils.py | 61 ++++++---- 8 files changed, 146 insertions(+), 135 deletions(-) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 80292fb3c..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -47,6 +47,7 @@ from .validation_utils import ( validate_string, validate_video_dimensions, validate_video_duration, + validate_video_frame_count, ) __all__ = [ @@ -94,6 +95,7 @@ __all__ = [ "validate_string", "validate_video_dimensions", "validate_video_duration", + "validate_video_frame_count", # Misc functions "get_fs_object_size", ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 328fe5227..491e6b6a8 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -2,8 +2,8 @@ import asyncio import contextlib import os import time +from collections.abc import Callable from io import BytesIO -from typing import Callable, Optional, Union from comfy.cli_args import args from comfy.model_management import processing_interrupted @@ -35,12 +35,12 @@ def default_base_url() -> str: async def sleep_with_interrupt( seconds: float, - node_cls: Optional[type[IO.ComfyNode]], - label: Optional[str] = None, - start_ts: Optional[float] = None, - estimated_total: Optional[int] = None, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, *, - display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, ): """ Sleep in 1s slices while: @@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: +def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf01d7d36..bf37cba5f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -4,10 +4,11 @@ import json import logging import time import uuid +from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Literal, TypeVar from urllib.parse import urljoin, urlparse import aiohttp @@ -37,8 +38,8 @@ class ApiEndpoint: path: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", *, - query_params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, ): self.path = path self.method = method @@ -52,18 +53,18 @@ class _RequestConfig: endpoint: ApiEndpoint timeout: float content_type: str - data: Optional[dict[str, Any]] - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] - multipart_parser: Optional[Callable] + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None max_retries: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" monitor_progress: bool = True - estimated_total: Optional[int] = None - final_label_on_success: Optional[str] = "Completed" - progress_origin_ts: Optional[float] = None - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None @dataclass @@ -71,10 +72,10 @@ class _PollUIState: started: float status_label: str = "Queued" is_queued: bool = True - price: Optional[float] = None - estimated_duration: Optional[int] = None + price: float | None = None + estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: Optional[float] = None # start time of current active interval (None if queued) + active_since: float | None = None # start time of current active interval (None if queued) _RETRY_STATUS = {408, 429, 500, 502, 503, 504} @@ -87,20 +88,20 @@ async def sync_op( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - response_model: Type[M], - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - data: Optional[BaseModel] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, ) -> M: raw = await sync_op_raw( @@ -131,22 +132,22 @@ async def poll_op( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - response_model: Type[M], - status_extractor: Callable[[M], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[M], Optional[int]]] = None, - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[BaseModel] = None, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> M: raw = await poll_op_raw( @@ -178,22 +179,22 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, + estimated_duration: int | None = None, as_binary: bool = False, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, -) -> Union[dict[str, Any], bytes]: +) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). @@ -229,21 +230,21 @@ async def poll_op_raw( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> dict[str, Any]: """ @@ -261,7 +262,7 @@ async def poll_op_raw( consumed_attempts = 0 # counts only non-queued polls progress_bar = utils.ProgressBar(100) if progress_extractor else None - last_progress: Optional[int] = None + last_progress: int | None = None state = _PollUIState(started=started, estimated_duration=estimated_duration) stop_ticker = asyncio.Event() @@ -420,10 +421,10 @@ async def poll_op_raw( def _display_text( node_cls: type[IO.ComfyNode], - text: Optional[str], + text: str | None, *, - status: Optional[Union[str, int]] = None, - price: Optional[float] = None, + status: str | int | None = None, + price: float | None = None, ) -> None: display_lines: list[str] = [] if status: @@ -440,13 +441,13 @@ def _display_text( def _display_time_progress( node_cls: type[IO.ComfyNode], - status: Optional[Union[str, int]], + status: str | int | None, elapsed_seconds: int, - estimated_total: Optional[int] = None, + estimated_total: int | None = None, *, - price: Optional[float] = None, - is_queued: Optional[bool] = None, - processing_elapsed_seconds: Optional[int] = None, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str: def _snapshot_request_body_for_logging( content_type: str, method: str, - data: Optional[dict[str, Any]], - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], -) -> Optional[Union[dict[str, Any], str]]: + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: if method.upper() == "GET": return None if content_type == "multipart/form-data": @@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): attempt = 0 delay = cfg.retry_delay operation_succeeded: bool = False - final_elapsed_seconds: Optional[int] = None - extracted_price: Optional[float] = None + final_elapsed_seconds: int | None = None + extracted_price: float | None = None while True: attempt += 1 stop_event = asyncio.Event() - monitor_task: Optional[asyncio.Task] = None - sess: Optional[aiohttp.ClientSession] = None + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) @@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): ) -def _validate_or_raise(response_model: Type[M], payload: Any) -> M: +def _validate_or_raise(response_model: type[M], payload: Any) -> M: try: return response_model.model_validate(payload) except Exception as e: @@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _wrap_model_extractor( - response_model: Type[M], - extractor: Optional[Callable[[M], Any]], -) -> Optional[Callable[[dict[str, Any]], Any]]: + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: """Wrap a typed extractor so it can be used by the dict-based poller. Validates the dict into `response_model` before invoking `extractor`. Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating @@ -929,10 +930,10 @@ def _wrap_model_extractor( return _wrapped -def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: if not values: return set() - out: set[Union[str, int]] = set() + out: set[str | int] = set() for v in values: nv = _normalize_status_value(v) if nv is not None: @@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio return out -def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: +def _normalize_status_value(val: str | int | None) -> str | int | None: if isinstance(val, str): return val.strip().lower() return val diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 971dc57de..c57457580 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -4,7 +4,6 @@ import math import mimetypes import uuid from io import BytesIO -from typing import Optional import av import numpy as np @@ -12,8 +11,7 @@ import torch from PIL import Image from comfy.utils import common_upscale -from comfy_api.latest import Input, InputImpl -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import Input, InputImpl, Types from ._helpers import mimetype_to_extension @@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: Optional[str] = None, + name: str | None = None, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co def video_to_base64_string( video: Input.Video, - container_format: VideoContainer = None, - codec: VideoCodec = None + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, ) -> str: """ Converts a video input to a base64 string. @@ -189,12 +187,11 @@ def video_to_base64_string( codec: Optional codec to use (defaults to video.codec if available) """ video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) video_bytes_io.seek(0) return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 14207dc68..3e0d0352d 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -3,15 +3,15 @@ import contextlib import uuid from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO from urllib.parse import urljoin, urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError -from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl from . import request_logger from ._helpers import ( @@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + dest: BytesIO | IO[bytes] | str | Path | None, *, - timeout: Optional[float] = None, + timeout: float | None = None, max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, @@ -71,10 +71,10 @@ async def download_url_to_bytesio( is_path_sink = isinstance(dest, (str, Path)) fhandle = None - session: Optional[aiohttp.ClientSession] = None - stop_evt: Optional[asyncio.Event] = None - monitor_task: Optional[asyncio.Task] = None - req_task: Optional[asyncio.Task] = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None try: with contextlib.suppress(Exception): @@ -234,11 +234,11 @@ async def download_url_to_video_output( timeout: float = None, max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, -) -> VideoFromFile: +) -> InputImpl.VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) - return VideoFromFile(result) + return InputImpl.VideoFromFile(result) async def download_url_as_bytesio( diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index ac52e2eab..e0cb4428d 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import datetime import hashlib import json diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 0532bea9a..b8d33f4d1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,15 +4,13 @@ import logging import time import uuid from io import BytesIO -from typing import Optional from urllib.parse import urlparse import aiohttp import torch from pydantic import BaseModel, Field -from comfy_api.latest import IO, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, Input, Types from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt @@ -32,7 +30,7 @@ from .conversions import ( class UploadRequest(BaseModel): file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( + content_type: str | None = Field( None, description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", ) @@ -56,7 +54,7 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ - # if batch, try to upload each file if max_images is greater than 0 + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 @@ -100,9 +98,9 @@ async def upload_video_to_comfyapi( cls: type[IO.ComfyNode], video: Input.Video, *, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, wait_label: str | None = "Uploading", ) -> str: """ @@ -220,7 +218,7 @@ async def upload_file( return monitor_task = asyncio.create_task(_monitor()) - sess: Optional[aiohttp.ClientSession] = None + sess: aiohttp.ClientSession | None = None try: try: request_logger.log_request_response( diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ec7006aed..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,9 +1,7 @@ import logging -from typing import Optional import torch -from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -37,8 +35,8 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: @@ -54,8 +52,8 @@ def validate_image_aspect_ratio( def validate_images_aspect_ratio_closeness( first_image: torch.Tensor, second_image: torch.Tensor, - min_rel: float, # e.g. 0.8 - max_rel: float, # e.g. 1.25 + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness( def validate_aspect_ratio_string( aspect_ratio: str, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -97,10 +95,10 @@ def validate_aspect_ratio_string( def validate_video_dimensions( video: Input.Video, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -120,8 +118,8 @@ def validate_video_dimensions( def validate_video_duration( video: Input.Video, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -136,6 +134,23 @@ def validate_video_duration( raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") + + def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 @@ -144,8 +159,8 @@ def get_number_of_images(images): def validate_audio_duration( audio: Input.Audio, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ) -> None: sr = int(audio["sample_rate"]) dur = int(audio["waveform"].shape[-1]) / sr @@ -177,7 +192,7 @@ def validate_string( ) -def validate_container_format_is_mp4(video: VideoInput) -> None: +def validate_container_format_is_mp4(video: Input.Video) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: @@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float: def _assert_ratio_bounds( ar: float, *, - min_ratio: Optional[tuple[float, float]] = None, - max_ratio: Optional[tuple[float, float]] = None, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, strict: bool = True, ) -> None: """Validate a numeric aspect ratio against optional min/max ratio bounds."""