mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-31 06:58:42 +08:00
273 lines
10 KiB
Python
273 lines
10 KiB
Python
import cv2
|
|
import math
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
import torch
|
|
from torchvision.transforms import InterpolationMode
|
|
from torchvision.transforms.functional import normalize, resize
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
from typing import List, Optional, Tuple, Union
|
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
|
|
|
|
|
def tensor_to_pil(src_img_tensor):
|
|
img = src_img_tensor.clone().detach()
|
|
if img.dtype == torch.bfloat16:
|
|
img = img.to(torch.float32)
|
|
img = img.cpu().numpy()
|
|
img = np.transpose(img, (1, 2, 0))
|
|
img = img.astype(np.uint8)
|
|
pil_image = Image.fromarray(img)
|
|
return pil_image
|
|
|
|
|
|
def _get_t5_prompt_embeds(
|
|
tokenizer: T5Tokenizer,
|
|
text_encoder: T5EncoderModel,
|
|
prompt: Union[str, List[str]],
|
|
num_videos_per_prompt: int = 1,
|
|
max_sequence_length: int = 226,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
text_input_ids=None,
|
|
):
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
batch_size = len(prompt)
|
|
|
|
if tokenizer is not None:
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
add_special_tokens=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
else:
|
|
if text_input_ids is None:
|
|
raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
_, seq_len, _ = prompt_embeds.shape
|
|
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
|
|
|
return prompt_embeds
|
|
|
|
|
|
def encode_prompt(
|
|
tokenizer: T5Tokenizer,
|
|
text_encoder: T5EncoderModel,
|
|
prompt: Union[str, List[str]],
|
|
num_videos_per_prompt: int = 1,
|
|
max_sequence_length: int = 226,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
text_input_ids=None,
|
|
):
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
prompt_embeds = _get_t5_prompt_embeds(
|
|
tokenizer,
|
|
text_encoder,
|
|
prompt=prompt,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
dtype=dtype,
|
|
text_input_ids=text_input_ids,
|
|
)
|
|
return prompt_embeds
|
|
|
|
|
|
def compute_prompt_embeddings(
|
|
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
|
):
|
|
if requires_grad:
|
|
prompt_embeds = encode_prompt(
|
|
tokenizer,
|
|
text_encoder,
|
|
prompt,
|
|
num_videos_per_prompt=1,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
else:
|
|
with torch.no_grad():
|
|
prompt_embeds = encode_prompt(
|
|
tokenizer,
|
|
text_encoder,
|
|
prompt,
|
|
num_videos_per_prompt=1,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
return prompt_embeds
|
|
|
|
|
|
def prepare_rotary_positional_embeddings(
|
|
height: int,
|
|
width: int,
|
|
num_frames: int,
|
|
vae_scale_factor_spatial: int = 8,
|
|
patch_size: int = 2,
|
|
attention_head_dim: int = 64,
|
|
device: Optional[torch.device] = None,
|
|
base_height: int = 480,
|
|
base_width: int = 720,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
|
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
|
|
|
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
|
embed_dim=attention_head_dim,
|
|
crops_coords=grid_crops_coords,
|
|
grid_size=(grid_height, grid_width),
|
|
temporal_size=num_frames,
|
|
)
|
|
|
|
freqs_cos = freqs_cos.to(device=device)
|
|
freqs_sin = freqs_sin.to(device=device)
|
|
return freqs_cos, freqs_sin
|
|
|
|
|
|
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
|
"""Numpy array to tensor.
|
|
|
|
Args:
|
|
imgs (list[ndarray] | ndarray): Input images.
|
|
bgr2rgb (bool): Whether to change bgr to rgb.
|
|
float32 (bool): Whether to change to float32.
|
|
|
|
Returns:
|
|
list[tensor] | tensor: Tensor images. If returned results only have
|
|
one element, just return tensor.
|
|
"""
|
|
|
|
def _totensor(img, bgr2rgb, float32):
|
|
if img.shape[2] == 3 and bgr2rgb:
|
|
if img.dtype == 'float64':
|
|
img = img.astype('float32')
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = torch.from_numpy(img.transpose(2, 0, 1))
|
|
if float32:
|
|
img = img.float()
|
|
return img
|
|
|
|
if isinstance(imgs, list):
|
|
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
|
return _totensor(imgs, bgr2rgb, float32)
|
|
|
|
|
|
def to_gray(img):
|
|
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
|
|
x = x.repeat(1, 3, 1, 1)
|
|
return x
|
|
|
|
|
|
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
|
stickwidth = 4
|
|
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
|
kps = np.array(kps)
|
|
|
|
w, h = image_pil.size
|
|
out_img = np.zeros([h, w, 3])
|
|
|
|
for i in range(len(limbSeq)):
|
|
index = limbSeq[i]
|
|
color = color_list[index[0]]
|
|
|
|
x = kps[index][:, 0]
|
|
y = kps[index][:, 1]
|
|
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
|
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
|
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
|
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
|
out_img = (out_img * 0.6).astype(np.uint8)
|
|
|
|
for idx_kp, kp in enumerate(kps):
|
|
color = color_list[idx_kp]
|
|
x, y = kp
|
|
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
|
|
|
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
|
return out_img_pil
|
|
|
|
|
|
def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True, cal_uncond=False):
|
|
"""
|
|
Args:
|
|
image: numpy rgb image, range [0, 255]
|
|
"""
|
|
face_helper.clean_all()
|
|
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
|
|
# get antelopev2 embedding
|
|
face_info = app.get(image_bgr)
|
|
if len(face_info) > 0:
|
|
face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
|
|
-1
|
|
] # only use the maximum face
|
|
id_ante_embedding = face_info['embedding'] # (512,)
|
|
face_kps = face_info['kps']
|
|
else:
|
|
id_ante_embedding = None
|
|
face_kps = None
|
|
|
|
# using facexlib to detect and align face
|
|
face_helper.read_image(image_bgr)
|
|
face_helper.get_face_landmarks_5(only_center_face=True)
|
|
if face_kps is None:
|
|
face_kps = face_helper.all_landmarks_5[0]
|
|
face_helper.align_warp_face()
|
|
if len(face_helper.cropped_faces) == 0:
|
|
raise RuntimeError('facexlib align face fail')
|
|
align_face = face_helper.cropped_faces[0] # (512, 512, 3) # RGB
|
|
|
|
# incase insightface didn't detect face
|
|
if id_ante_embedding is None:
|
|
print('fail to detect face using insightface, extract embedding on align face')
|
|
id_ante_embedding = handler_ante.get_feat(align_face)
|
|
|
|
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
|
|
if id_ante_embedding.ndim == 1:
|
|
id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512])
|
|
|
|
# parsing
|
|
if is_align_face:
|
|
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
|
input = input.to(device)
|
|
parsing_out = face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
|
|
parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
|
|
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
|
|
bg = sum(parsing_out == i for i in bg_label).bool()
|
|
white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512])
|
|
# only keep the face features
|
|
return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512])
|
|
return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512])
|
|
else:
|
|
original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR)
|
|
input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
|
input = input.to(device)
|
|
return_face_features_image = return_face_features_image_2 = input
|
|
|
|
# transform img before sending to eva-clip-vit
|
|
face_features_image = resize(return_face_features_image, clip_vision_model.image_size,
|
|
InterpolationMode.BICUBIC) # torch.Size([1, 3, 336, 336])
|
|
face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
|
|
id_cond_vit, id_vit_hidden = clip_vision_model(face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
|
|
id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
|
|
id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
|
|
|
|
id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
|
|
|
|
return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024])) |