mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Compare commits
31 Commits
1c8d120635
...
0ba6498de3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ba6498de3 | ||
|
|
fd271dedfd | ||
|
|
c3c6313fc7 | ||
|
|
85c4b4ae26 | ||
|
|
058f084371 | ||
|
|
ec7f65187d | ||
|
|
3ade599515 | ||
|
|
335dec0dba | ||
|
|
c36323fe53 | ||
|
|
64a984177f | ||
|
|
af4d412e67 | ||
|
|
b4d3f4e567 | ||
|
|
fb099a40b2 | ||
|
|
7237c36fba | ||
|
|
3ae78a4804 | ||
|
|
d53e62913d | ||
|
|
8d62661a9f | ||
|
|
6bfce54652 | ||
|
|
25063f25cc | ||
|
|
897ffeb5d2 | ||
|
|
f5d53f2a6b | ||
|
|
4cbc1a626c | ||
|
|
99dc95960a | ||
|
|
7842a5c805 | ||
|
|
d0dce6b90e | ||
|
|
9c5022e7e3 | ||
|
|
57567bde4e | ||
|
|
00c069dd1c | ||
|
|
6f6db12bbe | ||
|
|
460ce7f77b | ||
|
|
efe83f5a36 |
@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "attn1_patch" in patches:
|
||||
for p in patches["attn1_patch"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_patch" in patches:
|
||||
for p in patches["attn2_patch"]:
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@ -569,6 +583,7 @@ class WanModel(torch.nn.Module):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_idx"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -735,6 +750,7 @@ class VaceWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@ -764,6 +780,7 @@ class VaceWanModel(WanModel):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_idx"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
500
comfy/ldm/wan/model_multitalk.py
Normal file
500
comfy/ldm/wan/model_multitalk.py
Normal file
@ -0,0 +1,500 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q.transpose(1, 2) * scale
|
||||
|
||||
B, H, x_seqlens, K = visual_q.shape
|
||||
|
||||
x_ref_attn_maps = []
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||
|
||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||
|
||||
for i in range(0, x_seqlens, chunk_size):
|
||||
end_i = min(i + chunk_size, x_seqlens)
|
||||
|
||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||
|
||||
# Apply softmax
|
||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||
attn_chunk = (attn_chunk - attn_max).exp()
|
||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||
|
||||
# Apply mask and sum
|
||||
masked_attn = attn_chunk * ref_target_mask
|
||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||
|
||||
del attn_chunk, masked_attn
|
||||
|
||||
# Average across heads
|
||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del visual_q, ref_k
|
||||
|
||||
return torch.cat(x_ref_attn_maps, dim=0)
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_target_masks
|
||||
)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
return x_ref_attn_maps / split_num
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||
audio_embs = []
|
||||
human_num = len(encoded_audio)
|
||||
audio_frames = encoded_audio[0].shape[0]
|
||||
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
|
||||
for human_idx in range(human_num):
|
||||
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||
pad_len = audio_end - audio_frames
|
||||
pad_shape = list(encoded_audio[human_idx].shape)
|
||||
pad_shape[0] = pad_len
|
||||
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||
else:
|
||||
encoded_audio_in = encoded_audio[human_idx]
|
||||
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||
audio_embs.append(audio_emb)
|
||||
|
||||
return torch.cat(audio_embs, dim=0)
|
||||
|
||||
|
||||
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||
|
||||
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||
|
||||
middle_index = audio_proj.seq_len // 2
|
||||
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||
|
||||
return audio_emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
class SingleStreamAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
expected_tokens = N_t * N_h * N_w
|
||||
actual_tokens = x.shape[1]
|
||||
x_extra = None
|
||||
|
||||
if actual_tokens != expected_tokens:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
|
||||
B = x.shape[0]
|
||||
S = N_h * N_w
|
||||
x = x.view(B * N_t, S, self.dim)
|
||||
|
||||
# get q for hidden_state
|
||||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||
|
||||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||
kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||
|
||||
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# linear transform
|
||||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||
x = x.view(B, N_t * S, self.dim)
|
||||
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
# Rotary-embedding layout parameters
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.max_humans = self.class_range // self.class_interval
|
||||
|
||||
# Constant bucket used for background tokens
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||
# Single-speaker fall-through
|
||||
if human_num <= 1:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_extra = None
|
||||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# Query projection
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
# Use `class_range` logic for 2 speakers
|
||||
rope_h1 = (0, self.class_interval)
|
||||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
rope_bak = int(self.class_range // 2)
|
||||
|
||||
# Normalize and scale attention maps for each speaker
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
|
||||
# Token-wise speaker dominance
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||
|
||||
# Apply rotary to Q
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Keys / Values
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Final attention
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# Linear projection
|
||||
x = x.reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# Restore original layout
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len: int = 5,
|
||||
seq_len_vf: int = 12,
|
||||
blocks: int = 12,
|
||||
channels: int = 768,
|
||||
intermediate_dim: int = 512,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.out_dim = out_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_idx", None)
|
||||
x = kwargs["x"]
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class MultiTalkApplyModelWrapper:
|
||||
def __init__(self, init_latents):
|
||||
self.init_latents = init_latents
|
||||
|
||||
def __call__(self, executor, x, *args, **kwargs):
|
||||
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||
samples = executor(x, *args, **kwargs)
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
@ -5,9 +5,9 @@ from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
This should be used to initialize any global resources needed by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
|
||||
@ -3,14 +3,14 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
|
||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
@ -763,7 +763,7 @@ class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
class ModelPatch(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
@ -1863,6 +1863,7 @@ __all__ = [
|
||||
"ControlNet",
|
||||
"Vae",
|
||||
"Model",
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"AudioEncoder",
|
||||
|
||||
@ -22,7 +22,7 @@ import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
from .._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: list[str] | None = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
"seedance-1-0-pro-fast-251015": {
|
||||
"480p": 50,
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
}
|
||||
@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
|
||||
@ -85,7 +85,7 @@ class Response1(BaseModel):
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
videos: Optional[list[Video]] = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
|
||||
class Text2ImageModelName(str, Enum):
|
||||
seedream_3 = "seedream-3-0-t2i-250415"
|
||||
|
||||
|
||||
class Image2ImageModelName(str, Enum):
|
||||
seededit_3 = "seededit-3-0-i2i-250628"
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
size: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: Optional[str] = Field("adaptive")
|
||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field("seedream-4-0-250828")
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: Optional[TaskStatusError] = Field(None)
|
||||
content: Optional[TaskStatusResult] = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
return response.data[0]["url"]
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "content") and response.content:
|
||||
return response.content.video_url
|
||||
return None
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2ImageModelName,
|
||||
default=Text2ImageModelName.seedream_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2ImageModelName,
|
||||
default=Image2ImageModelName.seededit_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-0-250828"],
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor = None,
|
||||
image: Input.Image | None = None,
|
||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||
width: int = 2048,
|
||||
height: int = 2048,
|
||||
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2VideoModelName,
|
||||
default=Text2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2VideoModelName,
|
||||
default=Image2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in Image2VideoModelName],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
estimated_duration: Optional[int],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -1085,7 +934,7 @@ async def process_video_task(
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
|
||||
@ -13,8 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
GeminiSystemInstructionContent,
|
||||
GeminiTextPart,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||
"If a prompt is conversational or lacks specific visual details, "
|
||||
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
],
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
|
||||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
fps: int | None = Field(25)
|
||||
generate_audio: bool | None = Field(True)
|
||||
image_uri: str | None = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Optional[VideoInput] = None,
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: Optional[int] = 100,
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
|
||||
@ -11,12 +11,11 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
@ -119,8 +116,8 @@ async def get_response(
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
estimated_duration: int | None = None,
|
||||
) -> InputImpl.VideoFromFile:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
|
||||
@ -244,6 +244,14 @@ class ModelPatchLoader:
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif "audio_proj.proj1.weight" in sd:
|
||||
model = MultiTalkModelPatch(
|
||||
audio_window=5, context_tokens=32, vae_scale=4,
|
||||
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@ -435,6 +443,40 @@ class USOStyleReference:
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
|
||||
class MultiTalkModelPatch(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_window: int = 5,
|
||||
intermediate_dim: int = 512,
|
||||
in_dim: int = 5120,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
vae_scale: int = 4,
|
||||
num_layers: int = 40,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_proj = MultiTalkAudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
out_dim=out_dim,
|
||||
context_tokens=context_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
|
||||
@ -8,10 +8,7 @@ import json
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=VideoContainer(format),
|
||||
format=Types.VideoContainer(format),
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
)
|
||||
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
|
||||
@ -11,6 +11,7 @@ import numpy as np
|
||||
from typing import Tuple
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -1288,6 +1289,155 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanInfiniteTalkToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
|
||||
ref_masks = None
|
||||
if mask_1 is not None and mask_2 is not None:
|
||||
if audio_encoder_output_2 is None:
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
model_patched = model.clone()
|
||||
|
||||
encoded_audio_list = []
|
||||
seq_lengths = []
|
||||
|
||||
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||
if audio_encoder_output is None:
|
||||
continue
|
||||
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||
encoded_audio_list.append(encoded_audio)
|
||||
seq_lengths.append(encoded_audio.shape[0])
|
||||
|
||||
# Pad / combine depending on multi_audio_type
|
||||
multi_audio_type = "add"
|
||||
if len(encoded_audio_list) > 1:
|
||||
if multi_audio_type == "para":
|
||||
max_len = max(seq_lengths)
|
||||
padded = []
|
||||
for emb in encoded_audio_list:
|
||||
if emb.shape[0] < max_len:
|
||||
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=0)
|
||||
padded.append(emb)
|
||||
encoded_audio_list = padded
|
||||
elif multi_audio_type == "add":
|
||||
total_len = sum(seq_lengths)
|
||||
full_list = []
|
||||
offset = 0
|
||||
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||
full[offset:offset+seq_len] = emb
|
||||
full_list.append(full)
|
||||
offset += seq_len
|
||||
encoded_audio_list = full_list
|
||||
|
||||
token_ref_target_masks = None
|
||||
if ref_masks is not None:
|
||||
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1307,6 +1457,7 @@ class WanExtension(ComfyExtension):
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.33.10
|
||||
comfyui-workflow-templates==0.7.25
|
||||
comfyui-workflow-templates==0.7.51
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user