mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Compare commits
10 Commits
de9ceb36af
...
9f018ddb3f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f018ddb3f | ||
|
|
fd271dedfd | ||
|
|
c3c6313fc7 | ||
|
|
85c4b4ae26 | ||
|
|
058f084371 | ||
|
|
ec7f65187d | ||
|
|
37139daa98 | ||
|
|
4004af3290 | ||
|
|
bf573e94a2 | ||
|
|
7a93c55a9f |
@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
models_list = [model] if not skip_load_model else []
|
||||
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
This should be used to initialize any global resources needed by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
|
||||
@ -3,14 +3,14 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
|
||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
from .._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: list[str] | None = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
"seedance-1-0-pro-fast-251015": {
|
||||
"480p": 50,
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
}
|
||||
@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
|
||||
@ -85,7 +85,7 @@ class Response1(BaseModel):
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
videos: Optional[list[Video]] = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
|
||||
class Text2ImageModelName(str, Enum):
|
||||
seedream_3 = "seedream-3-0-t2i-250415"
|
||||
|
||||
|
||||
class Image2ImageModelName(str, Enum):
|
||||
seededit_3 = "seededit-3-0-i2i-250628"
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
size: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: Optional[str] = Field("adaptive")
|
||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field("seedream-4-0-250828")
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: Optional[TaskStatusError] = Field(None)
|
||||
content: Optional[TaskStatusResult] = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
return response.data[0]["url"]
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "content") and response.content:
|
||||
return response.content.video_url
|
||||
return None
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2ImageModelName,
|
||||
default=Text2ImageModelName.seedream_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2ImageModelName,
|
||||
default=Image2ImageModelName.seededit_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-0-250828"],
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor = None,
|
||||
image: Input.Image | None = None,
|
||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||
width: int = 2048,
|
||||
height: int = 2048,
|
||||
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2VideoModelName,
|
||||
default=Text2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2VideoModelName,
|
||||
default=Image2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in Image2VideoModelName],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
estimated_duration: Optional[int],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -1085,7 +934,7 @@ async def process_video_task(
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
|
||||
@ -13,8 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
GeminiSystemInstructionContent,
|
||||
GeminiTextPart,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||
"If a prompt is conversational or lacks specific visual details, "
|
||||
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
],
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
|
||||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
fps: int | None = Field(25)
|
||||
generate_audio: bool | None = Field(True)
|
||||
image_uri: str | None = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Optional[VideoInput] = None,
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: Optional[int] = 100,
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
|
||||
@ -11,12 +11,11 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
@ -119,8 +116,8 @@ async def get_response(
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
estimated_duration: int | None = None,
|
||||
) -> InputImpl.VideoFromFile:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
@ -623,6 +624,79 @@ class TextProcessingNode(io.ComfyNode):
|
||||
# ========== Image Transform Nodes ==========
|
||||
|
||||
|
||||
class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToSameSize"
|
||||
display_name = "Resize Images to Same Size"
|
||||
description = "Resize all images to the same width and height."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
|
||||
io.Combo.Input(
|
||||
"mode",
|
||||
options=["stretch", "crop_center", "pad"],
|
||||
default="stretch",
|
||||
tooltip="Resize mode.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, width, height, mode):
|
||||
img = tensor_to_pil(image)
|
||||
|
||||
if mode == "stretch":
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "crop_center":
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
if img.width != width or img.height != height:
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "pad":
|
||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
paste_x = (width - img.width) // 2
|
||||
paste_y = (height - img.height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
img = new_img
|
||||
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToPixelCount"
|
||||
display_name = "Resize Images to Pixel Count"
|
||||
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"pixel_count",
|
||||
default=512 * 512,
|
||||
min=1,
|
||||
max=8192 * 8192,
|
||||
tooltip="Target pixel count.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
default=64,
|
||||
min=1,
|
||||
max=128,
|
||||
tooltip="The stepping for resize width/height.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, pixel_count, steps):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
@ -727,6 +801,29 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class FlipImagesNode(ImageProcessingNode):
|
||||
node_id = "FlipImages"
|
||||
display_name = "Flip Images"
|
||||
description = "Flip all images horizontally or vertically."
|
||||
extra_inputs = [
|
||||
io.Combo.Input(
|
||||
"direction",
|
||||
options=["horizontal", "vertical"],
|
||||
default="horizontal",
|
||||
tooltip="Flip direction.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, direction):
|
||||
img = tensor_to_pil(image)
|
||||
if direction == "horizontal":
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
else:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
display_name = "Normalize Images"
|
||||
@ -1125,6 +1222,99 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
# ========== Training Dataset Nodes ==========
|
||||
|
||||
|
||||
class ResolutionBucket(io.ComfyNode):
|
||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
display_name="Resolution Bucket",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Latent.Input(
|
||||
"latents",
|
||||
tooltip="List of latent dicts to bucket by resolution.",
|
||||
),
|
||||
io.Conditioning.Input(
|
||||
"conditioning",
|
||||
tooltip="List of conditioning lists (must match latents length).",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(
|
||||
display_name="latents",
|
||||
is_output_list=True,
|
||||
tooltip="List of batched latent dicts, one per resolution bucket.",
|
||||
),
|
||||
io.Conditioning.Output(
|
||||
display_name="conditioning",
|
||||
is_output_list=True,
|
||||
tooltip="List of condition lists, one per resolution bucket.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents, conditioning):
|
||||
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||
# conditioning: list[list[cond]]
|
||||
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
raise ValueError(
|
||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||
)
|
||||
|
||||
# Flatten latents and conditions to individual samples
|
||||
flat_latents = [] # list of (C, H, W) tensors
|
||||
flat_conditions = [] # list of condition lists
|
||||
|
||||
for latent_dict, cond in zip(latents, conditioning):
|
||||
samples = latent_dict["samples"] # (B, C, H, W)
|
||||
batch_size = samples.shape[0]
|
||||
|
||||
# cond is a list of conditions with length == batch_size
|
||||
for i in range(batch_size):
|
||||
flat_latents.append(samples[i]) # (C, H, W)
|
||||
flat_conditions.append(cond[i]) # single condition
|
||||
|
||||
# Group by resolution (H, W)
|
||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||
|
||||
for latent, cond in zip(flat_latents, flat_conditions):
|
||||
# latent shape is (C, H, W)
|
||||
h, w = latent.shape[1], latent.shape[2]
|
||||
key = (h, w)
|
||||
|
||||
if key not in buckets:
|
||||
buckets[key] = {"latents": [], "conditions": []}
|
||||
|
||||
buckets[key]["latents"].append(latent)
|
||||
buckets[key]["conditions"].append(cond)
|
||||
|
||||
# Convert buckets to output format
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
|
||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||
|
||||
for (h, w), bucket_data in buckets.items():
|
||||
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
|
||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||
output_latents.append({"samples": stacked_latents})
|
||||
|
||||
# Conditions stay as list of condition lists
|
||||
output_conditions.append(bucket_data["conditions"])
|
||||
|
||||
logging.info(
|
||||
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||
)
|
||||
|
||||
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||
return io.NodeOutput(output_latents, output_conditions)
|
||||
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@ -1373,7 +1563,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
shard_data = torch.load(f)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
@ -1399,10 +1589,13 @@ class DatasetExtension(ComfyExtension):
|
||||
SaveImageDataSetToFolderNode,
|
||||
SaveImageTextDataSetToFolderNode,
|
||||
# Image transform nodes
|
||||
ResizeImagesToSameSizeNode,
|
||||
ResizeImagesToPixelCountNode,
|
||||
ResizeImagesByShorterEdgeNode,
|
||||
ResizeImagesByLongerEdgeNode,
|
||||
CenterCropImagesNode,
|
||||
RandomCropImagesNode,
|
||||
FlipImagesNode,
|
||||
NormalizeImagesNode,
|
||||
AdjustBrightnessNode,
|
||||
AdjustContrastNode,
|
||||
@ -1425,6 +1618,7 @@ class DatasetExtension(ComfyExtension):
|
||||
MakeTrainingDataset,
|
||||
SaveTrainingDataset,
|
||||
LoadTrainingDataset,
|
||||
ResolutionBucket,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
|
||||
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
"""
|
||||
def outer_sample(
|
||||
self,
|
||||
noise,
|
||||
latent_image,
|
||||
sampler,
|
||||
sigmas,
|
||||
denoise_mask=None,
|
||||
callback=None,
|
||||
disable_pbar=False,
|
||||
seed=None,
|
||||
latent_shapes=None,
|
||||
):
|
||||
self.inner_model, self.conds, self.loaded_models = (
|
||||
comfy.sampler_helpers.prepare_sampling(
|
||||
self.model_patcher,
|
||||
noise.shape,
|
||||
self.conds,
|
||||
self.model_options,
|
||||
skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute()
|
||||
)
|
||||
)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
||||
denoise_mask, noise.shape, device
|
||||
)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
comfy.samplers.cast_to_load_options(
|
||||
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
||||
)
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(
|
||||
noise,
|
||||
latent_image,
|
||||
device,
|
||||
sampler,
|
||||
sigmas,
|
||||
denoise_mask,
|
||||
callback,
|
||||
disable_pbar,
|
||||
seed,
|
||||
latent_shapes=latent_shapes,
|
||||
)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
seed=0,
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||
# Bucket mode data
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
else:
|
||||
self.bucket_offsets = None
|
||||
self.bucket_weights = None
|
||||
self.num_images = None
|
||||
|
||||
def _init_bucket_data(self, bucket_latents):
|
||||
"""Initialize bucket offsets and weights for sampling."""
|
||||
self.bucket_offsets = [0]
|
||||
bucket_sizes = []
|
||||
for lat in bucket_latents:
|
||||
bucket_sizes.append(lat.shape[0])
|
||||
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
||||
self.num_images = self.bucket_offsets[-1]
|
||||
# Weights for sampling buckets proportional to their size
|
||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||
|
||||
def fwd_bwd(
|
||||
self,
|
||||
@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
"""Generate random sigma values for a batch."""
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
return torch.tensor(batch_sigmas).to(device)
|
||||
|
||||
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
||||
"""Execute one training step in bucket mode."""
|
||||
# Sample bucket (weighted by size), then sample batch from bucket
|
||||
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
||||
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
||||
bucket_size = bucket_latent.shape[0]
|
||||
bucket_offset = self.bucket_offsets[bucket_idx]
|
||||
|
||||
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
||||
actual_batch_size = min(self.batch_size, bucket_size)
|
||||
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||
|
||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond, # Use flattened cond with absolute indices
|
||||
absolute_indices,
|
||||
extra_args,
|
||||
self.num_images,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
||||
|
||||
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
batch_sigmas = (
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
single_latent,
|
||||
cond,
|
||||
[index],
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=False,
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
model_wrap,
|
||||
@ -142,65 +330,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||
self.seed + i * 1000
|
||||
)
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
|
||||
if self.real_dataset is None:
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
if self.bucket_latents is not None:
|
||||
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
||||
elif self.real_dataset is None:
|
||||
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
else:
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
batch_sigmas = (
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
single_latent,
|
||||
cond,
|
||||
[index],
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=False,
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
@ -283,6 +419,364 @@ def unpatch(m):
|
||||
del m.org_forward
|
||||
|
||||
|
||||
def _process_latents_bucket_mode(latents):
|
||||
"""Process latents for bucket mode training.
|
||||
|
||||
Args:
|
||||
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||
|
||||
Returns:
|
||||
list of latent tensors
|
||||
"""
|
||||
bucket_latents = []
|
||||
for latent_dict in latents:
|
||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||
return bucket_latents
|
||||
|
||||
|
||||
def _process_latents_standard_mode(latents):
|
||||
"""Process latents for standard (non-bucket) mode training.
|
||||
|
||||
Args:
|
||||
latents: list of latent dicts or single latent dict
|
||||
|
||||
Returns:
|
||||
Processed latents (tensor or list of tensors)
|
||||
"""
|
||||
if len(latents) == 1:
|
||||
return latents[0]["samples"] # Single latent dict
|
||||
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
for sub_latent in latent:
|
||||
latent_list.append(sub_latent[None])
|
||||
else:
|
||||
latent_list.append(latent)
|
||||
return latent_list
|
||||
|
||||
|
||||
def _process_conditioning(positive):
|
||||
"""Process conditioning - either single list or list of lists.
|
||||
|
||||
Args:
|
||||
positive: list of conditioning
|
||||
|
||||
Returns:
|
||||
Flattened conditioning list
|
||||
"""
|
||||
if len(positive) == 1:
|
||||
return positive[0] # Single conditioning list
|
||||
|
||||
# Multiple conditioning lists - flatten
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
flat_positive.extend(cond)
|
||||
else:
|
||||
flat_positive.append(cond)
|
||||
return flat_positive
|
||||
|
||||
|
||||
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
"""Convert latents to dtype and compute image counts.
|
||||
|
||||
Args:
|
||||
latents: Latents (tensor, list of tensors, or bucket list)
|
||||
dtype: Target dtype
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
|
||||
Returns:
|
||||
tuple: (processed_latents, num_images, multi_res)
|
||||
"""
|
||||
if bucket_mode:
|
||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
num_buckets = len(latents)
|
||||
num_images = sum(t.shape[0] for t in latents)
|
||||
multi_res = False # Not using multi_res path in bucket mode
|
||||
|
||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
for i, lat in enumerate(latents):
|
||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||
return latents, num_images, multi_res
|
||||
|
||||
# Non-bucket mode
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, torch.Tensor):
|
||||
latents = latents.to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
multi_res = False
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
num_images = 0
|
||||
multi_res = False
|
||||
|
||||
return latents, num_images, multi_res
|
||||
|
||||
|
||||
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||
"""Validate conditioning count matches image count, expand if needed.
|
||||
|
||||
Args:
|
||||
positive: Conditioning list
|
||||
num_images: Number of images
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
|
||||
Returns:
|
||||
Validated/expanded conditioning list
|
||||
|
||||
Raises:
|
||||
ValueError: If conditioning count doesn't match image count
|
||||
"""
|
||||
if bucket_mode:
|
||||
return positive # Skip validation in bucket mode
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
return positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
return positive
|
||||
|
||||
|
||||
def _load_existing_lora(existing_lora):
|
||||
"""Load existing LoRA weights if provided.
|
||||
|
||||
Args:
|
||||
existing_lora: LoRA filename or "[None]"
|
||||
|
||||
Returns:
|
||||
tuple: (existing_weights dict, existing_steps int)
|
||||
"""
|
||||
if existing_lora == "[None]":
|
||||
return {}, 0
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
existing_weights = {}
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
return existing_weights, existing_steps
|
||||
|
||||
|
||||
def _create_weight_adapter(
|
||||
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
||||
):
|
||||
"""Create a weight adapter for a module with weight.
|
||||
|
||||
Args:
|
||||
module: The module to create adapter for
|
||||
module_name: Name of the module
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (train_adapter, lora_params dict)
|
||||
"""
|
||||
key = f"{module_name}.weight"
|
||||
shape = module.weight.shape
|
||||
lora_params = {}
|
||||
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
|
||||
# Try to load existing adapter
|
||||
existing_adapter = None
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
module_name, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
|
||||
if existing_adapter is None:
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
module.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_params[f"{module_name}.{name}"] = parameter
|
||||
|
||||
return train_adapter, lora_params
|
||||
else:
|
||||
# 1D weight - use BiasDiff
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
lora_params[f"{module_name}.diff"] = diff
|
||||
return diff_module, lora_params
|
||||
|
||||
|
||||
def _create_bias_adapter(module, module_name, lora_dtype):
|
||||
"""Create a bias adapter for a module with bias.
|
||||
|
||||
Args:
|
||||
module: The module with bias
|
||||
module_name: Name of the module
|
||||
lora_dtype: dtype for LoRA weights
|
||||
|
||||
Returns:
|
||||
tuple: (bias_module, lora_params dict)
|
||||
"""
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_params = {f"{module_name}.diff_b": bias}
|
||||
return bias_module, lora_params
|
||||
|
||||
|
||||
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
"""Setup all LoRA adapters on the model.
|
||||
|
||||
Args:
|
||||
mp: Model patcher
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (lora_sd dict, all_weight_adapters list)
|
||||
"""
|
||||
lora_sd = {}
|
||||
all_weight_adapters = []
|
||||
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
adapter, params = _create_weight_adapter(
|
||||
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
lora_sd.update(params)
|
||||
key = f"{n}.weight"
|
||||
mp.add_weight_wrapper(key, adapter)
|
||||
all_weight_adapters.append(adapter)
|
||||
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||
lora_sd.update(bias_params)
|
||||
key = f"{n}.bias"
|
||||
mp.add_weight_wrapper(key, bias_adapter)
|
||||
all_weight_adapters.append(bias_adapter)
|
||||
|
||||
return lora_sd, all_weight_adapters
|
||||
|
||||
|
||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||
"""Create optimizer based on name.
|
||||
|
||||
Args:
|
||||
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
||||
parameters: Parameters to optimize
|
||||
learning_rate: Learning rate
|
||||
|
||||
Returns:
|
||||
Optimizer instance
|
||||
"""
|
||||
if optimizer_name == "Adam":
|
||||
return torch.optim.Adam(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "AdamW":
|
||||
return torch.optim.AdamW(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "SGD":
|
||||
return torch.optim.SGD(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "RMSprop":
|
||||
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
||||
|
||||
|
||||
def _create_loss_function(loss_function_name):
|
||||
"""Create loss function based on name.
|
||||
|
||||
Args:
|
||||
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
||||
|
||||
Returns:
|
||||
Loss function instance
|
||||
"""
|
||||
if loss_function_name == "MSE":
|
||||
return torch.nn.MSELoss()
|
||||
elif loss_function_name == "L1":
|
||||
return torch.nn.L1Loss()
|
||||
elif loss_function_name == "Huber":
|
||||
return torch.nn.HuberLoss()
|
||||
elif loss_function_name == "SmoothL1":
|
||||
return torch.nn.SmoothL1Loss()
|
||||
|
||||
|
||||
def _run_training_loop(
|
||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
||||
):
|
||||
"""Execute the training loop.
|
||||
|
||||
Args:
|
||||
guider: The guider object
|
||||
train_sampler: The training sampler
|
||||
latents: Latent tensors
|
||||
num_images: Number of images
|
||||
seed: Random seed
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
multi_res: Whether multi-resolution mode is enabled
|
||||
"""
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
|
||||
if bucket_mode:
|
||||
# Use first bucket's first latent as dummy for guider
|
||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": dummy_latent}),
|
||||
dummy_latent,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
elif multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
else:
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
|
||||
|
||||
class TrainLoraNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="[None]",
|
||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bucket_mode",
|
||||
default=False,
|
||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
):
|
||||
# Extract scalars from lists (due to is_input_list=True)
|
||||
model = model[0]
|
||||
@ -427,215 +927,124 @@ class TrainLoraNode(io.ComfyNode):
|
||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||
learning_rate = learning_rate[0]
|
||||
rank = rank[0]
|
||||
optimizer = optimizer[0]
|
||||
loss_function = loss_function[0]
|
||||
optimizer_name = optimizer[0]
|
||||
loss_function_name = loss_function[0]
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
|
||||
# Handle latents - either single dict or list of dicts
|
||||
if len(latents) == 1:
|
||||
latents = latents[0]["samples"] # Single latent dict
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
else:
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
for sub_latent in latent:
|
||||
latent_list.append(sub_latent[None])
|
||||
else:
|
||||
latent_list.append(latent)
|
||||
latents = latent_list
|
||||
latents = _process_latents_standard_mode(latents)
|
||||
|
||||
# Handle conditioning - either single list or list of lists
|
||||
if len(positive) == 1:
|
||||
positive = positive[0] # Single conditioning list
|
||||
else:
|
||||
# Multiple conditioning lists - flatten
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
flat_positive.extend(cond)
|
||||
else:
|
||||
flat_positive.append(cond)
|
||||
positive = flat_positive
|
||||
# Process conditioning
|
||||
positive = _process_conditioning(positive)
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
# latents here can be list of different size latent or one large batch
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, torch.Tensor):
|
||||
latents = latents.to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
# Prepare latents and compute counts
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
)
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
existing_adapter = None
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
# Setup LoRA adapters
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(
|
||||
lora_dtype
|
||||
)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
# Create optimizer and loss function
|
||||
optimizer = _create_optimizer(
|
||||
optimizer_name, lora_sd.values(), learning_rate
|
||||
)
|
||||
criterion = _create_loss_function(loss_function_name)
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
# Setup gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
# Setup loss tracking
|
||||
loss_map = {"loss": []}
|
||||
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
# Create sampler
|
||||
if bucket_mode:
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
if multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
_run_training_loop(
|
||||
guider,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
latents,
|
||||
num_images,
|
||||
seed,
|
||||
bucket_mode,
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
# Finalize adapters
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
@ -645,7 +1054,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
|
||||
@ -8,10 +8,7 @@ import json
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=VideoContainer(format),
|
||||
format=Types.VideoContainer(format),
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
)
|
||||
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.33.10
|
||||
comfyui-workflow-templates==0.7.25
|
||||
comfyui-workflow-templates==0.7.51
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user