Merge branch 'master' into v3-improvements

This commit is contained in:
Jedrzej Kosinski 2025-12-03 11:07:10 -08:00
commit 203a4e9b46
5 changed files with 183 additions and 36 deletions

View File

@ -704,7 +704,7 @@ class ModelPatcher:
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
@ -883,7 +883,7 @@ class ModelPatcher:
break break
module_offload_mem, module_mem, n, m, params = unload module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:

View File

@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None: if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function: comfy.model_management.sync_stream(device, offload_stream)
with wf_context:
for f in s.bias_function: bias_a = bias
bias = f(bias) weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context: weight = weight.to(dtype=dtype)
weight = weight.to(dtype=dtype) if isinstance(weight, QuantizedTensor):
if isinstance(weight, QuantizedTensor): weight = weight.dequantize()
weight = weight.dequantize() for f in s.weight_function:
for f in s.weight_function: weight = f(weight)
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable: if offloadable:
return weight, bias, offload_stream return weight, bias, (offload_stream, weight_a, bias_a)
else: else:
#Legacy function signature #Legacy function signature
return weight, bias return weight, bias
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None: if offload_stream is None:
return return
if weight is not None: os, weight_a, bias_a = offload_stream
device = weight.device if os is None:
return
if weight_a is not None:
device = weight_a.device
else: else:
if bias is None: if bias_a is None:
return return
device = bias.device device = bias_a.device
offload_stream.wait_stream(comfy.model_management.current_stream(device)) os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:

View File

@ -5,7 +5,7 @@ import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter from collections import Counter
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final from typing_extensions import NotRequired, final
@ -1210,9 +1210,9 @@ class Schema:
"""Display name of node.""" """Display name of node."""
category: str = "sd" category: str = "sd"
"""The category of the node, as per the "Add Node" menu.""" """The category of the node, as per the "Add Node" menu."""
inputs: list[Input]=None inputs: list[Input] = field(default_factory=list)
outputs: list[Output]=None outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden]=None hidden: list[Hidden] = field(default_factory=list)
description: str="" description: str=""
"""Node description, shown as a tooltip when hovering over the node.""" """Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False is_input_list: bool = False

View File

@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel):
url: str | None = Field(None, description="URL for generated video") url: str | None = Field(None, description="URL for generated video")
class TaskStatusVideoResults(BaseModel): class TaskStatusImageResult(BaseModel):
index: int = Field(..., description="Image Number0-9")
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None) videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusVideoResponseData(BaseModel): class OmniTaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time") created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time") updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID") task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusVideoResults | None = Field(None) task_result: OmniTaskStatusResults | None = Field(None)
class TaskStatusVideoResponse(BaseModel): class OmniTaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code") code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message") message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID") request_id: str | None = Field(None, description="Request ID")
data: TaskStatusVideoResponseData | None = Field(None) data: OmniTaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
image: str = Field(...)
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)

View File

@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
import logging import logging
import math import math
import re
import torch import torch
from typing_extensions import override from typing_extensions import override
@ -49,12 +50,14 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.kling_api import ( from comfy_api_nodes.apis.kling_api import (
OmniImageParamImage,
OmniParamImage, OmniParamImage,
OmniParamVideo, OmniParamVideo,
OmniProFirstLastFrameRequest, OmniProFirstLastFrameRequest,
OmniProImageRequest,
OmniProReferences2VideoRequest, OmniProReferences2VideoRequest,
OmniProText2VideoRequest, OmniProText2VideoRequest,
TaskStatusVideoResponse, OmniTaskStatusResponse,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -210,7 +213,36 @@ VOICES_CONFIG = {
} }
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: def normalize_omni_prompt_references(prompt: str) -> str:
"""
Rewrites Kling Omni-style placeholders used in the app, like:
@image, @image1, @image2, ... @imageN
@video, @video1, @video2, ... @videoN
into the API-compatible form:
<<<image_1>>>, <<<image_2>>>, ...
<<<video_1>>>, <<<video_2>>>, ...
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
"""
if not prompt:
return prompt
def _image_repl(match):
return f"<<<image_{match.group('idx') or '1'}>>>"
def _video_repl(match):
return f"<<<video_{match.group('idx') or '1'}>>>"
# (?<!\w) avoids matching e.g. "test@image.com"
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
if response.code: if response.code:
raise RuntimeError( raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
data=OmniProText2VideoRequest( data=OmniProText2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None, end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None: if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
data=OmniProFirstLastFrameRequest( data=OmniProFirstLastFrameRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int, duration: int,
reference_images: Input.Image, reference_images: Input.Image,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7: if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.") raise ValueError("The maximum number of reference images is 7.")
@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool, keep_original_sound: bool,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool, keep_original_sound: bool,
reference_images: Input.Image | None = None, reference_images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05) validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse, response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode):
return await finish_omni_video_task(cls, response) return await finish_omni_video_task(cls, response)
class OmniProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling Omni Image (Pro)",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
resolution: str,
aspect_ratio: str,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 10:
raise ValueError("The maximum number of reference images is 10.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
class KlingCameraControlT2VNode(IO.ComfyNode): class KlingCameraControlT2VNode(IO.ComfyNode):
""" """
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode, OmniProImageToVideoNode,
OmniProVideoToVideoNode, OmniProVideoToVideoNode,
OmniProEditVideoNode, OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
] ]