diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5a5addd01..d89339c3d 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -5,7 +5,7 @@ import inspect from abc import ABC, abstractmethod from collections import Counter from collections.abc import Iterable -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing_extensions import NotRequired, final @@ -1210,9 +1210,9 @@ class Schema: """Display name of node.""" category: str = "sd" """The category of the node, as per the "Add Node" menu.""" - inputs: list[Input]=None - outputs: list[Output]=None - hidden: list[Hidden]=None + inputs: list[Input] = field(default_factory=list) + outputs: list[Output] = field(default_factory=list) + hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 0a3b447c5..d8949f8ac 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel): url: str | None = Field(None, description="URL for generated video") -class TaskStatusVideoResults(BaseModel): +class TaskStatusImageResult(BaseModel): + index: int = Field(..., description="Image Number,0-9") + url: str = Field(..., description="URL for generated image") + + +class OmniTaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) + images: list[TaskStatusImageResult] | None = Field(None) -class TaskStatusVideoResponseData(BaseModel): +class OmniTaskStatusResponseData(BaseModel): created_at: int | None = Field(None, description="Task creation time") updated_at: int | None = Field(None, description="Task update time") task_status: str | None = None task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_id: str | None = Field(None, description="Task ID") - task_result: TaskStatusVideoResults | None = Field(None) + task_result: OmniTaskStatusResults | None = Field(None) -class TaskStatusVideoResponse(BaseModel): +class OmniTaskStatusResponse(BaseModel): code: int | None = Field(None, description="Error code") message: str | None = Field(None, description="Error message") request_id: str | None = Field(None, description="Request ID") - data: TaskStatusVideoResponseData | None = Field(None) + data: OmniTaskStatusResponseData | None = Field(None) + + +class OmniImageParamImage(BaseModel): + image: str = Field(...) + + +class OmniProImageRequest(BaseModel): + model_name: str = Field(..., description="kling-image-o1") + resolution: str = Field(..., description="'1k' or '2k'") + aspect_ratio: str | None = Field(...) + prompt: str = Field(...) + mode: str = Field("pro") + n: int | None = Field(1, le=9) + image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 850c44db6..6c840dc47 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere import logging import math +import re import torch from typing_extensions import override @@ -49,12 +50,14 @@ from comfy_api_nodes.apis import ( KlingSingleImageEffectModelName, ) from comfy_api_nodes.apis.kling_api import ( + OmniImageParamImage, OmniParamImage, OmniParamVideo, OmniProFirstLastFrameRequest, + OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, - TaskStatusVideoResponse, + OmniTaskStatusResponse, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -210,7 +213,36 @@ VOICES_CONFIG = { } -async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: +def normalize_omni_prompt_references(prompt: str) -> str: + """ + Rewrites Kling Omni-style placeholders used in the app, like: + + @image, @image1, @image2, ... @imageN + @video, @video1, @video2, ... @videoN + + into the API-compatible form: + + <<>>, <<>>, ... + <<>>, <<>>, ... + + This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app. + """ + if not prompt: + return prompt + + def _image_repl(match): + return f"<<>>" + + def _video_repl(match): + return f"<<>>" + + # (? and not @imageFoo + prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt) + return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) + + +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: if response.code: raise RuntimeError( f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" @@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProText2VideoRequest( model_name=model_name, prompt=prompt, @@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") @@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProFirstLastFrameRequest( model_name=model_name, prompt=prompt, @@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): duration: int, reference_images: Input.Image, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) if get_number_of_images(reference_images) > 7: raise ValueError("The maximum number of reference images is 7.") @@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode): keep_original_sound: bool, reference_images: Input.Image | None = None, ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(video, min_duration=3.0, max_duration=10.05) validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) @@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode): response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusVideoResponse, + response_model=OmniTaskStatusResponse, data=OmniProReferences2VideoRequest( model_name=model_name, prompt=prompt, @@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode): return await finish_omni_video_task(cls, response) +class OmniProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageNode", + display_name="Kling Omni Image (Pro)", + category="api node/image/Kling", + description="Create or edit images with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the image content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], + ), + IO.Image.Input( + "reference_images", + tooltip="Up to 10 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + resolution: str, + aspect_ratio: str, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + image_list: list[OmniImageParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 10: + raise ValueError("The maximum number of reference images is 10.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniImageParamImage(image=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), + response_model=OmniTaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension): OmniProImageToVideoNode, OmniProVideoToVideoNode, OmniProEditVideoNode, + # OmniProImageNode, # need support from backend ]