Merge branch 'master' into v3-improvements

This commit is contained in:
Jedrzej Kosinski 2025-12-03 11:07:10 -08:00
commit 203a4e9b46
5 changed files with 183 additions and 36 deletions

View File

@ -704,7 +704,7 @@ class ModelPatcher:
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@ -883,7 +883,7 @@ class ModelPatcher:
break
module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:

View File

@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
comfy.model_management.sync_stream(device, offload_stream)
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, offload_stream
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
os, weight_a, bias_a = offload_stream
if os is None:
return
if weight_a is not None:
device = weight_a.device
else:
if bias is None:
if bias_a is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:

View File

@ -5,7 +5,7 @@ import inspect
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@ -1210,9 +1210,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[Input]=None
outputs: list[Output]=None
hidden: list[Hidden]=None
inputs: list[Input] = field(default_factory=list)
outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False

View File

@ -46,21 +46,41 @@ class TaskStatusVideoResult(BaseModel):
url: str | None = Field(None, description="URL for generated video")
class TaskStatusVideoResults(BaseModel):
class TaskStatusImageResult(BaseModel):
index: int = Field(..., description="Image Number0-9")
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusVideoResponseData(BaseModel):
class OmniTaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusVideoResults | None = Field(None)
task_result: OmniTaskStatusResults | None = Field(None)
class TaskStatusVideoResponse(BaseModel):
class OmniTaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: TaskStatusVideoResponseData | None = Field(None)
data: OmniTaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
image: str = Field(...)
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)

View File

@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
import logging
import math
import re
import torch
from typing_extensions import override
@ -49,12 +50,14 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
TaskStatusVideoResponse,
OmniTaskStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -210,7 +213,36 @@ VOICES_CONFIG = {
}
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
def normalize_omni_prompt_references(prompt: str) -> str:
"""
Rewrites Kling Omni-style placeholders used in the app, like:
@image, @image1, @image2, ... @imageN
@video, @video1, @video2, ... @videoN
into the API-compatible form:
<<<image_1>>>, <<<image_2>>>, ...
<<<video_1>>>, <<<video_2>>>, ...
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
"""
if not prompt:
return prompt
def _image_repl(match):
return f"<<<image_{match.group('idx') or '1'}>>>"
def _video_repl(match):
return f"<<<video_{match.group('idx') or '1'}>>>"
# (?<!\w) avoids matching e.g. "test@image.com"
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -218,8 +250,9 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVi
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -801,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -864,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
@ -895,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@ -950,6 +984,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
@ -962,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1023,6 +1058,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1045,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1103,6 +1139,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1125,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1138,6 +1175,90 @@ class OmniProEditVideoNode(IO.ComfyNode):
return await finish_omni_video_task(cls, response)
class OmniProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling Omni Image (Pro)",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
resolution: str,
aspect_ratio: str,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 10:
raise ValueError("The maximum number of reference images is 10.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -1935,6 +2056,7 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
]