Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-05-19 06:04:10 +09:00
commit 9ac185456f
50 changed files with 2760 additions and 684 deletions

View File

@ -110,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version - Builds a new release using the latest stable core version
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository - Weekly frontend updates are merged into the core repository
@ -302,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips ### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command: You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention``` ```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```

View File

@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
DEPRECATED: bool DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool] API_NODE: Optional[bool]
"""Flags a node as an API node.""" """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod @classmethod
@abstractmethod @abstractmethod

View File

@ -8,11 +8,7 @@ from typing import Callable, Tuple, List
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
# from diffusers.models.modeling_utils import ModelMixin
# from diffusers.loaders import FromOriginalModelMixin
# from diffusers.configuration_utils import ConfigMixin, register_to_config
from .music_log_mel import LogMelSpectrogram from .music_log_mel import LogMelSpectrogram
@ -259,7 +255,7 @@ class ResBlock1(torch.nn.Module):
self.convs1 = nn.ModuleList( self.convs1 = nn.ModuleList(
[ [
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -269,7 +265,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[0]), padding=get_padding(kernel_size, dilation[0]),
) )
), ),
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -279,7 +275,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, dilation[1]), padding=get_padding(kernel_size, dilation[1]),
) )
), ),
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -294,7 +290,7 @@ class ResBlock1(torch.nn.Module):
self.convs2 = nn.ModuleList( self.convs2 = nn.ModuleList(
[ [
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -304,7 +300,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1), padding=get_padding(kernel_size, 1),
) )
), ),
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -314,7 +310,7 @@ class ResBlock1(torch.nn.Module):
padding=get_padding(kernel_size, 1), padding=get_padding(kernel_size, 1),
) )
), ),
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
channels, channels,
channels, channels,
@ -366,7 +362,7 @@ class HiFiGANGenerator(nn.Module):
prod(upsample_rates) == hop_length prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}" ), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = weight_norm( self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
num_mels, num_mels,
upsample_initial_channel, upsample_initial_channel,
@ -386,7 +382,7 @@ class HiFiGANGenerator(nn.Module):
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1)) c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append( self.ups.append(
weight_norm( torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d( ops.ConvTranspose1d(
upsample_initial_channel // (2**i), upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)), upsample_initial_channel // (2 ** (i + 1)),
@ -421,7 +417,7 @@ class HiFiGANGenerator(nn.Module):
self.resblocks.append(ResBlock1(ch, k, d)) self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation() self.activation_post = post_activation()
self.conv_post = weight_norm( self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d( ops.Conv1d(
ch, ch,
1, 1,

View File

@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x return x
def WNConv1d(*args, **kwargs): def WNConv1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs)) return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
def WNConvTranspose1d(*args, **kwargs): def WNConvTranspose1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu": if activation == "elu":

View File

@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
guiding_frame_index=None, guiding_frame_index=None,
ref_latent=None,
control=None, control=None,
transformer_options={}, transformer_options={},
) -> Tensor: ) -> Tensor:
@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img) img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
if guiding_frame_index is not None: if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add img[:, : img_len] += add
img = img[:, : img_len] img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): def img_ids(self, x):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out return out

View File

@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
return c_skip, c return c_skip, c
class WanCamAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
super(WanCamAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class WanCamResidualBlock(nn.Module):
def __init__(self, dim, operation_settings={}):
super(WanCamResidualBlock, self).__init__()
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.relu = nn.ReLU(inplace=True)
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
class Head(nn.Module): class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@ -637,3 +691,92 @@ class VaceWanModel(WanModel):
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
class CameraWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='camera',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
in_dim_control_adapter=24,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
camera_conditions = None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -286,6 +286,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map return key_map

View File

@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel):
if guiding_frame_index is not None: if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
return out return out
def scale_latent_inpaint(self, latent_image, **kwargs): def scale_latent_inpaint(self, latent_image, **kwargs):
@ -1075,6 +1079,17 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out return out
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
camera_conditions = kwargs.get("camera_conditions", None)
if camera_conditions is not None:
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "vace" dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else: else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v" dit_config["model_type"] = "i2v"

View File

@ -308,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype) input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None: if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)

View File

@ -30,7 +30,7 @@ if RMSNorm is None:
def __init__( def __init__(
self, self,
normalized_shape, normalized_shape,
eps=None, eps=1e-6,
elementwise_affine=True, elementwise_affine=True,
device=None, device=None,
dtype=None, dtype=None,

View File

@ -451,7 +451,7 @@ class VAE:
self.latent_dim = 2 self.latent_dim = 2
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True self.disable_offload = True
self.extra_1d_channel = 16 self.extra_1d_channel = 16
else: else:

View File

@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=False, device=device) out = model_base.WAN21(self, image_to_video=False, device=device)
return out return out
class WAN21_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera",
"in_dim": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V): class WAN21_Vace(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:

View File

@ -43,3 +43,13 @@ class VideoInput(ABC):
components = self.get_components() components = self.get_components()
return components.images.shape[2], components.images.shape[1] return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents: def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames # Get video frames
frames = [] frames = []

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import io import io
import logging import logging
from typing import Optional from typing import Optional, Union
from comfy.utils import common_upscale from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec from comfy_api.util import VideoContainer, VideoCodec
@ -14,6 +15,7 @@ from comfy_api_nodes.apis.client import (
UploadRequest, UploadRequest,
UploadResponse, UploadResponse,
) )
from server import PromptServer
import numpy as np import numpy as np
@ -59,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s return s
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor. """Validates and casts a response to a torch.Tensor.
Args: Args:
@ -93,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
img = Image.open(io.BytesIO(img_data)) img = Image.open(io.BytesIO(img_data))
elif image_url: elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout) img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200: if img_response.status_code != 200:
raise ValueError("Failed to download the image") raise ValueError("Failed to download the image")
@ -314,7 +322,7 @@ def upload_file_to_comfyapi(
file_bytes_io: BytesIO, file_bytes_io: BytesIO,
filename: str, filename: str,
upload_mime_type: str, upload_mime_type: str,
auth_token: Optional[str] = None, auth_kwargs: Optional[dict[str,str]] = None,
) -> str: ) -> str:
""" """
Uploads a single file to ComfyUI API and returns its download URL. Uploads a single file to ComfyUI API and returns its download URL.
@ -323,7 +331,7 @@ def upload_file_to_comfyapi(
file_bytes_io: BytesIO object containing the file data. file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file. filename: The filename of the file.
upload_mime_type: MIME type of the file. upload_mime_type: MIME type of the file.
auth_token: Optional authentication token. auth_kwargs: Optional authentication token(s).
Returns: Returns:
The download URL for the uploaded file. The download URL for the uploaded file.
@ -337,7 +345,7 @@ def upload_file_to_comfyapi(
response_model=UploadResponse, response_model=UploadResponse,
), ),
request=request_object, request=request_object,
auth_token=auth_token, auth_kwargs=auth_kwargs,
) )
response: UploadResponse = operation.execute() response: UploadResponse = operation.execute()
@ -351,7 +359,7 @@ def upload_file_to_comfyapi(
def upload_video_to_comfyapi( def upload_video_to_comfyapi(
video: VideoInput, video: VideoInput,
auth_token: Optional[str] = None, auth_kwargs: Optional[dict[str,str]] = None,
container: VideoContainer = VideoContainer.MP4, container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264, codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None, max_duration: Optional[int] = None,
@ -362,7 +370,7 @@ def upload_video_to_comfyapi(
Args: Args:
video: VideoInput object (Comfy VIDEO type). video: VideoInput object (Comfy VIDEO type).
auth_token: Optional authentication token. auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4). container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264). codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
@ -390,7 +398,7 @@ def upload_video_to_comfyapi(
video_bytes_io.seek(0) video_bytes_io.seek(0)
return upload_file_to_comfyapi( return upload_file_to_comfyapi(
video_bytes_io, filename, upload_mime_type, auth_token video_bytes_io, filename, upload_mime_type, auth_kwargs
) )
@ -453,7 +461,7 @@ def audio_ndarray_to_bytesio(
def upload_audio_to_comfyapi( def upload_audio_to_comfyapi(
audio: AudioInput, audio: AudioInput,
auth_token: Optional[str] = None, auth_kwargs: Optional[dict[str,str]] = None,
container_format: str = "mp4", container_format: str = "mp4",
codec_name: str = "aac", codec_name: str = "aac",
mime_type: str = "audio/mp4", mime_type: str = "audio/mp4",
@ -465,7 +473,7 @@ def upload_audio_to_comfyapi(
Args: Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_token: Optional authentication token. auth_kwargs: Optional authentication token(s).
Returns: Returns:
The download URL for the uploaded audio file. The download URL for the uploaded audio file.
@ -477,11 +485,11 @@ def upload_audio_to_comfyapi(
audio_data_np, sample_rate, container_format, codec_name audio_data_np, sample_rate, container_format, codec_name
) )
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token) return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def upload_images_to_comfyapi( def upload_images_to_comfyapi(
image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
) -> list[str]: ) -> list[str]:
""" """
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
@ -490,7 +498,7 @@ def upload_images_to_comfyapi(
Args: Args:
image: Input torch.Tensor image. image: Input torch.Tensor image.
max_images: Maximum number of images to upload. max_images: Maximum number of images to upload.
auth_token: Optional authentication token. auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image. mime_type: Optional MIME type for the image.
""" """
# if batch, try to upload each file if max_images is greater than 0 # if batch, try to upload each file if max_images is greater than 0
@ -521,7 +529,7 @@ def upload_images_to_comfyapi(
response_model=UploadResponse, response_model=UploadResponse,
), ),
request=request_object, request=request_object,
auth_token=auth_token, auth_kwargs=auth_kwargs,
) )
response = operation.execute() response = operation.execute()

View File

@ -20,7 +20,8 @@ Usage Examples:
# 1. Create the API client # 1. Create the API client
api_client = ApiClient( api_client = ApiClient(
base_url="https://api.example.com", base_url="https://api.example.com",
api_key="your_api_key_here", auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0, timeout=30.0,
verify_ssl=True verify_ssl=True
) )
@ -93,15 +94,19 @@ from __future__ import annotations
import logging import logging
import time import time
import io import io
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable import socket
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum from enum import Enum
import json import json
import requests import requests
from urllib.parse import urljoin from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args from comfy.cli_args import args
from comfy import utils from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel) R = TypeVar("R", bound=BaseModel)
@ -110,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100 PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel): class EmptyRequest(BaseModel):
"""Base class for empty request bodies. """Base class for empty request bodies.
For GET requests, fields will be sent as query parameters.""" For GET requests, fields will be sent as query parameters."""
@ -140,20 +160,36 @@ class HttpMethod(str, Enum):
class ApiClient: class ApiClient:
""" """
Client for making HTTP requests to an API with authentication and error handling. Client for making HTTP requests to an API with authentication, error handling, and retry logic.
""" """
def __init__( def __init__(
self, self,
base_url: str, base_url: str,
api_key: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0, timeout: float = 3600.0,
verify_ssl: bool = True, verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
): ):
self.base_url = base_url self.base_url = base_url
self.api_key = api_key self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout self.timeout = timeout
self.verify_ssl = verify_ssl self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
def _generate_operation_id(self, path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
def _create_json_payload_args( def _create_json_payload_args(
self, self,
@ -201,11 +237,63 @@ class ApiClient:
"""Get headers for API requests, including authentication if available""" """Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.api_key: if self.auth_token:
headers["Authorization"] = f"Bearer {self.api_key}" headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers return headers
def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False
}
# First check basic internet connectivity using a reliable external site
try:
# Use a reliable external domain for checking basic connectivity
check_response = requests.get("https://www.google.com",
timeout=5.0,
verify=self.verify_ssl)
if check_response.status_code < 500:
results["internet_accessible"] = True
except (requests.RequestException, socket.error):
results["internet_accessible"] = False
results["is_local_issue"] = True
return results
# Now check API server connectivity
try:
# Extract domain from the target URL to do a simpler health check
parsed_url = urlparse(target_url)
api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
# Try to reach the API domain
api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
if api_response.status_code < 500:
results["api_accessible"] = True
else:
results["api_accessible"] = False
results["is_api_issue"] = True
except requests.RequestException:
results["api_accessible"] = False
# If we can reach the internet but not the API, it's an API issue
results["is_api_issue"] = True
return results
def request( def request(
self, self,
method: str, method: str,
@ -216,9 +304,10 @@ class ApiClient:
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json", content_type: str = "application/json",
multipart_parser: Callable = None, multipart_parser: Callable = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Make an HTTP request to the API Make an HTTP request to the API with automatic retries for transient errors.
Args: Args:
method: HTTP method (GET, POST, etc.) method: HTTP method (GET, POST, etc.)
@ -228,15 +317,18 @@ class ApiClient:
files: Files to upload files: Files to upload
headers: Additional headers headers: Additional headers
content_type: Content type of the request. Defaults to application/json. content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns: Returns:
Parsed JSON response Parsed JSON response
Raises: Raises:
requests.RequestException: If the request fails LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
""" """
url = urljoin(self.base_url, path) url = urljoin(self.base_url, path)
self.check_auth_token(self.api_key) self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers # Combine default headers with any provided headers
request_headers = self.get_headers() request_headers = self.get_headers()
if headers: if headers:
@ -260,6 +352,16 @@ class ApiClient:
else: else:
payload_args = self._create_json_payload_args(data, request_headers) payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]"
)
try: try:
response = requests.request( response = requests.request(
method=method, method=method,
@ -270,87 +372,365 @@ class ApiClient:
**payload_args, **payload_args,
) )
# Raise exception for error status codes # Check if we should retry based on status code
response.raise_for_status() if (response.status_code in self.retry_status_codes and
except requests.ConnectionError: retry_count < self.max_retries):
raise Exception(
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available." # Calculate delay with exponential backoff
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request failed with status {response.status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
) )
except requests.Timeout: time.sleep(delay)
raise Exception( return self.request(
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected." method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
) )
# Raise exception for error status codes
response.raise_for_status()
# Log successful response
response_content_to_log = response.content
try:
# Attempt to parse JSON for prettier logging, fallback to raw content
response_content_to_log = response.json()
except json.JSONDecodeError:
pass # Keep as bytes/str if not JSON
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, # Pass request details again for context in log
request_url=url,
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content=response_content_to_log
)
except requests.ConnectionError as e:
error_message = f"ConnectionError: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
error_message=error_message
)
# Only perform connectivity check if we've exhausted all retries
if retry_count >= self.max_retries:
# Check connectivity to determine if it's a local or API issue
connectivity = self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
elif connectivity["is_api_issue"]:
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
# If we haven't exhausted retries yet, retry the request
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Connection error: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# If we've exhausted retries and didn't identify the specific issue,
# raise a generic exception
final_error_message = (
f"Unable to connect to the API server after {self.max_retries} attempts. "
f"Please check your internet connection or try again later."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.Timeout as e:
error_message = f"Timeout: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, request_url=url,
error_message=error_message
)
# Retry timeouts if we haven't exhausted retries
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request timed out. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
final_error_message = (
f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
f"The server might be experiencing high load or the operation is taking longer than expected."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.HTTPError as e: except requests.HTTPError as e:
status_code = e.response.status_code if hasattr(e, "response") else None status_code = e.response.status_code if hasattr(e, "response") else None
error_message = f"HTTP Error: {str(e)}" original_error_message = f"HTTP Error: {str(e)}"
error_content_for_log = None
# Try to extract detailed error message from JSON response if hasattr(e, "response") and e.response is not None:
error_content_for_log = e.response.content
try: try:
if hasattr(e, "response") and e.response.content: error_content_for_log = e.response.json()
except json.JSONDecodeError:
pass
# Try to extract detailed error message from JSON response for user display
# but log the full error content.
user_display_error_message = original_error_message
try:
if hasattr(e, "response") and e.response is not None and e.response.content:
error_json = e.response.json() error_json = e.response.json()
if "error" in error_json and "message" in error_json["error"]: if "error" in error_json and "message" in error_json["error"]:
error_message = f"API Error: {error_json['error']['message']}" user_display_error_message = f"API Error: {error_json['error']['message']}"
if "type" in error_json["error"]: if "type" in error_json["error"]:
error_message += f" (Type: {error_json['error']['type']})" user_display_error_message += f" (Type: {error_json['error']['type']})"
elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
user_display_error_message = f"API Error: {json.dumps(error_json)}"
else: # Non-dict JSON error
user_display_error_message = f"API Error: {str(error_json)}"
except json.JSONDecodeError:
# If not JSON, use the raw content if it's not too long, or a summary
if hasattr(e, "response") and e.response is not None and e.response.content:
raw_content = e.response.content.decode(errors='ignore')
if len(raw_content) < 200: # Arbitrary limit for display
user_display_error_message = f"API Error (raw): {raw_content}"
else: else:
error_message = f"API Error: {error_json}" user_display_error_message = f"API Error (raw, status {status_code})"
except Exception as json_error:
# If we can't parse the JSON, fall back to the original error message request_logger.log_request_response(
logging.debug( operation_id=operation_id,
f"[DEBUG] Failed to parse error response: {str(json_error)}" request_method=method, request_url=url,
response_status_code=status_code,
response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
response_content=error_content_for_log,
error_message=original_error_message # Log the original exception string as error
) )
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})") logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response.content: if hasattr(e, "response") and e.response is not None and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}") logging.debug(f"[DEBUG] Response content: {e.response.content}")
# Retry if the status code is in our retry list and we haven't exhausted retries
if (status_code in self.retry_status_codes and
retry_count < self.max_retries):
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"HTTP error {status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# Specific error messages for common status codes for user display
if status_code == 401: if status_code == 401:
error_message = "Unauthorized: Please login first to use this node." user_display_error_message = "Unauthorized: Please login first to use this node."
if status_code == 402: elif status_code == 402:
error_message = "Payment Required: Please add credits to your account to use this node." user_display_error_message = "Payment Required: Please add credits to your account to use this node."
if status_code == 409: elif status_code == 409:
error_message = "There is a problem with your account. Please contact support@comfy.org. " user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
if status_code == 429: elif status_code == 429:
error_message = "Rate Limit Exceeded: Please try again later." user_display_error_message = "Rate Limit Exceeded: Please try again later."
raise Exception(error_message) # else, user_display_error_message remains as parsed from response or original HTTPError string
raise Exception(user_display_error_message) # Raise with the user-friendly message
# Parse and return JSON response # Parse and return JSON response
if response.content: if response.content:
return response.json() return response.json()
return {} return {}
def check_auth_token(self, auth_token): def check_auth(self, auth_token, comfy_api_key):
"""Verify that an auth token is present.""" """Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None: if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.") raise Exception("Unauthorized: Please login first to use this node.")
return auth_token return auth_token or comfy_api_key
@staticmethod @staticmethod
def upload_file( def upload_file(
upload_url: str, upload_url: str,
file: io.BytesIO | str, file: io.BytesIO | str,
content_type: str | None = None, content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
): ):
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects. """Upload a file to the API with retry logic.
Args: Args:
upload_url: The URL to upload to upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename) file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
mime_type: Optional mime type to set for the upload content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
""" """
headers = {} headers = {}
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
# Prepare the file data
if isinstance(file, io.BytesIO): if isinstance(file, io.BytesIO):
file.seek(0) # Ensure we're at the start of the file file.seek(0) # Ensure we're at the start of the file
data = file.read() data = file.read()
return requests.put(upload_url, data=data, headers=headers)
elif isinstance(file, str): elif isinstance(file, str):
with open(file, "rb") as f: with open(file, "rb") as f:
data = f.read() data = f.read()
return requests.put(upload_url, data=data, headers=headers) else:
raise ValueError("File must be either a BytesIO object or a file path string")
# Try the upload with retries
last_exception = None
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
# Log initial attempt (without full file data for brevity)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
)
for retry_attempt in range(max_retries + 1):
try:
response = requests.put(upload_url, data=data, headers=headers)
response.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url, # For context
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content="File uploaded successfully." # Or response.text if available
)
return response
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
last_exception = e
error_message_for_log = f"{type(e).__name__}: {str(e)}"
response_content_for_log = None
status_code_for_log = None
headers_for_log = None
if hasattr(e, 'response') and e.response is not None:
status_code_for_log = e.response.status_code
headers_for_log = dict(e.response.headers)
try:
response_content_for_log = e.response.json()
except json.JSONDecodeError:
response_content_for_log = e.response.content
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
response_status_code=status_code_for_log,
response_headers=headers_for_log,
response_content=response_content_for_log,
error_message=error_message_for_log
)
if retry_attempt < max_retries:
delay = retry_delay * (retry_backoff_factor ** retry_attempt)
logging.warning(
f"File upload failed: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
)
time.sleep(delay)
else:
break # Max retries reached
# If we've exhausted all retries, determine the final error type and raise
final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
try:
# Check basic internet connectivity
check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
if check_response.status_code >= 500: # Google itself has an issue (rare)
final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
f"(status {check_response.status_code}). Original error: {str(last_exception)}")
# Not raising LocalNetworkError here as Google itself might be down.
# If Google is reachable, the issue is likely with the upload server or a more specific local problem
# not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
# The original last_exception is probably most relevant.
except (requests.RequestException, socket.error) as conn_check_exc:
# Could not reach Google, likely a local network issue
final_error_message = (f"Failed to upload file due to network connectivity issues "
f"(cannot reach Google: {str(conn_check_exc)}). "
f"Original upload error: {str(last_exception)}")
request_logger.log_request_response( # Log final failure reason
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise LocalNetworkError(final_error_message) from last_exception
request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise Exception(final_error_message) from last_exception
class ApiEndpoint(Generic[T, R]): class ApiEndpoint(Generic[T, R]):
@ -392,10 +772,15 @@ class SynchronousOperation(Generic[T, R]):
files: Optional[Dict[str, Any]] = None, files: Optional[Dict[str, Any]] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
timeout: float = 604800.0, timeout: float = 604800.0,
verify_ssl: bool = True, verify_ssl: bool = True,
content_type: str = "application/json", content_type: str = "application/json",
multipart_parser: Callable = None, multipart_parser: Callable = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
): ):
self.endpoint = endpoint self.endpoint = endpoint
self.request = request self.request = request
@ -403,21 +788,33 @@ class SynchronousOperation(Generic[T, R]):
self.error = None self.error = None
self.api_base: str = api_base or args.comfy_api_base self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout self.timeout = timeout
self.verify_ssl = verify_ssl self.verify_ssl = verify_ssl
self.files = files self.files = files
self.content_type = content_type self.content_type = content_type
self.multipart_parser = multipart_parser self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
def execute(self, client: Optional[ApiClient] = None) -> R: def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the API operation using the provided client or create one""" """Execute the API operation using the provided client or create one with retry support"""
try: try:
# Create client if not provided # Create client if not provided
if client is None: if client is None:
client = ApiClient( client = ApiClient(
base_url=self.api_base, base_url=self.api_base,
api_key=self.auth_token, auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout, timeout=self.timeout,
verify_ssl=self.verify_ssl, verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
) )
# Convert request model to dict, but use None for EmptyRequest # Convert request model to dict, but use None for EmptyRequest
@ -431,11 +828,6 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(value, Enum): if isinstance(value, Enum):
request_dict[key] = value.value request_dict[key] = value.value
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
# Debug log for request # Debug log for request
logging.debug( logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
@ -443,7 +835,7 @@ class SynchronousOperation(Generic[T, R]):
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
# Make the request # Make the request with built-in retry
resp = client.request( resp = client.request(
method=self.endpoint.method.value, method=self.endpoint.method.value,
path=self.endpoint.path, path=self.endpoint.path,
@ -464,8 +856,18 @@ class SynchronousOperation(Generic[T, R]):
# Parse and return the response # Parse and return the response
return self._parse_response(resp) return self._parse_response(resp)
except LocalNetworkError as e:
# Propagate specific network error types
logging.error(f"[ERROR] Local network error: {str(e)}")
raise
except ApiServerError as e:
# Propagate API server errors
logging.error(f"[ERROR] API server error: {str(e)}")
raise
except Exception as e: except Exception as e:
logging.error(f"[DEBUG] API Exception: {str(e)}") logging.error(f"[ERROR] API Exception: {str(e)}")
raise Exception(str(e)) raise Exception(str(e))
def _parse_response(self, resp): def _parse_response(self, resp):
@ -499,22 +901,42 @@ class PollingOperation(Generic[T, R]):
failed_statuses: list, failed_statuses: list,
status_extractor: Callable[[R], str], status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] = None, progress_extractor: Callable[[R], float] = None,
result_url_extractor: Callable[[R], str] = None,
request: Optional[T] = None, request: Optional[T] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
): ):
self.poll_endpoint = poll_endpoint self.poll_endpoint = poll_endpoint
self.request = request self.request = request
self.api_base: str = api_base or args.comfy_api_base self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
# Polling configuration # Polling configuration
self.status_extractor = status_extractor or ( self.status_extractor = status_extractor or (
lambda x: getattr(x, "status", None) lambda x: getattr(x, "status", None)
) )
self.progress_extractor = progress_extractor self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses self.failed_statuses = failed_statuses
@ -528,12 +950,48 @@ class PollingOperation(Generic[T, R]):
if client is None: if client is None:
client = ApiClient( client = ApiClient(
base_url=self.api_base, base_url=self.api_base,
api_key=self.auth_token, auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
) )
return self._poll_until_complete(client) return self._poll_until_complete(client)
except LocalNetworkError as e:
# Provide clear message for local network issues
raise Exception(
f"Polling failed due to local network issues. Please check your internet connection. "
f"Details: {str(e)}"
) from e
except ApiServerError as e:
# Provide clear message for API server issues
raise Exception(
f"Polling failed due to API server issues. The service may be experiencing problems. "
f"Please try again later. Details: {str(e)}"
) from e
except Exception as e: except Exception as e:
raise Exception(f"Error during polling: {str(e)}") raise Exception(f"Error during polling: {str(e)}")
def _display_text_on_node(self, text: str):
"""Sends text to the client which will be displayed on the node in the UI"""
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int):
if not self.node_id:
return
if self.estimated_duration is not None:
estimated_time_remaining = max(
0, int(self.estimated_duration) - int(time_completed)
)
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
else:
message = f"Task in progress: {time_completed:.0f}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus: def _check_task_status(self, response: R) -> TaskStatus:
"""Check task status using the status extractor function""" """Check task status using the status extractor function"""
try: try:
@ -550,10 +1008,13 @@ class PollingOperation(Generic[T, R]):
def _poll_until_complete(self, client: ApiClient) -> R: def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete""" """Poll until the task is complete"""
poll_count = 0 poll_count = 0
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor: if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX) progress = utils.ProgressBar(PROGRESS_BAR_MAX)
while True: while poll_count < self.max_poll_attempts:
try: try:
poll_count += 1 poll_count += 1
logging.debug(f"[DEBUG] Polling attempt #{poll_count}") logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
@ -580,8 +1041,12 @@ class PollingOperation(Generic[T, R]):
data=request_dict, data=request_dict,
) )
# Successfully got a response, reset consecutive error count
consecutive_errors = 0
# Parse response # Parse response
response_obj = self.poll_endpoint.response_model.model_validate(resp) response_obj = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete # Check if task is complete
status = self._check_task_status(response_obj) status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}") logging.debug(f"[DEBUG] Task Status: {status}")
@ -593,7 +1058,15 @@ class PollingOperation(Generic[T, R]):
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED: if status == TaskStatus.COMPLETED:
logging.debug("[DEBUG] Task completed successfully") message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
else:
message = "Task completed successfully!"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj self.final_response = response_obj
if self.progress_extractor: if self.progress_extractor:
progress.update(100) progress.update(100)
@ -609,8 +1082,43 @@ class PollingOperation(Generic[T, R]):
logging.debug( logging.debug(
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
) )
for i in range(int(self.poll_interval)):
time_completed = (poll_count * self.poll_interval) + i
self._display_time_progress_on_node(time_completed)
time.sleep(1)
except (LocalNetworkError, ApiServerError) as e:
# For network-related errors, increment error count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
) from e
# Log the error but continue polling
logging.warning(
f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval) time.sleep(self.poll_interval)
except Exception as e: except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}") logging.error(f"[DEBUG] Polling error: {str(e)}")
raise Exception(f"Error while polling: {str(e)}") logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
f"The operation may still be running on the server but is taking longer than expected."
)

View File

@ -0,0 +1,125 @@
import os
import datetime
import json
import logging
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__':
# Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG)
# Mock folder_paths for direct execution if not running within ComfyUI full context
if not hasattr(folder_paths, 'get_temp_directory'):
class MockFolderPaths:
def get_temp_directory(self):
# Create a local temp dir for testing if needed
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
os.makedirs(p, exist_ok=True)
return p
folder_paths = MockFolderPaths()
log_request_response(
operation_id="test_operation_get",
request_method="GET",
request_url="https://api.example.com/test",
request_headers={"Authorization": "Bearer testtoken"},
request_params={"param1": "value1"},
response_status_code=200,
response_content={"message": "Success!"}
)
log_request_response(
operation_id="test_operation_post_error",
request_method="POST",
request_url="https://api.example.com/submit",
request_data={"key": "value", "nested": {"num": 123}},
error_message="Connection timed out"
)
log_request_response(
operation_id="test_binary_response",
request_method="GET",
request_url="https://api.example.com/image.png",
response_status_code=200,
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
)

View File

@ -1,5 +1,6 @@
import io import io
from inspect import cleandoc from inspect import cleandoc
from typing import Union
from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLStatus, BFLStatus,
@ -30,6 +31,7 @@ import requests
import torch import torch
import base64 import base64
import time import time
from server import PromptServer
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: torch.Tensor):
@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor):
def handle_bfl_synchronous_operation( def handle_bfl_synchronous_operation(
operation: SynchronousOperation, timeout_bfl_calls=360 operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
): ):
response_api: BFLFluxProGenerateResponse = operation.execute() response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated( return _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
) )
def _poll_until_generated(polling_url: str, timeout=360):
def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation: # used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time() start_time = time.time()
@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
request = requests.Request(method=HttpMethod.GET, url=polling_url) request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted? # NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True: while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
response = requests.Session().send(request.prepare()) response = requests.Session().send(request.prepare())
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()
if result["status"] == BFLStatus.ready: if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"] img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
img_response = requests.get(img_url) img_response = requests.get(img_url)
return process_image_response(img_response) return process_image_response(img_response)
elif result["status"] in [ elif result["status"] in [
@ -179,6 +196,8 @@ class FluxProUltraImageNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -211,7 +230,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
seed=0, seed=0,
image_prompt=None, image_prompt=None,
image_prompt_strength=0.1, image_prompt_strength=0.1,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
if image_prompt is None: if image_prompt is None:
@ -244,9 +263,9 @@ class FluxProUltraImageNode(ComfyNodeABC):
None if image_prompt is None else round(image_prompt_strength, 2) None if image_prompt is None else round(image_prompt_strength, 2)
), ),
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -319,6 +338,8 @@ class FluxProImageNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -337,7 +358,7 @@ class FluxProImageNode(ComfyNodeABC):
seed=0, seed=0,
image_prompt=None, image_prompt=None,
# image_prompt_strength=0.1, # image_prompt_strength=0.1,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
image_prompt = ( image_prompt = (
@ -361,9 +382,9 @@ class FluxProImageNode(ComfyNodeABC):
seed=seed, seed=seed,
image_prompt=image_prompt, image_prompt=image_prompt,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -457,10 +478,11 @@ class FluxProExpandNode(ComfyNodeABC):
}, },
), ),
}, },
"optional": { "optional": {},
},
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -482,7 +504,7 @@ class FluxProExpandNode(ComfyNodeABC):
steps: int, steps: int,
guidance: float, guidance: float,
seed=0, seed=0,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
image = convert_image_to_base64(image) image = convert_image_to_base64(image)
@ -506,9 +528,9 @@ class FluxProExpandNode(ComfyNodeABC):
seed=seed, seed=seed,
image=image, image=image,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -568,10 +590,11 @@ class FluxProFillNode(ComfyNodeABC):
}, },
), ),
}, },
"optional": { "optional": {},
},
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -590,14 +613,14 @@ class FluxProFillNode(ComfyNodeABC):
steps: int, steps: int,
guidance: float, guidance: float,
seed=0, seed=0,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
# prepare mask # prepare mask
mask = resize_mask_to_image(mask, image) mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask)) mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed # make sure image will have alpha channel removed
image = convert_image_to_base64(image[:,:,:,:3]) image = convert_image_to_base64(image[:, :, :, :3])
operation = SynchronousOperation( operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
@ -615,9 +638,9 @@ class FluxProFillNode(ComfyNodeABC):
image=image, image=image,
mask=mask, mask=mask,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -702,10 +725,11 @@ class FluxProCannyNode(ComfyNodeABC):
}, },
), ),
}, },
"optional": { "optional": {},
},
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -726,10 +750,10 @@ class FluxProCannyNode(ComfyNodeABC):
steps: int, steps: int,
guidance: float, guidance: float,
seed=0, seed=0,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
control_image = convert_image_to_base64(control_image[:,:,:,:3]) control_image = convert_image_to_base64(control_image[:, :, :, :3])
preprocessed_image = None preprocessed_image = None
# scale canny threshold between 0-500, to match BFL's API # scale canny threshold between 0-500, to match BFL's API
@ -763,9 +787,9 @@ class FluxProCannyNode(ComfyNodeABC):
canny_high_threshold=canny_high_threshold, canny_high_threshold=canny_high_threshold,
preprocessed_image=preprocessed_image, preprocessed_image=preprocessed_image,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)
@ -830,10 +854,11 @@ class FluxProDepthNode(ComfyNodeABC):
}, },
), ),
}, },
"optional": { "optional": {},
},
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -852,7 +877,7 @@ class FluxProDepthNode(ComfyNodeABC):
steps: int, steps: int,
guidance: float, guidance: float,
seed=0, seed=0,
auth_token=None, unique_id: Union[str, None] = None,
**kwargs, **kwargs,
): ):
control_image = convert_image_to_base64(control_image[:,:,:,:3]) control_image = convert_image_to_base64(control_image[:,:,:,:3])
@ -878,9 +903,9 @@ class FluxProDepthNode(ComfyNodeABC):
control_image=control_image, control_image=control_image,
preprocessed_image=preprocessed_image, preprocessed_image=preprocessed_image,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
output_image = handle_bfl_synchronous_operation(operation) output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,) return (output_image,)

View File

@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor, bytesio_to_image_tensor,
resize_mask_to_image, resize_mask_to_image,
) )
from server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",
@ -232,6 +233,19 @@ def download_and_process_images(image_urls):
return stacked_tensors return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(ComfyNodeABC): class IdeogramV1(ComfyNodeABC):
""" """
Generates images using the Ideogram V1 model. Generates images using the Ideogram V1 model.
@ -301,7 +315,11 @@ class IdeogramV1(ComfyNodeABC):
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -319,7 +337,8 @@ class IdeogramV1(ComfyNodeABC):
seed=0, seed=0,
negative_prompt="", negative_prompt="",
num_images=1, num_images=1,
auth_token=None, unique_id=None,
**kwargs,
): ):
# Determine the model based on turbo setting # Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
@ -345,7 +364,7 @@ class IdeogramV1(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
) )
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = operation.execute() response = operation.execute()
@ -358,6 +377,7 @@ class IdeogramV1(ComfyNodeABC):
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (download_and_process_images(image_urls),)
@ -454,7 +474,11 @@ class IdeogramV2(ComfyNodeABC):
# }, # },
#), #),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -475,7 +499,8 @@ class IdeogramV2(ComfyNodeABC):
negative_prompt="", negative_prompt="",
num_images=1, num_images=1,
color_palette="", color_palette="",
auth_token=None, unique_id=None,
**kwargs,
): ):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None) resolution = V1_V1_RES_MAP.get(resolution, None)
@ -515,7 +540,7 @@ class IdeogramV2(ComfyNodeABC):
color_palette=color_palette if color_palette else None, color_palette=color_palette if color_palette else None,
) )
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = operation.execute() response = operation.execute()
@ -528,6 +553,7 @@ class IdeogramV2(ComfyNodeABC):
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC): class IdeogramV3(ComfyNodeABC):
@ -614,7 +640,11 @@ class IdeogramV3(ComfyNodeABC):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -634,7 +664,8 @@ class IdeogramV3(ComfyNodeABC):
seed=0, seed=0,
num_images=1, num_images=1,
rendering_speed="BALANCED", rendering_speed="BALANCED",
auth_token=None, unique_id=None,
**kwargs,
): ):
# Check if both image and mask are provided for editing mode # Check if both image and mask are provided for editing mode
if image is not None and mask is not None: if image is not None and mask is not None:
@ -698,7 +729,7 @@ class IdeogramV3(ComfyNodeABC):
"mask": mask_binary, "mask": mask_binary,
}, },
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
elif image is not None or mask is not None: elif image is not None or mask is not None:
@ -739,7 +770,7 @@ class IdeogramV3(ComfyNodeABC):
response_model=IdeogramGenerateResponse, response_model=IdeogramGenerateResponse,
), ),
request=gen_request, request=gen_request,
auth_token=auth_token, auth_kwargs=kwargs,
) )
# Execute the operation and process response # Execute the operation and process response
@ -753,6 +784,7 @@ class IdeogramV3(ComfyNodeABC):
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),) return (download_and_process_images(image_urls),)
@ -767,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"IdeogramV2": "Ideogram V2", "IdeogramV2": "Ideogram V2",
"IdeogramV3": "Ideogram V3", "IdeogramV3": "Ideogram V3",
} }

View File

@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
from __future__ import annotations from __future__ import annotations
from typing import Optional, TypeVar, Any from typing import Optional, TypeVar, Any
from collections.abc import Callable
import math import math
import logging import logging
@ -86,6 +87,15 @@ MAX_PROMPT_LENGTH_IMAGE_GEN = 500
MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200 MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
MAX_PROMPT_LENGTH_LIP_SYNC = 120 MAX_PROMPT_LENGTH_LIP_SYNC = 120
# TODO: adjust based on tests
AVERAGE_DURATION_T2V = 319 # 319,
AVERAGE_DURATION_I2V = 164 # 164,
AVERAGE_DURATION_LIP_SYNC = 120
AVERAGE_DURATION_VIRTUAL_TRY_ON = 19 # 19,
AVERAGE_DURATION_IMAGE_GEN = 32
AVERAGE_DURATION_VIDEO_EFFECTS = 320
AVERAGE_DURATION_VIDEO_EXTEND = 320
R = TypeVar("R") R = TypeVar("R")
@ -95,7 +105,13 @@ class KlingApiError(Exception):
pass pass
def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R: def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation( return PollingOperation(
poll_endpoint=api_endpoint, poll_endpoint=api_endpoint,
@ -108,7 +124,10 @@ def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R
if response.data and response.data.task_status if response.data and response.data.task_status
else None else None
), ),
auth_token=auth_token, auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
).execute() ).execute()
@ -227,7 +246,9 @@ def get_camera_control_input_config(
def get_video_from_response(response) -> KlingVideoResult: def get_video_from_response(response) -> KlingVideoResult:
"""Returns the first video object from the Kling video generation task result.""" """Returns the first video object from the Kling video generation task result.
Will raise an error if the response is not valid.
"""
video = response.data.task_result.videos[0] video = response.data.task_result.videos[0]
logging.info( logging.info(
"Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url "Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
@ -235,12 +256,37 @@ def get_video_from_response(response) -> KlingVideoResult:
return video return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Kling video generation task result.
Will not raise an error if the response is not valid.
"""
if response and is_valid_video_response(response):
return str(get_video_from_response(response).url)
else:
return None
def get_images_from_response(response) -> list[KlingImageResult]: def get_images_from_response(response) -> list[KlingImageResult]:
"""Returns the list of image objects from the Kling image generation task result.
Will raise an error if the response is not valid.
"""
images = response.data.task_result.images images = response.data.task_result.images
logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images) logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
return images return images
def get_images_urls_from_response(response) -> Optional[str]:
"""Returns the list of image urls from the Kling image generation task result.
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
"""
if response and is_valid_image_response(response):
images = get_images_from_response(response)
image_urls = [str(image.url) for image in images]
return "\n".join(image_urls)
else:
return None
def video_result_to_node_output( def video_result_to_node_output(
video: KlingVideoResult, video: KlingVideoResult,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[VideoFromFile, str, str]:
@ -312,6 +358,7 @@ class KlingCameraControls(KlingNodeBase):
RETURN_TYPES = ("CAMERA_CONTROL",) RETURN_TYPES = ("CAMERA_CONTROL",)
RETURN_NAMES = ("camera_control",) RETURN_NAMES = ("camera_control",)
FUNCTION = "main" FUNCTION = "main"
API_NODE = False # This is just a helper node, it doesn't make an API call
@classmethod @classmethod
def VALIDATE_INPUTS( def VALIDATE_INPUTS(
@ -418,22 +465,31 @@ class KlingTextToVideoNode(KlingNodeBase):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Text to Video Node" DESCRIPTION = "Kling Text to Video Node"
def get_response(self, task_id: str, auth_token: str) -> KlingText2VideoResponse: def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingText2VideoResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingText2VideoResponse, response_model=KlingText2VideoResponse,
), ),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
node_id=node_id,
) )
def api_call( def api_call(
@ -446,7 +502,8 @@ class KlingTextToVideoNode(KlingNodeBase):
camera_control: Optional[KlingCameraControl] = None, camera_control: Optional[KlingCameraControl] = None,
model_name: Optional[str] = None, model_name: Optional[str] = None,
duration: Optional[str] = None, duration: Optional[str] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
if model_name is None: if model_name is None:
@ -468,14 +525,16 @@ class KlingTextToVideoNode(KlingNodeBase):
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
camera_control=camera_control, camera_control=camera_control,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -522,7 +581,11 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
@ -534,7 +597,8 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
cfg_scale: float, cfg_scale: float,
aspect_ratio: str, aspect_ratio: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: Optional[KlingCameraControl] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
return super().api_call( return super().api_call(
model_name=KlingVideoGenModelName.kling_v1, model_name=KlingVideoGenModelName.kling_v1,
@ -545,7 +609,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
camera_control=camera_control, camera_control=camera_control,
auth_token=auth_token, **kwargs,
) )
@ -604,22 +668,31 @@ class KlingImage2VideoNode(KlingNodeBase):
enum_type=KlingVideoGenDuration, enum_type=KlingVideoGenDuration,
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Image to Video Node" DESCRIPTION = "Kling Image to Video Node"
def get_response(self, task_id: str, auth_token: str) -> KlingImage2VideoResponse: def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingImage2VideoResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=KlingImage2VideoRequest, request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse, response_model=KlingImage2VideoResponse,
), ),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
node_id=node_id,
) )
def api_call( def api_call(
@ -634,7 +707,8 @@ class KlingImage2VideoNode(KlingNodeBase):
duration: str, duration: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: Optional[KlingCameraControl] = None,
end_frame: Optional[torch.Tensor] = None, end_frame: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
validate_input_image(start_frame) validate_input_image(start_frame)
@ -662,18 +736,19 @@ class KlingImage2VideoNode(KlingNodeBase):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
mode=KlingVideoGenMode(mode), mode=KlingVideoGenMode(mode),
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
duration=KlingVideoGenDuration(duration), duration=KlingVideoGenDuration(duration),
camera_control=camera_control, camera_control=camera_control,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -723,7 +798,11 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
@ -736,7 +815,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
cfg_scale: float, cfg_scale: float,
aspect_ratio: str, aspect_ratio: str,
camera_control: KlingCameraControl, camera_control: KlingCameraControl,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
return super().api_call( return super().api_call(
model_name=KlingVideoGenModelName.kling_v1_5, model_name=KlingVideoGenModelName.kling_v1_5,
@ -748,7 +828,8 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
camera_control=camera_control, camera_control=camera_control,
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
@ -816,7 +897,11 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
@ -830,7 +915,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
cfg_scale: float, cfg_scale: float,
aspect_ratio: str, aspect_ratio: str,
mode: str, mode: str,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
mode mode
@ -845,7 +931,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=duration, duration=duration,
end_frame=end_frame, end_frame=end_frame,
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
@ -875,22 +962,31 @@ class KlingVideoExtendNode(KlingNodeBase):
IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
def get_response(self, task_id: str, auth_token: str) -> KlingVideoExtendResponse: def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoExtendResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIDEO_EXTEND}/{task_id}", path=f"{PATH_VIDEO_EXTEND}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingVideoExtendResponse, response_model=KlingVideoExtendResponse,
), ),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
node_id=node_id,
) )
def api_call( def api_call(
@ -899,7 +995,8 @@ class KlingVideoExtendNode(KlingNodeBase):
negative_prompt: str, negative_prompt: str,
cfg_scale: float, cfg_scale: float,
video_id: str, video_id: str,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
@ -915,14 +1012,16 @@ class KlingVideoExtendNode(KlingNodeBase):
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
video_id=video_id, video_id=video_id,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -935,15 +1034,20 @@ class KlingVideoEffectsBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration") RETURN_NAMES = ("VIDEO", "video_id", "duration")
def get_response(self, task_id: str, auth_token: str) -> KlingVideoEffectsResponse: def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoEffectsResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIDEO_EFFECTS}/{task_id}", path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingVideoEffectsResponse, response_model=KlingVideoEffectsResponse,
), ),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
node_id=node_id,
) )
def api_call( def api_call(
@ -955,7 +1059,8 @@ class KlingVideoEffectsBase(KlingNodeBase):
image_1: torch.Tensor, image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None, image_2: Optional[torch.Tensor] = None,
mode: Optional[KlingVideoGenMode] = None, mode: Optional[KlingVideoGenMode] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
if dual_character: if dual_character:
request_input_field = KlingDualCharacterEffectInput( request_input_field = KlingDualCharacterEffectInput(
@ -985,14 +1090,16 @@ class KlingVideoEffectsBase(KlingNodeBase):
effect_scene=effect_scene, effect_scene=effect_scene,
input=request_input_field, input=request_input_field,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -1033,7 +1140,11 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
enum_type=KlingVideoGenDuration, enum_type=KlingVideoGenDuration,
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite." DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite."
@ -1048,7 +1159,8 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
model_name: KlingCharacterEffectModelName, model_name: KlingCharacterEffectModelName,
mode: KlingVideoGenMode, mode: KlingVideoGenMode,
duration: KlingVideoGenDuration, duration: KlingVideoGenDuration,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
video, _, duration = super().api_call( video, _, duration = super().api_call(
dual_character=True, dual_character=True,
@ -1058,10 +1170,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
duration=duration, duration=duration,
image_1=image_left, image_1=image_left,
image_2=image_right, image_2=image_right,
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
return video, duration return video, duration
class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
"""Kling Single Image Video Effect Node""" """Kling Single Image Video Effect Node"""
@ -1094,7 +1208,11 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
enum_type=KlingVideoGenDuration, enum_type=KlingVideoGenDuration,
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
@ -1105,7 +1223,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
effect_scene: KlingSingleImageEffectsScene, effect_scene: KlingSingleImageEffectsScene,
model_name: KlingSingleImageEffectModelName, model_name: KlingSingleImageEffectModelName,
duration: KlingVideoGenDuration, duration: KlingVideoGenDuration,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
return super().api_call( return super().api_call(
dual_character=False, dual_character=False,
@ -1113,7 +1232,8 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
model_name=model_name, model_name=model_name,
duration=duration, duration=duration,
image_1=image, image_1=image,
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
@ -1131,16 +1251,21 @@ class KlingLipSyncBase(KlingNodeBase):
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
) )
def get_response(self, task_id: str, auth_token: str) -> KlingLipSyncResponse: def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingLipSyncResponse:
"""Polls the Kling API endpoint until the task reaches a terminal state.""" """Polls the Kling API endpoint until the task reaches a terminal state."""
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_LIP_SYNC}/{task_id}", path=f"{PATH_LIP_SYNC}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingLipSyncResponse, response_model=KlingLipSyncResponse,
), ),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
node_id=node_id,
) )
def api_call( def api_call(
@ -1152,18 +1277,19 @@ class KlingLipSyncBase(KlingNodeBase):
text: Optional[str] = None, text: Optional[str] = None,
voice_speed: Optional[float] = None, voice_speed: Optional[float] = None,
voice_id: Optional[str] = None, voice_id: Optional[str] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[VideoFromFile, str, str]:
if text: if text:
self.validate_text(text) self.validate_text(text)
# Upload video to Comfy API and get download URL # Upload video to Comfy API and get download URL
video_url = upload_video_to_comfyapi(video, auth_token) video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
logging.info("Uploaded video to Comfy API. URL: %s", video_url) logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL # Upload the audio file to Comfy API and get download URL
if audio: if audio:
audio_url = upload_audio_to_comfyapi(audio, auth_token) audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else: else:
audio_url = None audio_url = None
@ -1187,14 +1313,16 @@ class KlingLipSyncBase(KlingNodeBase):
voice_id=voice_id, voice_id=voice_id,
), ),
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -1217,7 +1345,11 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
enum_type=KlingLipSyncVoiceLanguage, enum_type=KlingLipSyncVoiceLanguage,
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file." DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."
@ -1227,14 +1359,16 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
video: VideoInput, video: VideoInput,
audio: AudioInput, audio: AudioInput,
voice_language: str, voice_language: str,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
return super().api_call( return super().api_call(
video=video, video=video,
audio=audio, audio=audio,
voice_language=voice_language, voice_language=voice_language,
mode="audio2video", mode="audio2video",
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
@ -1323,7 +1457,11 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt." DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."
@ -1334,7 +1472,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
text: str, text: str,
voice: str, voice: str,
voice_speed: float, voice_speed: float,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
return super().api_call( return super().api_call(
@ -1344,7 +1483,8 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
voice_id=voice_id, voice_id=voice_id,
voice_speed=voice_speed, voice_speed=voice_speed,
mode="text2video", mode="text2video",
auth_token=auth_token, unique_id=unique_id,
**kwargs,
) )
@ -1381,22 +1521,29 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
enum_type=KlingVirtualTryOnModelName, enum_type=KlingVirtualTryOnModelName,
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human." DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
def get_response( def get_response(
self, task_id: str, auth_token: Optional[str] = None self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVirtualTryOnResponse: ) -> KlingVirtualTryOnResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingVirtualTryOnResponse, response_model=KlingVirtualTryOnResponse,
), ),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
node_id=node_id,
) )
def api_call( def api_call(
@ -1404,7 +1551,8 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
human_image: torch.Tensor, human_image: torch.Tensor,
cloth_image: torch.Tensor, cloth_image: torch.Tensor,
model_name: KlingVirtualTryOnModelName, model_name: KlingVirtualTryOnModelName,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
@ -1418,14 +1566,16 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
cloth_image=tensor_to_base64_string(cloth_image), cloth_image=tensor_to_base64_string(cloth_image),
model_name=model_name, model_name=model_name,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response) validate_image_result_response(final_response)
images = get_images_from_response(final_response) images = get_images_from_response(final_response)
@ -1493,22 +1643,32 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
"optional": { "optional": {
"image": (IO.IMAGE, {}), "image": (IO.IMAGE, {}),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
def get_response( def get_response(
self, task_id: str, auth_token: Optional[str] = None self,
task_id: str,
auth_kwargs: Optional[dict[str, str]],
node_id: Optional[str] = None,
) -> KlingImageGenerationsResponse: ) -> KlingImageGenerationsResponse:
return poll_until_finished( return poll_until_finished(
auth_token, auth_kwargs,
ApiEndpoint( ApiEndpoint(
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=KlingImageGenerationsResponse, response_model=KlingImageGenerationsResponse,
), ),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
node_id=node_id,
) )
def api_call( def api_call(
@ -1522,7 +1682,8 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
n: int, n: int,
aspect_ratio: KlingImageGenAspectRatio, aspect_ratio: KlingImageGenAspectRatio,
image: Optional[torch.Tensor] = None, image: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None, unique_id: Optional[str] = None,
**kwargs,
): ):
self.validate_prompt(prompt, negative_prompt) self.validate_prompt(prompt, negative_prompt)
@ -1547,14 +1708,16 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
n=n, n=n,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
task_creation_response = initial_operation.execute() task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_token) final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response) validate_image_result_response(final_response)
images = get_images_from_response(final_response) images = get_images_from_response(final_response)

View File

@ -1,4 +1,6 @@
from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import ( from comfy_api_nodes.apis.luma_api import (
@ -34,11 +36,20 @@ from comfy_api_nodes.apinode_utils import (
process_image_response, process_image_response,
validate_string, validate_string,
) )
from server import PromptServer
import requests import requests
import torch import torch
from io import BytesIO from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC): class LumaReferenceNode(ComfyNodeABC):
""" """
@ -201,6 +212,8 @@ class LumaImageGenerationNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -214,7 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None, image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None, style_image: torch.Tensor = None,
character_image: torch.Tensor = None, character_image: torch.Tensor = None,
auth_token=None, unique_id: str = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
@ -222,19 +235,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
api_image_ref = None api_image_ref = None
if image_luma_ref is not None: if image_luma_ref is not None:
api_image_ref = self._convert_luma_refs( api_image_ref = self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_token=auth_token image_luma_ref, max_refs=4, auth_kwargs=kwargs,
) )
# handle style_luma_ref # handle style_luma_ref
api_style_ref = None api_style_ref = None
if style_image is not None: if style_image is not None:
api_style_ref = self._convert_style_image( api_style_ref = self._convert_style_image(
style_image, weight=style_image_weight, auth_token=auth_token style_image, weight=style_image_weight, auth_kwargs=kwargs,
) )
# handle character_ref images # handle character_ref images
character_ref = None character_ref = None
if character_image is not None: if character_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = upload_images_to_comfyapi(
character_image, max_images=4, auth_token=auth_token character_image, max_images=4, auth_kwargs=kwargs,
) )
character_ref = LumaCharacterRef( character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls) identity0=LumaImageIdentity(images=download_urls)
@ -255,7 +268,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
style_ref=api_style_ref, style_ref=api_style_ref,
character_ref=character_ref, character_ref=character_ref,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = operation.execute()
@ -269,7 +282,9 @@ class LumaImageGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed], completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
auth_token=auth_token, result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = operation.execute()
@ -278,13 +293,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
return (img,) return (img,)
def _convert_luma_refs( def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
): ):
luma_urls = [] luma_urls = []
ref_count = 0 ref_count = 0
for ref in luma_ref.refs: for ref in luma_ref.refs:
download_urls = upload_images_to_comfyapi( download_urls = upload_images_to_comfyapi(
ref.image, max_images=1, auth_token=auth_token ref.image, max_images=1, auth_kwargs=auth_kwargs
) )
luma_urls.append(download_urls[0]) luma_urls.append(download_urls[0])
ref_count += 1 ref_count += 1
@ -293,12 +308,12 @@ class LumaImageGenerationNode(ComfyNodeABC):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
def _convert_style_image( def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_token=None self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
): ):
chain = LumaReferenceChain( chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight) first_ref=LumaReference(image=style_image, weight=weight)
) )
return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token) return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC): class LumaImageModifyNode(ComfyNodeABC):
@ -350,6 +365,8 @@ class LumaImageModifyNode(ComfyNodeABC):
"optional": {}, "optional": {},
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -360,12 +377,12 @@ class LumaImageModifyNode(ComfyNodeABC):
image: torch.Tensor, image: torch.Tensor,
image_weight: float, image_weight: float,
seed, seed,
auth_token=None, unique_id: str = None,
**kwargs, **kwargs,
): ):
# first, upload image # first, upload image
download_urls = upload_images_to_comfyapi( download_urls = upload_images_to_comfyapi(
image, max_images=1, auth_token=auth_token image, max_images=1, auth_kwargs=kwargs,
) )
image_url = download_urls[0] image_url = download_urls[0]
# next, make Luma call with download url provided # next, make Luma call with download url provided
@ -383,7 +400,7 @@ class LumaImageModifyNode(ComfyNodeABC):
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
), ),
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = operation.execute()
@ -397,7 +414,9 @@ class LumaImageModifyNode(ComfyNodeABC):
completed_statuses=[LumaState.completed], completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
auth_token=auth_token, result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = operation.execute()
@ -470,6 +489,8 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -483,7 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool, loop: bool,
seed, seed,
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
auth_token=None, unique_id: str = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
@ -506,10 +527,13 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop=loop, loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}", path=f"/proxy/luma/generations/{response_api.id}",
@ -520,7 +544,10 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed], completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
auth_token=auth_token, result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = operation.execute()
@ -594,6 +621,8 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -608,14 +637,14 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
auth_token=None, unique_id: str = None,
**kwargs, **kwargs,
): ):
if first_image is None and last_image is None: if first_image is None and last_image is None:
raise Exception( raise Exception(
"At least one of first_image and last_image requires an input." "At least one of first_image and last_image requires an input."
) )
keyframes = self._convert_to_keyframes(first_image, last_image, auth_token) keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
@ -636,10 +665,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
keyframes=keyframes, keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api: LumaGeneration = operation.execute() response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation( operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}", path=f"/proxy/luma/generations/{response_api.id}",
@ -650,7 +682,10 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed], completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed], failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
auth_token=auth_token, result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
) )
response_poll = operation.execute() response_poll = operation.execute()
@ -661,7 +696,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
self, self,
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
auth_token=None, auth_kwargs: Optional[dict[str,str]] = None,
): ):
if first_image is None and last_image is None: if first_image is None and last_image is None:
return None return None
@ -669,12 +704,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
frame1 = None frame1 = None
if first_image is not None: if first_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = upload_images_to_comfyapi(
first_image, max_images=1, auth_token=auth_token first_image, max_images=1, auth_kwargs=auth_kwargs,
) )
frame0 = LumaImageReference(type="image", url=download_urls[0]) frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None: if last_image is not None:
download_urls = upload_images_to_comfyapi( download_urls = upload_images_to_comfyapi(
last_image, max_images=1, auth_token=auth_token last_image, max_images=1, auth_kwargs=auth_kwargs,
) )
frame1 = LumaImageReference(type="image", url=download_urls[0]) frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1) return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@ -1,3 +1,7 @@
from typing import Union
import logging
import torch
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string, validate_string,
) )
from server import PromptServer
import torch
import logging
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
class MinimaxTextToVideoNode: class MinimaxTextToVideoNode:
""" """
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
""" """
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -67,6 +74,8 @@ class MinimaxTextToVideoNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -84,7 +93,8 @@ class MinimaxTextToVideoNode:
model="T2V-01", model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo subject: torch.Tensor=None, # used for SubjectToVideo
auth_token=None, unique_id: Union[str, None]=None,
**kwargs,
): ):
''' '''
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments. Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
@ -94,12 +104,12 @@ class MinimaxTextToVideoNode:
# upload image, if passed in # upload image, if passed in
image_url = None image_url = None
if image is not None: if image is not None:
image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0] image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None subject_reference = None
if subject is not None: if subject is not None:
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0] subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
subject_reference = [SubjectReferenceItem(image=subject_url)] subject_reference = [SubjectReferenceItem(image=subject_url)]
@ -118,7 +128,7 @@ class MinimaxTextToVideoNode:
subject_reference=subject_reference, subject_reference=subject_reference,
prompt_optimizer=None, prompt_optimizer=None,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = video_generate_operation.execute() response = video_generate_operation.execute()
@ -137,7 +147,9 @@ class MinimaxTextToVideoNode:
completed_statuses=["Success"], completed_statuses=["Success"],
failed_statuses=["Fail"], failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value, status_extractor=lambda x: x.status.value,
auth_token=auth_token, estimated_duration=self.AVERAGE_DURATION,
node_id=unique_id,
auth_kwargs=kwargs,
) )
task_result = video_generate_operation.execute() task_result = video_generate_operation.execute()
@ -153,7 +165,7 @@ class MinimaxTextToVideoNode:
query_params={"file_id": int(file_id)}, query_params={"file_id": int(file_id)},
), ),
request=EmptyRequest(), request=EmptyRequest(),
auth_token=auth_token, auth_kwargs=kwargs,
) )
file_result = file_retrieve_operation.execute() file_result = file_retrieve_operation.execute()
@ -163,6 +175,12 @@ class MinimaxTextToVideoNode:
f"No video was found in the response. Full response: {file_result.model_dump()}" f"No video was found in the response. Full response: {file_result.model_dump()}"
) )
logging.info(f"Generated video URL: {file_url}") logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = download_url_to_bytesio(file_url) video_io = download_url_to_bytesio(file_url)
if video_io is None: if video_io is None:
@ -177,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
""" """
AVERAGE_DURATION = I2V_AVERAGE_DURATION
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -221,6 +241,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -237,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
""" """
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -279,6 +303,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }

View File

@ -93,7 +93,11 @@ class OpenAIDalle2(ComfyNodeABC):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -110,7 +114,8 @@ class OpenAIDalle2(ComfyNodeABC):
mask=None, mask=None,
n=1, n=1,
size="1024x1024", size="1024x1024",
auth_token=None, unique_id=None,
**kwargs
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
model = "dall-e-2" model = "dall-e-2"
@ -168,12 +173,12 @@ class OpenAIDalle2(ComfyNodeABC):
else None else None
), ),
content_type=content_type, content_type=content_type,
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = operation.execute() response = operation.execute()
img_tensor = validate_and_cast_response(response) img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)
@ -236,7 +241,11 @@ class OpenAIDalle3(ComfyNodeABC):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -252,7 +261,8 @@ class OpenAIDalle3(ComfyNodeABC):
style="natural", style="natural",
quality="standard", quality="standard",
size="1024x1024", size="1024x1024",
auth_token=None, unique_id=None,
**kwargs
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
model = "dall-e-3" model = "dall-e-3"
@ -273,12 +283,12 @@ class OpenAIDalle3(ComfyNodeABC):
style=style, style=style,
seed=seed, seed=seed,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = operation.execute() response = operation.execute()
img_tensor = validate_and_cast_response(response) img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)
@ -366,7 +376,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
}, },
), ),
}, },
"hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
} }
RETURN_TYPES = (IO.IMAGE,) RETURN_TYPES = (IO.IMAGE,)
@ -385,7 +399,8 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask=None, mask=None,
n=1, n=1,
size="1024x1024", size="1024x1024",
auth_token=None, unique_id=None,
**kwargs
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1" model = "gpt-image-1"
@ -462,12 +477,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
), ),
files=files if files else None, files=files if files else None,
content_type=content_type, content_type=content_type,
auth_token=auth_token, auth_kwargs=kwargs,
) )
response = operation.execute() response = operation.execute()
img_tensor = validate_and_cast_response(response) img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,) return (img_tensor,)

View File

@ -3,6 +3,7 @@ Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference Pika API docs: https://pika-827374fb.mintlify.app/api-reference
""" """
from __future__ import annotations
import io import io
from typing import Optional, TypeVar from typing import Optional, TypeVar
@ -120,7 +121,10 @@ class PikaNodeBase(ComfyNodeABC):
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
def poll_for_task_status( def poll_for_task_status(
self, task_id: str, auth_token: str self,
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> PikaGenerateResponse: ) -> PikaGenerateResponse:
polling_operation = PollingOperation( polling_operation = PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
@ -139,20 +143,26 @@ class PikaNodeBase(ComfyNodeABC):
progress_extractor=lambda response: ( progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None response.progress if hasattr(response, "progress") else None
), ),
auth_token=auth_token, auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (
response.url if hasattr(response, "url") else None
),
node_id=node_id,
estimated_duration=60
) )
return polling_operation.execute() return polling_operation.execute()
def execute_task( def execute_task(
self, self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse], initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_token: Optional[str] = None, auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed. """Executes the initial operation then polls for the task status until it is completed.
Args: Args:
initial_operation: The initial operation to execute. initial_operation: The initial operation to execute.
auth_token: The authentication token to use for the API call. auth_kwargs: The authentication token(s) to use for the API call.
Returns: Returns:
A tuple containing the video file as a VIDEO output. A tuple containing the video file as a VIDEO output.
@ -164,7 +174,7 @@ class PikaNodeBase(ComfyNodeABC):
raise PikaApiError(error_msg) raise PikaApiError(error_msg)
task_id = initial_response.video_id task_id = initial_response.video_id
final_response = self.poll_for_task_status(task_id, auth_token) final_response = self.poll_for_task_status(task_id, auth_kwargs)
if not is_valid_video_response(final_response): if not is_valid_video_response(final_response):
error_msg = ( error_msg = (
f"Pika task {task_id} succeeded but no video data found in response." f"Pika task {task_id} succeeded but no video data found in response."
@ -193,6 +203,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
@ -206,7 +217,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
seed: int, seed: int,
resolution: str, resolution: str,
duration: int, duration: int,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
# Convert image to BytesIO # Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
@ -233,10 +245,10 @@ class PikaImageToVideoV2_2(PikaNodeBase):
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaTextToVideoNodeV2_2(PikaNodeBase): class PikaTextToVideoNodeV2_2(PikaNodeBase):
@ -259,6 +271,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -272,7 +286,8 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
resolution: str, resolution: str,
duration: int, duration: int,
aspect_ratio: float, aspect_ratio: float,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
@ -289,11 +304,11 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
duration=duration, duration=duration,
aspectRatio=aspect_ratio, aspectRatio=aspect_ratio,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
content_type="application/x-www-form-urlencoded", content_type="application/x-www-form-urlencoded",
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaScenesV2_2(PikaNodeBase): class PikaScenesV2_2(PikaNodeBase):
@ -336,6 +351,8 @@ class PikaScenesV2_2(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -350,12 +367,13 @@ class PikaScenesV2_2(PikaNodeBase):
duration: int, duration: int,
ingredients_mode: str, ingredients_mode: str,
aspect_ratio: float, aspect_ratio: float,
unique_id: str,
image_ingredient_1: Optional[torch.Tensor] = None, image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None, image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None, image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None, image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None, image_ingredient_5: Optional[torch.Tensor] = None,
auth_token: Optional[str] = None, **kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
# Convert all passed images to BytesIO # Convert all passed images to BytesIO
all_image_bytes_io = [] all_image_bytes_io = []
@ -396,10 +414,10 @@ class PikaScenesV2_2(PikaNodeBase):
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikAdditionsNode(PikaNodeBase): class PikAdditionsNode(PikaNodeBase):
@ -434,6 +452,8 @@ class PikAdditionsNode(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -446,7 +466,8 @@ class PikAdditionsNode(PikaNodeBase):
prompt_text: str, prompt_text: str,
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
# Convert video to BytesIO # Convert video to BytesIO
video_bytes_io = io.BytesIO() video_bytes_io = io.BytesIO()
@ -479,10 +500,10 @@ class PikAdditionsNode(PikaNodeBase):
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaSwapsNode(PikaNodeBase): class PikaSwapsNode(PikaNodeBase):
@ -526,6 +547,8 @@ class PikaSwapsNode(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -540,7 +563,8 @@ class PikaSwapsNode(PikaNodeBase):
prompt_text: str, prompt_text: str,
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
# Convert video to BytesIO # Convert video to BytesIO
video_bytes_io = io.BytesIO() video_bytes_io = io.BytesIO()
@ -583,10 +607,10 @@ class PikaSwapsNode(PikaNodeBase):
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaffectsNode(PikaNodeBase): class PikaffectsNode(PikaNodeBase):
@ -630,6 +654,8 @@ class PikaffectsNode(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -642,7 +668,8 @@ class PikaffectsNode(PikaNodeBase):
prompt_text: str, prompt_text: str,
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
@ -660,10 +687,10 @@ class PikaffectsNode(PikaNodeBase):
), ),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaStartEndFrameNode2_2(PikaNodeBase): class PikaStartEndFrameNode2_2(PikaNodeBase):
@ -681,6 +708,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -695,7 +724,8 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
seed: int, seed: int,
resolution: str, resolution: str,
duration: int, duration: int,
auth_token: Optional[str] = None, unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]: ) -> tuple[VideoFromFile]:
pika_files = [ pika_files = [
@ -722,10 +752,10 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
), ),
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
return self.execute_task(initial_operation, auth_token) return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -1,5 +1,5 @@
from inspect import cleandoc from inspect import cleandoc
from typing import Optional
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -34,11 +34,22 @@ import requests
from io import BytesIO from io import BytesIO
def upload_image_to_pixverse(image: torch.Tensor, auth_token=None): AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call # first, upload image to Pixverse and get image id to use in actual generation call
files = { files = {"image": tensor_to_bytesio(image)}
"image": tensor_to_bytesio(image)
}
operation = SynchronousOperation( operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload", path="/proxy/pixverse/image/upload",
@ -49,12 +60,14 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_token=None):
request=EmptyRequest(), request=EmptyRequest(),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = operation.execute() response_upload: PixverseImageUploadResponse = operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") raise Exception(
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
return response_upload.Resp.img_id return response_upload.Resp.img_id
@ -73,7 +86,7 @@ class PixverseTemplateNode:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"template": (list(pixverse_templates.keys()), ), "template": (list(pixverse_templates.keys()),),
} }
} }
@ -87,7 +100,7 @@ class PixverseTemplateNode:
class PixverseTextToVideoNode(ComfyNodeABC): class PixverseTextToVideoNode(ComfyNodeABC):
""" """
Generates videos synchronously based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) RETURN_TYPES = (IO.VIDEO,)
@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
"tooltip": "Prompt for the video generation", "tooltip": "Prompt for the video generation",
}, },
), ),
"aspect_ratio": ( "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
[ratio.value for ratio in PixverseAspectRatio],
),
"quality": ( "quality": (
[resolution.value for resolution in PixverseQuality], [resolution.value for resolution in PixverseQuality],
{ {
@ -143,11 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE, PixverseIO.TEMPLATE,
{ {
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
} },
) ),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -159,9 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC):
duration_seconds: int, duration_seconds: int,
motion_mode: str, motion_mode: str,
seed, seed,
negative_prompt: str=None, negative_prompt: str = None,
pixverse_template: int=None, pixverse_template: int = None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -190,7 +203,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse, response_model=PixverseGenerationStatusResponse,
), ),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_token=auth_token, auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = operation.execute() response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url) vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),) return (VideoFromFile(BytesIO(vid_response.content)),)
class PixverseImageToVideoNode(ComfyNodeABC): class PixverseImageToVideoNode(ComfyNodeABC):
""" """
Generates videos synchronously based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) RETURN_TYPES = (IO.VIDEO,)
@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ( "image": (IO.IMAGE,),
IO.IMAGE,
),
"prompt": ( "prompt": (
IO.STRING, IO.STRING,
{ {
@ -273,11 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE, PixverseIO.TEMPLATE,
{ {
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
} },
) ),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -289,13 +310,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
duration_seconds: int, duration_seconds: int,
motion_mode: str, motion_mode: str,
seed, seed,
negative_prompt: str=None, negative_prompt: str = None,
pixverse_template: int=None, pixverse_template: int = None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
img_id = upload_image_to_pixverse(image, auth_token=auth_token) img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -322,7 +343,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse, response_model=PixverseGenerationStatusResponse,
), ),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_token=auth_token, auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = operation.execute() response_poll = operation.execute()
@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(ComfyNodeABC): class PixverseTransitionVideoNode(ComfyNodeABC):
""" """
Generates videos synchronously based on prompt and output_size. Generates videos based on prompt and output_size.
""" """
RETURN_TYPES = (IO.VIDEO,) RETURN_TYPES = (IO.VIDEO,)
@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"first_frame": ( "first_frame": (IO.IMAGE,),
IO.IMAGE, "last_frame": (IO.IMAGE,),
),
"last_frame": (
IO.IMAGE,
),
"prompt": ( "prompt": (
IO.STRING, IO.STRING,
{ {
@ -407,6 +431,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -419,13 +445,13 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
duration_seconds: int, duration_seconds: int,
motion_mode: str, motion_mode: str,
seed, seed,
negative_prompt: str=None, negative_prompt: str = None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token) first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token) last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -452,7 +478,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse, response_model=PixverseGenerationStatusResponse,
), ),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_token=auth_token, auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = operation.execute() response_poll = operation.execute()

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from typing import Optional
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_extras.nodes_images import SVG # Added from comfy_extras.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
@ -29,6 +30,8 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image, resize_mask_to_image,
validate_string, validate_string,
) )
from server import PromptServer
import torch import torch
from io import BytesIO from io import BytesIO
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
@ -41,7 +44,7 @@ def handle_recraft_file_request(
total_pixels=4096*4096, total_pixels=4096*4096,
timeout=1024, timeout=1024,
request=None, request=None,
auth_token=None auth_kwargs: dict[str,str] = None,
) -> list[BytesIO]: ) -> list[BytesIO]:
""" """
Handle sending common Recraft file-only request to get back file bytes. Handle sending common Recraft file-only request to get back file bytes.
@ -65,7 +68,7 @@ def handle_recraft_file_request(
request=request, request=request,
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=auth_kwargs,
multipart_parser=recraft_multipart_parser, multipart_parser=recraft_multipart_parser,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = operation.execute()
@ -387,6 +390,8 @@ class RecraftTextToImageNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -399,7 +404,7 @@ class RecraftTextToImageNode:
recraft_style: RecraftStyle = None, recraft_style: RecraftStyle = None,
negative_prompt: str = None, negative_prompt: str = None,
recraft_controls: RecraftControls = None, recraft_controls: RecraftControls = None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False, max_length=1000) validate_string(prompt, strip_whitespace=False, max_length=1000)
@ -432,12 +437,19 @@ class RecraftTextToImageNode:
style_id=recraft_style.style_id, style_id=recraft_style.style_id,
controls=controls_api, controls=controls_api,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = operation.execute()
images = [] images = []
urls = []
for data in response.data: for data in response.data:
with handle_recraft_image_output(): with handle_recraft_image_output():
if unique_id and data.url:
urls.append(data.url)
urls_string = '\n'.join(urls)
PromptServer.instance.send_progress_text(
f"Result URL: {urls_string}", unique_id
)
image = bytesio_to_image_tensor( image = bytesio_to_image_tensor(
download_url_to_bytesio(data.url, timeout=1024) download_url_to_bytesio(data.url, timeout=1024)
) )
@ -522,6 +534,7 @@ class RecraftImageToImageNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
@ -532,7 +545,6 @@ class RecraftImageToImageNode:
n: int, n: int,
strength: float, strength: float,
seed, seed,
auth_token=None,
recraft_style: RecraftStyle = None, recraft_style: RecraftStyle = None,
negative_prompt: str = None, negative_prompt: str = None,
recraft_controls: RecraftControls = None, recraft_controls: RecraftControls = None,
@ -570,7 +582,7 @@ class RecraftImageToImageNode:
image=image[i], image=image[i],
path="/proxy/recraft/images/imageToImage", path="/proxy/recraft/images/imageToImage",
request=request, request=request,
auth_token=auth_token, auth_kwargs=kwargs,
) )
with handle_recraft_image_output(): with handle_recraft_image_output():
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
@ -638,6 +650,7 @@ class RecraftImageInpaintingNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
@ -648,7 +661,6 @@ class RecraftImageInpaintingNode:
prompt: str, prompt: str,
n: int, n: int,
seed, seed,
auth_token=None,
recraft_style: RecraftStyle = None, recraft_style: RecraftStyle = None,
negative_prompt: str = None, negative_prompt: str = None,
**kwargs, **kwargs,
@ -683,7 +695,7 @@ class RecraftImageInpaintingNode:
mask=mask[i:i+1], mask=mask[i:i+1],
path="/proxy/recraft/images/inpaint", path="/proxy/recraft/images/inpaint",
request=request, request=request,
auth_token=auth_token, auth_kwargs=kwargs,
) )
with handle_recraft_image_output(): with handle_recraft_image_output():
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
@ -762,6 +774,8 @@ class RecraftTextToVectorNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -774,7 +788,7 @@ class RecraftTextToVectorNode:
seed, seed,
negative_prompt: str = None, negative_prompt: str = None,
recraft_controls: RecraftControls = None, recraft_controls: RecraftControls = None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs, **kwargs,
): ):
validate_string(prompt, strip_whitespace=False, max_length=1000) validate_string(prompt, strip_whitespace=False, max_length=1000)
@ -805,11 +819,18 @@ class RecraftTextToVectorNode:
substyle=recraft_style.substyle, substyle=recraft_style.substyle,
controls=controls_api, controls=controls_api,
), ),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response: RecraftImageGenerationResponse = operation.execute() response: RecraftImageGenerationResponse = operation.execute()
svg_data = [] svg_data = []
urls = []
for data in response.data: for data in response.data:
if unique_id and data.url:
urls.append(data.url)
# Print result on each iteration in case of error
PromptServer.instance.send_progress_text(
f"Result URL: {' '.join(urls)}", unique_id
)
svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
return (SVG(svg_data),) return (SVG(svg_data),)
@ -836,13 +857,13 @@ class RecraftVectorizeImageNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call( def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
auth_token=None,
**kwargs, **kwargs,
): ):
svgs = [] svgs = []
@ -852,7 +873,7 @@ class RecraftVectorizeImageNode:
sub_bytes = handle_recraft_file_request( sub_bytes = handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/vectorize", path="/proxy/recraft/images/vectorize",
auth_token=auth_token, auth_kwargs=kwargs,
) )
svgs.append(SVG(sub_bytes)) svgs.append(SVG(sub_bytes))
pbar.update(1) pbar.update(1)
@ -917,6 +938,7 @@ class RecraftReplaceBackgroundNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
@ -926,7 +948,6 @@ class RecraftReplaceBackgroundNode:
prompt: str, prompt: str,
n: int, n: int,
seed, seed,
auth_token=None,
recraft_style: RecraftStyle = None, recraft_style: RecraftStyle = None,
negative_prompt: str = None, negative_prompt: str = None,
**kwargs, **kwargs,
@ -956,7 +977,7 @@ class RecraftReplaceBackgroundNode:
image=image[i], image=image[i],
path="/proxy/recraft/images/replaceBackground", path="/proxy/recraft/images/replaceBackground",
request=request, request=request,
auth_token=auth_token, auth_kwargs=kwargs,
) )
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1) pbar.update(1)
@ -986,13 +1007,13 @@ class RecraftRemoveBackgroundNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call( def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
auth_token=None,
**kwargs, **kwargs,
): ):
images = [] images = []
@ -1002,7 +1023,7 @@ class RecraftRemoveBackgroundNode:
sub_bytes = handle_recraft_file_request( sub_bytes = handle_recraft_file_request(
image=image[i], image=image[i],
path="/proxy/recraft/images/removeBackground", path="/proxy/recraft/images/removeBackground",
auth_token=auth_token, auth_kwargs=kwargs,
) )
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1) pbar.update(1)
@ -1037,13 +1058,13 @@ class RecraftCrispUpscaleNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call( def api_call(
self, self,
image: torch.Tensor, image: torch.Tensor,
auth_token=None,
**kwargs, **kwargs,
): ):
images = [] images = []
@ -1053,7 +1074,7 @@ class RecraftCrispUpscaleNode:
sub_bytes = handle_recraft_file_request( sub_bytes = handle_recraft_file_request(
image=image[i], image=image[i],
path=self.RECRAFT_PATH, path=self.RECRAFT_PATH,
auth_token=auth_token, auth_kwargs=kwargs,
) )
images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
pbar.update(1) pbar.update(1)

View File

@ -120,12 +120,13 @@ class StabilityStableImageUltraNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
auth_token=None): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present # prepare image binary if image present
image_binary = None image_binary = None
@ -160,7 +161,7 @@ class StabilityStableImageUltraNode:
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -252,12 +253,13 @@ class StabilityStableImageSD_3_5Node:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
auth_token=None): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present # prepare image binary if image present
image_binary = None image_binary = None
@ -298,7 +300,7 @@ class StabilityStableImageSD_3_5Node:
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -368,11 +370,12 @@ class StabilityUpscaleConservativeNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
auth_token=None): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -398,7 +401,7 @@ class StabilityUpscaleConservativeNode:
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -473,11 +476,12 @@ class StabilityUpscaleCreativeNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
auth_token=None): **kwargs):
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@ -506,7 +510,7 @@ class StabilityUpscaleCreativeNode:
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()
@ -521,7 +525,7 @@ class StabilityUpscaleCreativeNode:
completed_statuses=[StabilityPollStatus.finished], completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed], failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x), status_extractor=lambda x: get_async_dummy_status(x),
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_poll: StabilityResultsGetResponse = operation.execute() response_poll: StabilityResultsGetResponse = operation.execute()
@ -555,11 +559,12 @@ class StabilityUpscaleFastNode:
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
}, },
} }
def api_call(self, image: torch.Tensor, def api_call(self, image: torch.Tensor,
auth_token=None): **kwargs):
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = { files = {
@ -576,7 +581,7 @@ class StabilityUpscaleFastNode:
request=EmptyRequest(), request=EmptyRequest(),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_token=auth_token, auth_kwargs=kwargs,
) )
response_api = operation.execute() response_api = operation.execute()

View File

@ -3,6 +3,7 @@ import logging
import base64 import base64
import requests import requests
import torch import torch
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string tensor_to_base64_string
) )
AVERAGE_DURATION_VIDEO_GEN = 32
def convert_image_to_base64(image: torch.Tensor): def convert_image_to_base64(image: torch.Tensor):
if image is None: if image is None:
return None return None
@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image) return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(ComfyNodeABC): class VeoVideoGenerationNode(ComfyNodeABC):
""" """
Generates videos from text prompts using Google's Veo API. Generates videos from text prompts using Google's Veo API.
@ -114,6 +133,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
}, },
} }
@ -133,7 +154,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW", person_generation="ALLOW",
seed=0, seed=0,
image=None, image=None,
auth_token=None, unique_id: Optional[str] = None,
**kwargs,
): ):
# Prepare the instances for the request # Prepare the instances for the request
instances = [] instances = []
@ -179,7 +201,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
instances=instances, instances=instances,
parameters=parameters parameters=parameters
), ),
auth_token=auth_token auth_kwargs=kwargs,
) )
initial_response = initial_operation.execute() initial_response = initial_operation.execute()
@ -213,8 +235,11 @@ class VeoVideoGenerationNode(ComfyNodeABC):
request=Veo2GenVidPollRequest( request=Veo2GenVidPollRequest(
operationName=operation_name operationName=operation_name
), ),
auth_token=auth_token, auth_kwargs=kwargs,
poll_interval=5.0 poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
) )
# Execute the polling operation # Execute the polling operation

76
comfy_extras/nodes_apg.py Normal file
View File

@ -0,0 +1,76 @@
import torch
def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal
class APG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"
def patch(self, model, eta, norm_threshold, momentum):
running_avg = 0
prev_sigma = None
def pre_cfg_function(args):
nonlocal running_avg, prev_sigma
if len(args["conds_out"]) == 1: return args["conds_out"]
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
sigma = args["sigma"][0]
cond_scale = args["cond_scale"]
if prev_sigma is not None and sigma > prev_sigma:
running_avg = 0
prev_sigma = sigma
guidance = cond - uncond
if momentum != 0:
if not torch.is_tensor(running_avg):
running_avg = guidance
else:
running_avg = momentum * running_avg + guidance
guidance = running_avg
if norm_threshold > 0:
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale = torch.minimum(
torch.ones_like(guidance_norm),
norm_threshold / guidance_norm
)
guidance = guidance * scale
guidance_parallel, guidance_orthogonal = project(guidance, cond)
modified_guidance = guidance_orthogonal + eta * guidance_parallel
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
return [modified_cond, uncond] + args["conds_out"][2:]
m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"APG": APG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import av
import torchaudio import torchaudio
import torch import torch
import comfy.model_management import comfy.model_management
@ -7,7 +8,6 @@ import folder_paths
import os import os
import io import io
import json import json
import struct
import random import random
import hashlib import hashlib
import node_helpers import node_helpers
@ -90,60 +90,118 @@ class VAEDecodeAudio:
return ({"waveform": audio, "sample_rate": 44100}, ) return ({"waveform": audio, "sample_rate": 44100}, )
def create_vorbis_comment_block(comment_dict, last_block): def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
vendor_string = b'ComfyUI'
vendor_length = len(vendor_string)
comments = [] filename_prefix += self.prefix_append
for key, value in comment_dict.items(): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
comment = f"{key}={value}".encode('utf-8') results: list[FileLocator] = []
comments.append(struct.pack('<I', len(comment)) + comment)
user_comment_list_length = len(comments) # Prepare metadata dictionary
user_comments = b''.join(comments) metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments # Opus supported sample rates
if last_block: OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
id = b'\x84'
else:
id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
return comment_block for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
def insert_or_replace_vorbis_comment(flac_io, comment_dict): # Use original sample rate initially
if len(comment_dict) == 0: sample_rate = audio["sample_rate"]
return flac_io
flac_io.seek(4) # Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
blocks = [] # Resample if necessary
last_block = False if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
while not last_block: # Create in-memory WAV buffer
header = flac_io.read(4) wav_buffer = io.BytesIO()
last_block = (header[0] & 0x80) != 0 torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
block_type = header[0] & 0x7F wav_buffer.seek(0) # Rewind for reading
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
block_data = flac_io.read(block_length)
if block_type == 4 or block_type == 1: # Use PyAV to convert and add metadata
pass input_container = av.open(wav_buffer)
else:
header = bytes([(header[0] & (~0x80))]) + header[1:]
blocks.append(header + block_data)
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True)) # Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
new_flac_io = io.BytesIO() # Set metadata on the container
new_flac_io.write(b'fLaC') for key, value in metadata.items():
for block in blocks: output_container.metadata[key] = value
new_flac_io.write(block)
new_flac_io.write(flac_io.read()) # Set up the output stream with appropriate properties
return new_flac_io input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio: class SaveAudio:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -153,50 +211,70 @@ class SaveAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
} }
RETURN_TYPES = () RETURN_TYPES = ()
FUNCTION = "save_audio" FUNCTION = "save_flac"
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "audio" CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
metadata = {} class SaveAudioMP3:
if not args.disable_metadata: def __init__(self):
if prompt is not None: self.output_dir = folder_paths.get_output_directory()
metadata["prompt"] = json.dumps(prompt) self.type = "output"
if extra_pnginfo is not None: self.prefix_append = ""
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): @classmethod
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) def INPUT_TYPES(s):
file = f"{filename_with_batch_num}_{counter:05}_.flac" return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
buff = io.BytesIO() RETURN_TYPES = ()
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC") FUNCTION = "save_mp3"
buff = insert_or_replace_vorbis_comment(buff, metadata) OUTPUT_NODE = True
with open(os.path.join(full_output_folder, file), 'wb') as f: CATEGORY = "audio"
f.write(buff.getbuffer())
results.append({ def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
"filename": file, return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } } class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio): class PreviewAudio(SaveAudio):
def __init__(self): def __init__(self):
@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeAudio": VAEEncodeAudio, "VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio, "VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio, "SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio, "LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
}

View File

@ -0,0 +1,218 @@
import nodes
import torch
import numpy as np
from einops import rearrange
import comfy.model_management
MAX_RESOLUTION = nodes.MAX_RESOLUTION
CAMERA_DICT = {
"base_T_norm": 1.5,
"base_angle": np.pi/3,
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
}
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
def get_relative_pose(cam_params):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
class Camera(object):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
c2w_mat = np.array(entry[7:]).reshape(4, 4)
self.c2w_mat = c2w_mat
self.w2c_mat = np.linalg.inv(c2w_mat)
def ray_condition(K, c2w, H, W, device):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = torch.meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
indexing='ij'
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def get_camera_motion(angle, T, speed, n=81):
def compute_R_form_rad_angle(angles):
theta_x, theta_y, theta_z = angles
Rx = np.array([[1, 0, 0],
[0, np.cos(theta_x), -np.sin(theta_x)],
[0, np.sin(theta_x), np.cos(theta_x)]])
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
[0, 1, 0],
[-np.sin(theta_y), 0, np.cos(theta_y)]])
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
[np.sin(theta_z), np.cos(theta_z), 0],
[0, 0, 1]])
R = np.dot(Rz, np.dot(Ry, Rx))
return R
RT = []
for i in range(n):
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
R = compute_R_form_rad_angle(_angle)
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
_RT = np.concatenate([R,_T], axis=1)
RT.append(_RT)
RT = np.stack(RT)
return RT
class WanCameraEmbedding:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
},
"optional":{
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
}
}
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
RETURN_NAMES = ("camera_embedding","width","height","length")
FUNCTION = "run"
CATEGORY = "camera"
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
"""
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
"""
motion_list = [camera_pose]
speed = speed
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
RT = get_camera_motion(angle, T, speed, length)
trajs=[]
for cp in RT.tolist():
traj=[fx,fy,cx,cy,0,0]
traj.extend(cp[0])
traj.extend(cp[1])
traj.extend(cp[2])
traj.extend([0,0,0,1])
trajs.append(traj)
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
control_camera_video = process_pose_params(cam_params, width=width, height=height)
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
control_camera_video = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
# Reshape, transpose, and view into desired shape
b, f, c, h, w = control_camera_video.shape
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
return (control_camera_video, width, height, length)
NODE_CLASS_MAPPINGS = {
"WanCameraEmbedding": WanCameraEmbedding,
}

View File

@ -31,6 +31,7 @@ class T5TokenizerOptions:
} }
} }
CATEGORY = "_for_testing/conditioning"
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options" FUNCTION = "set_options"

View File

@ -77,7 +77,7 @@ class HunyuanImageToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], ) "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
}, },
"optional": {"start_image": ("IMAGE", ), "optional": {"start_image": ("IMAGE", ),
}} }}
@ -101,10 +101,12 @@ class HunyuanImageToVideo:
if guidance_type == "v1 (concat)": if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else: elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0} cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond) positive = node_helpers.conditioning_set_values(positive, cond)

View File

@ -2,6 +2,10 @@ import nodes
import folder_paths import folder_paths
import os import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
def normalize_path(path): def normalize_path(path):
return path.replace('\\', '/') return path.replace('\\', '/')
@ -21,8 +25,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@ -41,7 +45,14 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'] video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation(): class Load3DAnimation():
@classmethod @classmethod
@ -59,8 +70,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@ -77,7 +88,14 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'] video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
class Preview3D(): class Preview3D():
@classmethod @classmethod

View File

@ -0,0 +1,322 @@
import re
from comfy.comfy_types.node_typing import IO
class StringConcatenate():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, **kwargs):
return string_a + string_b,
class StringSubstring():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"start": (IO.INT, {}),
"end": (IO.INT, {}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, start, end, **kwargs):
return string[start:end],
class StringLength():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("length",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, **kwargs):
length = len(string)
return length,
class CaseConverter():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "UPPERCASE":
result = string.upper()
elif mode == "lowercase":
result = string.lower()
elif mode == "Capitalize":
result = string.capitalize()
elif mode == "Title Case":
result = string.title()
else:
result = string
return result,
class StringTrim():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "Both":
result = string.strip()
elif mode == "Left":
result = string.lstrip()
elif mode == "Right":
result = string.rstrip()
else:
result = string
return result,
class StringReplace():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"find": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, find, replace, **kwargs):
result = string.replace(find, replace)
return result,
class StringContains():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"substring": (IO.STRING, {"multiline": True}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("contains",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, substring, case_sensitive, **kwargs):
if case_sensitive:
contains = substring in string
else:
contains = substring.lower() in string.lower()
return contains,
class StringCompare():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
if case_sensitive:
a = string_a
b = string_b
else:
a = string_a.lower()
b = string_b.lower()
if mode == "Equal":
return a == b,
elif mode == "Starts With":
return a.startswith(b),
elif mode == "Ends With":
return a.endswith(b),
class RegexMatch():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("matches",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
match = re.search(regex_pattern, string, flags)
result = match is not None
except re.error:
result = False
return result,
class RegexExtract():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False}),
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
join_delimiter = "\n"
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
if mode == "First Match":
match = re.search(regex_pattern, string, flags)
if match:
result = match.group(0)
else:
result = ""
elif mode == "All Matches":
matches = re.findall(regex_pattern, string, flags)
if matches:
if isinstance(matches[0], tuple):
result = join_delimiter.join([m[0] for m in matches])
else:
result = join_delimiter.join(matches)
else:
result = ""
elif mode == "First Group":
match = re.search(regex_pattern, string, flags)
if match and len(match.groups()) >= group_index:
result = match.group(group_index)
else:
result = ""
elif mode == "All Groups":
matches = re.finditer(regex_pattern, string, flags)
results = []
for match in matches:
if match.groups() and len(match.groups()) >= group_index:
results.append(match.group(group_index))
result = join_delimiter.join(results)
else:
result = ""
except re.error:
result = ""
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
"StringLength": StringLength,
"CaseConverter": CaseConverter,
"StringTrim": StringTrim,
"StringReplace": StringReplace,
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StringConcatenate": "Concatenate",
"StringSubstring": "Substring",
"StringLength": "Length",
"CaseConverter": "Case Converter",
"StringTrim": "Trim",
"StringReplace": "Replace",
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract"
}

View File

@ -297,6 +297,52 @@ class TrimVideoLatent:
samples_out["samples"] = s1[:, :, trim_amount:] samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,) return (samples_out,)
class WanCameraImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo, "WanImageToVideo": WanImageToVideo,
@ -305,4 +351,5 @@ NODE_CLASS_MAPPINGS = {
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
"WanVaceToVideo": WanVaceToVideo, "WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent, "TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
} }

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.33" __version__ = "0.3.34"

View File

@ -1,155 +0,0 @@
class Example:
"""
A example node
Class methods
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Second value is a config for type "INT", "STRING" or "FLOAT".
"""
return {
"required": {
"image": ("IMAGE",),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!",
"lazy": True
}),
},
}
RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test"
#OUTPUT_NODE = False
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
int_field: {int_field}
float_field: {float_field}
""")
#do some processing on the image, in this example I just invert it
image = 1.0 - image
return (image,)
"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
# WEB_DIRECTORY = "./somejs"
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Example": Example
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Example": "Example Node"
}

View File

@ -1,44 +0,0 @@
from PIL import Image
import numpy as np
import comfy.utils
import time
#You can use this node to save full size images through the websocket, the
#images will be sent in exactly the same format as the image previews: as
#binary images on the websocket with a 8 byte header indicating the type
#of binary message (first 4 bytes) and the image format (next 4 bytes).
#Note that no metadata will be put in the images saved with this node.
class SaveImageWebsocket:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),}
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "api/image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
step = 0
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
step += 1
return {}
@classmethod
def IS_CHANGED(s, images):
return time.time()
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}

View File

@ -146,6 +146,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [unique_id] input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG": if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please

View File

@ -1,28 +0,0 @@
import importlib.util
import shutil
import os
import ctypes
import logging
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, "rb") as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@ -137,13 +137,6 @@ if __name__ == "__main__":
import cuda_malloc import cuda_malloc
if args.windows_standalone_build:
try:
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
import comfy.utils import comfy.utils
import execution import execution

View File

@ -1943,7 +1943,7 @@ class ImagePadForOutpaint:
mask[top:top + d2, left:left + d3] = t mask[top:top + d2, left:left + d3] = t
return (new_image, mask) return (new_image, mask.unsqueeze(0))
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@ -2270,8 +2270,11 @@ def init_builtin_extra_nodes():
"nodes_optimalsteps.py", "nodes_optimalsteps.py",
"nodes_hidream.py", "nodes_hidream.py",
"nodes_fresca.py", "nodes_fresca.py",
"nodes_apg.py",
"nodes_preview_any.py", "nodes_preview_any.py",
"nodes_ace.py", "nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.33" version = "0.3.34"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.18.9 comfyui-frontend-package==1.19.9
comfyui-workflow-templates==0.1.11 comfyui-workflow-templates==0.1.14
comfyui_manager comfyui_manager
torch torch
torchsde torchsde

View File

@ -101,6 +101,14 @@ prompt_text = """
def queue_prompt(prompt): def queue_prompt(prompt):
p = {"prompt": prompt} p = {"prompt": prompt}
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
# p["extra_data"] = {
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
# }
# See: https://docs.comfy.org/tutorials/api-nodes/overview
# Generate a key here: https://platform.comfy.org/login
data = json.dumps(p).encode('utf-8') data = json.dumps(p).encode('utf-8')
req = request.Request("http://127.0.0.1:8188/prompt", data=data) req = request.Request("http://127.0.0.1:8188/prompt", data=data)
request.urlopen(req) request.urlopen(req)

View File

@ -32,12 +32,13 @@ from app.frontend_management import FrontendManager
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from typing import Optional from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes: class BinaryEventTypes:
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2 UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -878,3 +879,15 @@ class PromptServer():
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return json_data return json_data
def send_progress_text(
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
):
if isinstance(text, str):
text = text.encode("utf-8")
node_id_bytes = str(node_id).encode("utf-8")
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
self.send_sync(BinaryEventTypes.TEXT, message, sid)

View File

@ -0,0 +1,239 @@
import pytest
import torch
import tempfile
import os
import av
import io
from fractions import Fraction
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.util.video_types import VideoComponents
from comfy_api.input.basic_types import AudioInput
from av.error import InvalidDataError
EPSILON = 0.0001
@pytest.fixture
def sample_images():
"""3-frame 2x2 RGB video tensor"""
return torch.rand(3, 2, 2, 3)
@pytest.fixture
def sample_audio():
"""Stereo audio with 44.1kHz sample rate"""
return AudioInput(
{
"waveform": torch.rand(1, 2, 1000),
"sample_rate": 44100,
}
)
@pytest.fixture
def video_components(sample_images, sample_audio):
"""VideoComponents with images, audio, and metadata"""
return VideoComponents(
images=sample_images,
audio=sample_audio,
frame_rate=Fraction(30),
metadata={"test": "metadata"},
)
def create_test_video(width=4, height=4, frames=3, fps=30):
"""Helper to create a temporary video file"""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame = av.VideoFrame.from_ndarray(
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
format="rgb24",
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
# Flush
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def simple_video_file():
"""4x4 video with 3 frames at 30fps"""
file_path = create_test_video()
yield file_path
os.unlink(file_path)
def test_video_from_components_get_duration(video_components):
"""Duration calculated correctly from frame count and frame rate"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
expected_duration = 3.0 / 30.0
assert duration == pytest.approx(expected_duration)
def test_video_from_components_get_duration_different_frame_rates(sample_images):
"""Duration correct for different frame rates including fractional"""
# Test with 60 fps
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
video_60fps = VideoFromComponents(components_60fps)
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
# Test with fractional frame rate (23.976fps)
components_frac = VideoComponents(
images=sample_images, frame_rate=Fraction(24000, 1001)
)
video_frac = VideoFromComponents(components_frac)
expected_frac = 3.0 / (24000.0 / 1001.0)
assert video_frac.get_duration() == pytest.approx(expected_frac)
def test_video_from_components_get_duration_empty_video():
"""Duration is zero for empty video"""
empty_components = VideoComponents(
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
)
video = VideoFromComponents(empty_components)
assert video.get_duration() == 0.0
def test_video_from_components_get_dimensions(video_components):
"""Dimensions returned correctly from image tensor shape"""
video = VideoFromComponents(video_components)
width, height = video.get_dimensions()
assert width == 2
assert height == 2
def test_video_from_file_get_duration(simple_video_file):
"""Duration extracted from file metadata"""
video = VideoFromFile(simple_video_file)
duration = video.get_duration()
assert duration == pytest.approx(0.1, abs=0.01)
def test_video_from_file_get_dimensions(simple_video_file):
"""Dimensions read from stream without decoding frames"""
video = VideoFromFile(simple_video_file)
width, height = video.get_dimensions()
assert width == 4
assert height == 4
def test_video_from_file_bytesio_input():
"""VideoFromFile works with BytesIO input"""
buffer = io.BytesIO()
with av.open(buffer, mode="w", format="mp4") as container:
stream = container.add_stream("h264", rate=30)
stream.width = 2
stream.height = 2
stream.pix_fmt = "yuv420p"
frame = av.VideoFrame.from_ndarray(
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
buffer.seek(0)
video = VideoFromFile(buffer)
assert video.get_dimensions() == (2, 2)
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
def test_video_from_file_invalid_file_error():
"""InvalidDataError raised for non-video files"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
tmp.write(b"not a video file")
tmp.flush()
tmp_name = tmp.name
try:
with pytest.raises(InvalidDataError):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_video_from_file_audio_only_error():
"""ValueError raised for audio-only files"""
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
tmp_name = tmp.name
try:
with av.open(tmp_name, mode="w") as container:
stream = container.add_stream("aac", rate=44100)
stream.sample_rate = 44100
stream.format = "fltp"
audio_data = torch.zeros(1, 1024).numpy()
audio_frame = av.AudioFrame.from_ndarray(
audio_data, format="fltp", layout="mono"
)
audio_frame.sample_rate = 44100
audio_frame.pts = 0
packet = stream.encode(audio_frame)
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)
with pytest.raises(ValueError, match="No video stream found"):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_single_frame_video():
"""Single frame video has correct duration"""
components = VideoComponents(
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
)
video = VideoFromComponents(components)
assert video.get_duration() == 1.0
@pytest.mark.parametrize(
"frame_rate,expected_fps",
[
(Fraction(24000, 1001), 24000 / 1001),
(Fraction(30000, 1001), 30000 / 1001),
(Fraction(25, 1), 25.0),
(Fraction(50, 2), 25.0),
],
)
def test_fractional_frame_rates(frame_rate, expected_fps):
"""Duration calculated correctly for various fractional frame rates"""
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
video = VideoFromComponents(components)
duration = video.get_duration()
expected_duration = 100.0 / expected_fps
assert duration == pytest.approx(expected_duration)
def test_duration_consistency(video_components):
"""get_duration() consistent with manual calculation from components"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
components = video.get_components()
manual_duration = float(components.images.shape[0] / components.frame_rate)
assert duration == pytest.approx(manual_duration)