Compare commits

...

10 Commits

Author SHA1 Message Date
Kohaku-Blueleaf
9f018ddb3f
Merge 37139daa98c354c80065464e6d53c09fac8cd157 into fd271dedfde6e192a1f1a025521070876e89e04a 2025-12-08 15:37:19 +02:00
Alexander Piskun
fd271dedfd
[API Nodes] add support for seedance-1-0-pro-fast model (#10947)
* feat(api-nodes): add support for seedance-1-0-pro-fast model

* feat(api-nodes): add support for seedream-4.5 model
2025-12-08 01:33:46 -08:00
Alexander Piskun
c3c6313fc7
Added "system_prompt" input to Gemini nodes (#11177) 2025-12-08 01:28:17 -08:00
Alexander Piskun
85c4b4ae26
chore: replace imports of deprecated V1 classes (#11127) 2025-12-08 01:27:02 -08:00
ComfyUI Wiki
058f084371
Update workflow templates to v0.7.51 (#11150)
* chore: update workflow templates to v0.7.50

* Update template to 0.7.51
2025-12-08 01:22:51 -08:00
Alexander Piskun
ec7f65187d
chore(comfy_api): replace absolute imports with relative (#11145) 2025-12-08 01:21:41 -08:00
Kohaku-Blueleaf
37139daa98 Merge branch 'master' into resolution-bucket 2025-12-05 17:26:15 +08:00
Kohaku-Blueleaf
4004af3290 Custom guider for correct offloading behavior 2025-12-05 17:24:55 +08:00
Kohaku-Blueleaf
bf573e94a2 Refactoring with better layout for maintainability 2025-12-01 23:53:20 +08:00
Kohaku-Blueleaf
7a93c55a9f Add resolution bucketing 2025-12-01 23:31:26 +08:00
20 changed files with 1143 additions and 521 deletions

View File

@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
models_list = [model] if not skip_load_model else []
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model
return real_model, conds, models

View File

@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:

View File

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL

View File

@ -22,7 +22,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -0,0 +1,144 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):

View File

@ -1,13 +1,27 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedream-4-0-250828"],
options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor = None,
image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor,
image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
images: torch.Tensor,
images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
estimated_duration: Optional[int],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@ -1085,7 +934,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,8 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum):
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.String.Output(),
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
# Create response
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
]
],
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,

View File

@ -1,12 +1,9 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
model: str,
prompt: str,
duration: int,
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):

View File

@ -1,11 +1,8 @@
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:

View File

@ -11,12 +11,11 @@ User Guides:
"""
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
estimated_duration: int | None = None,
) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
start_frame: Input.Image,
end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)

View File

@ -1,11 +1,9 @@
import base64
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
model: str,
generate_audio: bool,
):
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")

View File

@ -1,5 +1,6 @@
import logging
import os
import math
import json
import numpy as np
@ -623,6 +624,79 @@ class TextProcessingNode(io.ComfyNode):
# ========== Image Transform Nodes ==========
class ResizeImagesToSameSizeNode(ImageProcessingNode):
node_id = "ResizeImagesToSameSize"
display_name = "Resize Images to Same Size"
description = "Resize all images to the same width and height."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
io.Combo.Input(
"mode",
options=["stretch", "crop_center", "pad"],
default="stretch",
tooltip="Resize mode.",
),
]
@classmethod
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)
if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
return pil_to_tensor(img)
class ResizeImagesToPixelCountNode(ImageProcessingNode):
node_id = "ResizeImagesToPixelCount"
display_name = "Resize Images to Pixel Count"
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"pixel_count",
default=512 * 512,
min=1,
max=8192 * 8192,
tooltip="Target pixel count.",
),
io.Int.Input(
"steps",
default=64,
min=1,
max=128,
tooltip="The stepping for resize width/height.",
),
]
@classmethod
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge"
@ -727,6 +801,29 @@ class RandomCropImagesNode(ImageProcessingNode):
return pil_to_tensor(img)
class FlipImagesNode(ImageProcessingNode):
node_id = "FlipImages"
display_name = "Flip Images"
description = "Flip all images horizontally or vertically."
extra_inputs = [
io.Combo.Input(
"direction",
options=["horizontal", "vertical"],
default="horizontal",
tooltip="Flip direction.",
),
]
@classmethod
def _process(cls, image, direction):
img = tensor_to_pil(image)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return pil_to_tensor(img)
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
@ -1125,6 +1222,99 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (C, H, W)
h, w = latent.shape[1], latent.shape[2]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@ -1373,7 +1563,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1399,10 +1589,13 @@ class DatasetExtension(ComfyExtension):
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
ResizeImagesToSameSizeNode,
ResizeImagesToPixelCountNode,
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
RandomCropImagesNode,
FlipImagesNode,
NormalizeImagesNode,
AdjustBrightnessNode,
AdjustContrastNode,
@ -1425,6 +1618,7 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]

View File

@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.samplers
import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute()
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler):
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
@ -142,65 +330,13 @@ class TrainSampler(comfy.samplers.Sampler):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
@ -283,6 +419,364 @@ def unpatch(m):
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter, lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
],
outputs=[
io.Model.Output(
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -427,215 +927,124 @@ class TrainLoraNode(io.ComfyNode):
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer = optimizer[0]
loss_function = loss_function[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
latents = _process_latents_standard_mode(latents)
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
existing_weights, existing_steps = _load_existing_lora(existing_lora)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(
lora_dtype
)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
# Create optimizer and loss function
optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
criterion = _create_loss_function(loss_function_name)
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
# Setup models for training
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
)
torch.cuda.empty_cache()
# Setup sampler and guider like in test script
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
# Training loop
# Setup guider
guider = TrainGuider(mp)
guider.set_conds(positive)
# Run training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
_run_training_loop(
guider,
train_sampler,
sigmas,
seed=noise.seed,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
@ -645,7 +1054,7 @@ class TrainLoraNode(io.ComfyNode):
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode):#
@classmethod
def define_schema(cls):
return io.Schema(

View File

@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=VideoContainer(format),
format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25
comfyui-workflow-templates==0.7.51
comfyui-embedded-docs==0.3.1
torch
torchsde