Compare commits

...

22 Commits

Author SHA1 Message Date
Maifee Ul Asad
1b9b33166d
Merge 7ec165673553b968d71942c50c76fe34c95bbaee into fd271dedfde6e192a1f1a025521070876e89e04a 2025-12-08 15:59:15 +06:00
Alexander Piskun
fd271dedfd
[API Nodes] add support for seedance-1-0-pro-fast model (#10947)
* feat(api-nodes): add support for seedance-1-0-pro-fast model

* feat(api-nodes): add support for seedream-4.5 model
2025-12-08 01:33:46 -08:00
Alexander Piskun
c3c6313fc7
Added "system_prompt" input to Gemini nodes (#11177) 2025-12-08 01:28:17 -08:00
Alexander Piskun
85c4b4ae26
chore: replace imports of deprecated V1 classes (#11127) 2025-12-08 01:27:02 -08:00
ComfyUI Wiki
058f084371
Update workflow templates to v0.7.51 (#11150)
* chore: update workflow templates to v0.7.50

* Update template to 0.7.51
2025-12-08 01:22:51 -08:00
Alexander Piskun
ec7f65187d
chore(comfy_api): replace absolute imports with relative (#11145) 2025-12-08 01:21:41 -08:00
Maifee Ul Asad
7ec1656735
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-12-07 15:32:22 +06:00
Maifee Ul Asad
cee75f301a
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-27 08:47:41 +06:00
Maifee Ul Asad
1a59686ca8
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-25 22:09:53 +06:00
Maifee Ul Asad
6d96d26795
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-25 22:08:51 +06:00
Maifee Ul Asad
e07a32c9b8
Merge branch 'master' into offloader-maifee 2025-11-09 17:25:46 +06:00
Maifee Ul Asad
a19f0a88e4 refactor(gds): add show_stats option; moved GDS initialization to dedicated file; 2025-10-12 00:54:04 +06:00
Maifee Ul Asad
64811809a0 docs: updated GDS environment limitations;
- works only on linux and nvidia
2025-10-11 21:51:24 +06:00
Maifee Ul Asad
529109083a docs: added GDS docs and deps documentation; 2025-10-10 22:23:13 +06:00
Maifee Ul Asad
a7be9f6fc3 review: remove GDS-related dependencies from requirements.txt 2025-10-10 22:08:58 +06:00
Maifee Ul Asad
6075c44ec8 feat(gds): add GDS-related dependencies to requirements.txt 2025-10-08 14:43:55 +06:00
Maifee Ul Asad
154b73835a feat(gds): implement GPUDirect Storage initialization based on CLI arguments 2025-10-08 14:40:59 +06:00
Maifee Ul Asad
862e7784f4 feat(gds): add nodes_gds.py to built-in extra nodes initialization 2025-10-08 14:40:41 +06:00
Maifee Ul Asad
f6b6636bf3 feat(gds): implement GDS loading fallback in load_torch_file function;
- need to work with tensorflow and other formats
 - afaik, almost all models shared now is in torch format
 - converting types should not be that big of a deal
2025-10-08 14:40:29 +06:00
Maifee Ul Asad
83b00df3f0 feat(gds): add GPUDirect Storage options for SSD-to-GPU model loading; 2025-10-08 14:39:22 +06:00
Maifee Ul Asad
5f24eb699c feat(gds): implement GPUDirect Storage for efficient model loading 2025-10-08 14:38:49 +06:00
Maifee Ul Asad
fab0954077 feat(gds): add GPUDirect Storage support for model loading and prefetching
- limited to NVIDIA GPUs only
2025-10-08 14:38:38 +06:00
24 changed files with 1162 additions and 296 deletions

View File

@ -399,6 +399,14 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## How to run heavy workflow on mid range GPU (NVIDIA-Linux)?
Use the `--enable-gds` flag to activate NVIDIA [GPUDirect Storage](https://docs.nvidia.com/gpudirect-storage/) (GDS), which allows data to be transferred directly between SSDs and GPUs. This eliminates traditional CPU-mediated data paths, significantly reducing I/O latency and CPU overhead. System RAM will still be utilized for caching to further optimize performance, along with SSD.
This feature is tested on NVIDIA GPUs on Linux based system only.
Requires: `cupy-cuda12x>=12.0.0`, `pynvml>=11.4.1`, `cudf>=23.0.0`, `numba>=0.57.0`, `nvidia-ml-py>=12.0.0`.
## Support and dev channel
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.

View File

@ -147,6 +147,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
# GPUDirect Storage (GDS) arguments
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"

494
comfy/gds_loader.py Normal file
View File

@ -0,0 +1,494 @@
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
"""
GPUDirect Storage (GDS) Integration for ComfyUI
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
This module provides GPUDirect Storage functionality to load models directly
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
"""
import os
import logging
import torch
import time
from typing import Optional, Dict, Any, Union
from pathlib import Path
import safetensors
import gc
import mmap
from dataclasses import dataclass
try:
import cupy
import cupy.cuda.runtime as cuda_runtime
CUPY_AVAILABLE = True
except ImportError:
CUPY_AVAILABLE = False
logging.warning("CuPy not available. GDS will use fallback mode.")
try:
import cudf # RAPIDS for GPU dataframes
RAPIDS_AVAILABLE = True
except ImportError:
RAPIDS_AVAILABLE = False
try:
import pynvml
pynvml.nvmlInit()
NVML_AVAILABLE = True
except ImportError:
NVML_AVAILABLE = False
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
@dataclass
class GDSConfig:
"""Configuration for GPUDirect Storage"""
enabled: bool = True
min_file_size_mb: int = 100 # Only use GDS for files larger than this
chunk_size_mb: int = 64 # Size of chunks to transfer
use_pinned_memory: bool = True
prefetch_enabled: bool = True
compression_aware: bool = True
max_concurrent_streams: int = 4
fallback_to_cpu: bool = True
show_stats: bool = False # Whether to show stats on exit
class GDSError(Exception):
"""GDS-specific errors"""
pass
class GPUDirectStorage:
"""
GPUDirect Storage implementation for ComfyUI
Enables direct SSD-to-GPU transfers for model loading
"""
def __init__(self, config: Optional[GDSConfig] = None):
self.config = config or GDSConfig()
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
self.cuda_streams = []
self.pinned_buffers = {}
self.stats = {
'gds_loads': 0,
'fallback_loads': 0,
'total_bytes_gds': 0,
'total_time_gds': 0.0,
'avg_bandwidth_gbps': 0.0
}
# Initialize GDS if available
self._gds_available = self._check_gds_availability()
if self._gds_available:
self._init_gds()
else:
logging.warning("GDS not available, using fallback methods")
def _check_gds_availability(self) -> bool:
"""Check if GDS is available on the system"""
if not torch.cuda.is_available():
return False
if not CUPY_AVAILABLE:
return False
# Check for GPUDirect Storage support
try:
# Check CUDA version (GDS requires CUDA 11.4+)
cuda_version = torch.version.cuda
if cuda_version:
major, minor = map(int, cuda_version.split('.')[:2])
if major < 11 or (major == 11 and minor < 4):
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
return False
# Check if cuFile is available (part of CUDA toolkit)
try:
import cupy.cuda.cufile as cufile
# Try to initialize cuFile
cufile.initialize()
return True
except (ImportError, RuntimeError) as e:
logging.warning(f"cuFile not available: {e}")
return False
except Exception as e:
logging.warning(f"GDS availability check failed: {e}")
return False
def _init_gds(self):
"""Initialize GDS resources"""
try:
# Create CUDA streams for async operations
for i in range(self.config.max_concurrent_streams):
stream = torch.cuda.Stream()
self.cuda_streams.append(stream)
# Pre-allocate pinned memory buffers
if self.config.use_pinned_memory:
self._allocate_pinned_buffers()
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
except Exception as e:
logging.error(f"Failed to initialize GDS: {e}")
self._gds_available = False
def _allocate_pinned_buffers(self):
"""Pre-allocate pinned memory buffers for staging"""
try:
# Allocate buffers of different sizes
buffer_sizes = [16, 32, 64, 128, 256] # MB
for size_mb in buffer_sizes:
size_bytes = size_mb * 1024 * 1024
# Allocate pinned memory using CuPy
if CUPY_AVAILABLE:
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
self.pinned_buffers[size_mb] = buffer
except Exception as e:
logging.warning(f"Failed to allocate pinned buffers: {e}")
def _get_file_size(self, file_path: str) -> int:
"""Get file size in bytes"""
return os.path.getsize(file_path)
def _should_use_gds(self, file_path: str) -> bool:
"""Determine if GDS should be used for this file"""
if not self._gds_available or not self.config.enabled:
return False
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
return file_size_mb >= self.config.min_file_size_mb
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load model using GPUDirect Storage"""
start_time = time.time()
try:
if file_path.lower().endswith(('.safetensors', '.sft')):
return self._load_safetensors_gds(file_path)
else:
return self._load_pytorch_gds(file_path)
except Exception as e:
logging.error(f"GDS loading failed for {file_path}: {e}")
if self.config.fallback_to_cpu:
logging.info("Falling back to CPU loading")
self.stats['fallback_loads'] += 1
return self._load_fallback(file_path)
else:
raise GDSError(f"GDS loading failed: {e}")
finally:
load_time = time.time() - start_time
self.stats['total_time_gds'] += load_time
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load safetensors file using GDS"""
try:
import cupy.cuda.cufile as cufile
# Open file with cuFile for direct GPU loading
with cufile.CuFileManager() as manager:
# Memory-map the file for efficient access
with open(file_path, 'rb') as f:
# Use mmap for large files
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
# Parse safetensors header
header_size = int.from_bytes(mmapped_file[:8], 'little')
header_bytes = mmapped_file[8:8+header_size]
import json
header = json.loads(header_bytes.decode('utf-8'))
# Load tensors directly to GPU
tensors = {}
data_offset = 8 + header_size
for name, info in header.items():
if name == "__metadata__":
continue
dtype_map = {
'F32': torch.float32,
'F16': torch.float16,
'BF16': torch.bfloat16,
'I8': torch.int8,
'I16': torch.int16,
'I32': torch.int32,
'I64': torch.int64,
'U8': torch.uint8,
}
dtype = dtype_map.get(info['dtype'], torch.float32)
shape = info['shape']
start_offset = data_offset + info['data_offsets'][0]
end_offset = data_offset + info['data_offsets'][1]
# Direct GPU allocation
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
# Use cuFile for direct transfer
tensor_bytes = end_offset - start_offset
# Get GPU memory pointer
gpu_ptr = tensor.data_ptr()
# Direct file-to-GPU transfer
cufile.copy_from_file(
gpu_ptr,
mmapped_file[start_offset:end_offset],
tensor_bytes
)
tensors[name] = tensor
self.stats['gds_loads'] += 1
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
return tensors
except Exception as e:
logging.error(f"GDS safetensors loading failed: {e}")
raise
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load PyTorch file using GDS with staging"""
try:
# For PyTorch files, we need to use a staging approach
# since torch.load doesn't support direct GPU loading
# Load to pinned memory first
with open(file_path, 'rb') as f:
file_size = self._get_file_size(file_path)
# Choose appropriate buffer or allocate new one
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
if buffer_size_mb in self.pinned_buffers:
pinned_buffer = self.pinned_buffers[buffer_size_mb]
else:
# Allocate temporary pinned buffer
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
# Read file to pinned memory
f.readinto(pinned_buffer)
# Use torch.load with map_location to specific GPU
# This will be faster due to pinned memory
state_dict = torch.load(
f,
map_location=f'cuda:{self.device}',
weights_only=True
)
self.stats['gds_loads'] += 1
self.stats['total_bytes_gds'] += file_size
return state_dict
except Exception as e:
logging.error(f"GDS PyTorch loading failed: {e}")
raise
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Fallback loading method using standard approaches"""
if file_path.lower().endswith(('.safetensors', '.sft')):
# Use safetensors with device parameter
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
return {k: f.get_tensor(k) for k in f.keys()}
else:
# Standard PyTorch loading
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
"""
Main entry point for loading models with GDS
Args:
file_path: Path to the model file
device: Target device (if None, uses current CUDA device)
Returns:
Dictionary of tensors loaded directly to GPU
"""
if device is not None and device.type == 'cuda':
self.device = device.index or 0
if self._should_use_gds(file_path):
logging.info(f"Loading {file_path} with GDS")
return self._load_with_gds(file_path)
else:
logging.info(f"Loading {file_path} with standard method")
self.stats['fallback_loads'] += 1
return self._load_fallback(file_path)
def prefetch_model(self, file_path: str) -> bool:
"""
Prefetch model to GPU memory cache (if supported)
Args:
file_path: Path to the model file
Returns:
True if prefetch was successful
"""
if not self.config.prefetch_enabled or not self._gds_available:
return False
try:
# Basic prefetch implementation
# This would ideally use NVIDIA's GPUDirect Storage API
# to warm up the storage cache
file_size = self._get_file_size(file_path)
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
# Read file metadata to warm caches
with open(file_path, 'rb') as f:
# Read first and last chunks to trigger prefetch
f.read(1024 * 1024) # First 1MB
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
f.read()
return True
except Exception as e:
logging.warning(f"Prefetch failed for {file_path}: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""Get loading statistics"""
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
return {
**self.stats,
'total_loads': total_loads,
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
'gds_available': self._gds_available,
'config': self.config.__dict__
}
def cleanup(self):
"""Clean up GDS resources"""
try:
# Clear CUDA streams
for stream in self.cuda_streams:
stream.synchronize()
self.cuda_streams.clear()
# Free pinned buffers
for buffer in self.pinned_buffers.values():
if CUPY_AVAILABLE:
cupy.cuda.free_pinned_memory(buffer)
self.pinned_buffers.clear()
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
logging.warning(f"GDS cleanup failed: {e}")
def __del__(self):
"""Destructor to ensure cleanup"""
self.cleanup()
# Global GDS instance
_gds_instance: Optional[GPUDirectStorage] = None
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
"""Get or create the global GDS instance"""
global _gds_instance
if _gds_instance is None:
_gds_instance = GPUDirectStorage(config)
return _gds_instance
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
"""
GDS-enabled replacement for comfy.utils.load_torch_file
Args:
ckpt: Path to checkpoint file
safe_load: Whether to use safe loading (for compatibility)
device: Target device
Returns:
Dictionary of loaded tensors
"""
gds = get_gds_instance()
try:
# Load with GDS
return gds.load_model(ckpt, device)
except Exception as e:
logging.error(f"GDS loading failed, falling back to standard method: {e}")
# Fallback to original method
import comfy.utils
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
def prefetch_model_gds(file_path: str) -> bool:
"""Prefetch model for faster loading"""
gds = get_gds_instance()
return gds.prefetch_model(file_path)
def get_gds_stats() -> Dict[str, Any]:
"""Get GDS statistics"""
gds = get_gds_instance()
return gds.get_stats()
def configure_gds(config: GDSConfig):
"""Configure GDS settings"""
global _gds_instance
_gds_instance = GPUDirectStorage(config)
def init_gds(config: GDSConfig):
"""
Initialize GPUDirect Storage with the provided configuration
Args:
config: GDSConfig object with initialization parameters
"""
try:
# Configure GDS
configure_gds(config)
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
# Set up exit handler for stats if requested
if hasattr(config, 'show_stats') and config.show_stats:
import atexit
def print_gds_stats():
stats = get_gds_stats()
logging.info("=== GDS Statistics ===")
logging.info(f"Total loads: {stats['total_loads']}")
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
logging.info(f"Fallback loads: {stats['fallback_loads']}")
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
logging.info("===================")
atexit.register(print_gds_stats)
except ImportError as e:
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
except Exception as e:
logging.error(f"GDS initialization failed: {e}")

View File

@ -56,6 +56,18 @@ else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# Try GDS loading first if available and device is GPU
if device is not None and device.type == 'cuda':
try:
from . import gds_loader
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
if return_metadata:
# For GDS, we return empty metadata for now (can be enhanced)
return (gds_result, {})
return gds_result
except Exception as e:
logging.debug(f"GDS loading failed, using fallback: {e}")
if device is None:
device = torch.device("cpu")
metadata = None

View File

@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:

View File

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL

View File

@ -22,7 +22,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -0,0 +1,144 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):

View File

@ -1,13 +1,27 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedream-4-0-250828"],
options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor = None,
image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor,
image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
images: torch.Tensor,
images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
estimated_duration: Optional[int],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@ -1085,7 +934,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,8 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum):
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.String.Output(),
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
# Create response
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
]
],
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,

View File

@ -1,12 +1,9 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
model: str,
prompt: str,
duration: int,
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):

View File

@ -1,11 +1,8 @@
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:

View File

@ -11,12 +11,11 @@ User Guides:
"""
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
estimated_duration: int | None = None,
) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
start_frame: Input.Image,
end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)

View File

@ -1,11 +1,9 @@
import base64
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
model: str,
generate_audio: bool,
):
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")

293
comfy_extras/nodes_gds.py Normal file
View File

@ -0,0 +1,293 @@
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
"""
Enhanced model loading nodes with GPUDirect Storage support
"""
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import torch
import folder_paths
import comfy.sd
import comfy.utils
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
class CheckpointLoaderGDS(ComfyNodeABC):
"""
Enhanced checkpoint loader with GPUDirect Storage support
Provides direct SSD-to-GPU loading and prefetching capabilities
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
}),
},
"optional": {
"prefetch": ("BOOLEAN", {
"default": False,
"tooltip": "Prefetch model to GPU cache for faster loading."
}),
"use_gds": ("BOOLEAN", {
"default": True,
"tooltip": "Use GPUDirect Storage if available."
}),
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
"default": "auto",
"tooltip": "Target device for model loading."
})
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
RETURN_NAMES = ("model", "clip", "vae", "load_info")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
"Loading information and statistics."
)
FUNCTION = "load_checkpoint_gds"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
EXPERIMENTAL = True
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
start_time = time.time()
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
# Determine target device
if target_device == "auto":
device = None # Let the system decide
elif target_device == "cpu":
device = torch.device("cpu")
else:
device = torch.device(target_device)
load_info = {
"file": ckpt_name,
"path": ckpt_path,
"target_device": str(device) if device else "auto",
"gds_enabled": use_gds,
"prefetch_used": prefetch
}
try:
# Prefetch if requested
if prefetch and use_gds:
try:
from comfy.gds_loader import prefetch_model_gds
prefetch_success = prefetch_model_gds(ckpt_path)
load_info["prefetch_success"] = prefetch_success
if prefetch_success:
logging.info(f"Prefetched {ckpt_name} to GPU cache")
except Exception as e:
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
load_info["prefetch_error"] = str(e)
# Load checkpoint with potential GDS optimization
if use_gds and device and device.type == 'cuda':
try:
from comfy.gds_loader import get_gds_instance
gds = get_gds_instance()
# Check if GDS should be used for this file
if gds._should_use_gds(ckpt_path):
load_info["loader_used"] = "GDS"
logging.info(f"Loading {ckpt_name} with GDS")
else:
load_info["loader_used"] = "Standard"
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
except Exception as e:
logging.warning(f"GDS check failed, using standard loading: {e}")
load_info["loader_used"] = "Standard (GDS failed)"
else:
load_info["loader_used"] = "Standard"
# Load the actual checkpoint
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
load_time = time.time() - start_time
load_info["load_time_seconds"] = round(load_time, 3)
load_info["load_success"] = True
# Format load info as string
info_str = f"Loaded: {ckpt_name}\n"
info_str += f"Method: {load_info['loader_used']}\n"
info_str += f"Time: {load_info['load_time_seconds']}s\n"
info_str += f"Device: {load_info['target_device']}"
if "prefetch_success" in load_info:
info_str += f"\nPrefetch: {'' if load_info['prefetch_success'] else ''}"
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
return (*out[:3], info_str)
except Exception as e:
load_info["load_success"] = False
load_info["error"] = str(e)
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
logging.error(f"Checkpoint loading failed: {e}")
raise RuntimeError(error_str)
class ModelPrefetcher(ComfyNodeABC):
"""
Node for prefetching models to GPU cache
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"checkpoint_names": ("STRING", {
"multiline": True,
"default": "",
"tooltip": "List of checkpoint names to prefetch (one per line)."
}),
"prefetch_enabled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable/disable prefetching."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prefetch_report",)
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
FUNCTION = "prefetch_models"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
OUTPUT_NODE = True
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
if not prefetch_enabled:
return ("Prefetching disabled",)
# Parse checkpoint names
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
if not names:
return ("No checkpoints specified for prefetching",)
try:
from comfy.gds_loader import prefetch_model_gds
except ImportError:
return ("GDS not available for prefetching",)
results = []
successful_prefetches = 0
for name in names:
try:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
success = prefetch_model_gds(ckpt_path)
if success:
results.append(f"{name}")
successful_prefetches += 1
else:
results.append(f"{name} (prefetch failed)")
except Exception as e:
results.append(f"{name} (error: {str(e)[:50]})")
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
report += "\n".join(results)
return (report,)
class GDSStats(ComfyNodeABC):
"""
Node for displaying GDS statistics
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"refresh": ("BOOLEAN", {
"default": False,
"tooltip": "Refresh statistics."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("stats_report",)
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
FUNCTION = "get_stats"
CATEGORY = "utils/advanced"
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
OUTPUT_NODE = True
def get_stats(self, refresh: bool = False):
try:
from comfy.gds_loader import get_gds_stats
stats = get_gds_stats()
report = "=== GPUDirect Storage Statistics ===\n\n"
# Availability
report += f"GDS Available: {'' if stats['gds_available'] else ''}\n"
# Usage statistics
report += f"Total Loads: {stats['total_loads']}\n"
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
# Performance metrics
if stats['total_bytes_gds'] > 0:
gb_transferred = stats['total_bytes_gds'] / (1024**3)
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
# Configuration
config = stats.get('config', {})
if config:
report += "Configuration:\n"
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
return (report,)
except ImportError:
return ("GDS module not available",)
except Exception as e:
return (f"Error retrieving GDS stats: {str(e)}",)
# Node mappings
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderGDS": CheckpointLoaderGDS,
"ModelPrefetcher": ModelPrefetcher,
"GDSStats": GDSStats,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
"ModelPrefetcher": "Model Prefetcher",
"GDSStats": "GDS Statistics",
}

View File

@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=VideoContainer(format),
format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):

29
main.py
View File

@ -185,6 +185,35 @@ import comfyui_version
import app.logger
import hook_breaker_ac10a0
# Initialize GPUDirect Storage if enabled
def init_gds():
"""Initialize GPUDirect Storage based on CLI arguments"""
if hasattr(args, 'disable_gds') and args.disable_gds:
logging.info("GDS explicitly disabled via --disable-gds")
return
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
# GDS not explicitly requested, use auto-detection
return
if hasattr(args, 'enable_gds') and args.enable_gds:
from comfy.gds_loader import GDSConfig, init_gds as gds_init
config = GDSConfig(
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
max_concurrent_streams=getattr(args, 'gds_streams', 4),
prefetch_enabled=getattr(args, 'gds_prefetch', True),
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
show_stats=getattr(args, 'gds_stats', False)
)
gds_init(config)
# Initialize GDS
init_gds()
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)

View File

@ -2354,6 +2354,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_gds.py",
"nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25
comfyui-workflow-templates==0.7.51
comfyui-embedded-docs==0.3.1
torch
torchsde
@ -26,4 +26,4 @@ av>=14.2.0
kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
pydantic-settings~=2.0