mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-24 16:14:24 +08:00
235 lines
8.3 KiB
Python
Executable File
235 lines
8.3 KiB
Python
Executable File
# Open Source Model Licensed under the Apache License Version 2.0
|
|
# and Other Licenses of the Third-Party Components therein:
|
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
|
|
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
|
# The below software and/or models in this distribution may have been
|
|
# modified by THL A29 Limited ("Tencent Modifications").
|
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
|
|
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
|
# except for the third-party components listed below.
|
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
|
# in the repsective licenses of these third-party components.
|
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
|
# components and must ensure that the usage of the third party components adheres to
|
|
# all relevant laws and regulations.
|
|
|
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
|
# their software and algorithms, including trained model weights, parameters (including
|
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import transforms
|
|
from transformers import (
|
|
CLIPVisionModelWithProjection,
|
|
CLIPVisionConfig,
|
|
Dinov2Model,
|
|
Dinov2Config,
|
|
)
|
|
|
|
def split_tiles(embeds, num_split):
|
|
B, C, H, W = embeds.shape
|
|
out = []
|
|
for x in embeds: # x shape: [C, H, W]
|
|
x = x.unsqueeze(0) # shape: [1, C, H, W]
|
|
h, w = H // num_split, W // num_split
|
|
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w]
|
|
for i in range(num_split)
|
|
for j in range(num_split)], dim=0)
|
|
out.append(x_split)
|
|
print("x_split", x_split.shape)
|
|
|
|
x_split = torch.stack(out, dim=0) # Final shape: [B, num_split*num_split, C, h, w]
|
|
|
|
return x_split
|
|
|
|
# matteo's ipadapter tiling code
|
|
def merge_hiddenstates(x, tiles):
|
|
chunk_size = tiles*tiles
|
|
x = x.split(chunk_size)
|
|
|
|
out = []
|
|
for embeds in x:
|
|
print("embeds", embeds.shape)
|
|
num_tiles = embeds.shape[0]
|
|
tile_size = int((embeds.shape[1]-1) ** 0.5)
|
|
grid_size = int(num_tiles ** 0.5)
|
|
|
|
# Extract class tokens
|
|
class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]]
|
|
avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]]
|
|
|
|
patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]]
|
|
reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1])
|
|
|
|
merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1)
|
|
for i in range(grid_size)], dim=0)
|
|
|
|
merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]]
|
|
|
|
# Pool to original size
|
|
pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1)
|
|
flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1])
|
|
|
|
# Add back the class token
|
|
with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape
|
|
out.append(with_class)
|
|
|
|
out = torch.cat(out, dim=0)
|
|
|
|
return out
|
|
|
|
class ImageEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
version=None,
|
|
config=None,
|
|
use_cls_token=True,
|
|
image_size=224,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
if config is None:
|
|
self.model = self.MODEL_CLASS.from_pretrained(version)
|
|
else:
|
|
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
|
|
self.model.eval()
|
|
self.model.requires_grad_(False)
|
|
self.use_cls_token = use_cls_token
|
|
self.size = image_size // 14
|
|
self.num_patches = (image_size // 14) ** 2
|
|
if self.use_cls_token:
|
|
self.num_patches += 1
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
|
|
transforms.CenterCrop(image_size),
|
|
transforms.Normalize(
|
|
mean=self.mean,
|
|
std=self.std,
|
|
),
|
|
]
|
|
)
|
|
|
|
def forward(self, image, mask=None, value_range=(-1, 1), tiles = 1, ratio = 0.8):
|
|
if value_range is not None:
|
|
low, high = value_range
|
|
image = (image - low) / (high - low)
|
|
|
|
image = image.to(self.model.device, dtype=self.model.dtype)
|
|
|
|
if mask is not None:
|
|
mask = mask.to(image)
|
|
image = image * mask
|
|
|
|
image = self.transform(image)
|
|
|
|
last_hidden_state = self.model(image).last_hidden_state
|
|
|
|
if tiles > 1:
|
|
hidden_state = None
|
|
image_split = split_tiles(image, tiles)
|
|
for i in image_split:
|
|
i = self.transform(i)
|
|
if hidden_state is None:
|
|
hidden_state = self.model(i).last_hidden_state
|
|
else:
|
|
hidden_state = torch.cat([hidden_state, self.model(i).last_hidden_state], dim=0)
|
|
hidden_state = merge_hiddenstates(hidden_state, tiles)
|
|
last_hidden_state = last_hidden_state*ratio + hidden_state * (1-ratio)
|
|
|
|
if not self.use_cls_token:
|
|
last_hidden_state = last_hidden_state[:, 1:, :]
|
|
|
|
|
|
return last_hidden_state
|
|
|
|
def unconditional_embedding(self, batch_size):
|
|
device = next(self.model.parameters()).device
|
|
dtype = next(self.model.parameters()).dtype
|
|
zero = torch.zeros(
|
|
batch_size,
|
|
self.num_patches,
|
|
self.model.config.hidden_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
return zero
|
|
|
|
|
|
class CLIPImageEncoder(ImageEncoder):
|
|
MODEL_CLASS = CLIPVisionModelWithProjection
|
|
MODEL_CONFIG_CLASS = CLIPVisionConfig
|
|
mean = [0.48145466, 0.4578275, 0.40821073]
|
|
std = [0.26862954, 0.26130258, 0.27577711]
|
|
|
|
|
|
class DinoImageEncoder(ImageEncoder):
|
|
MODEL_CLASS = Dinov2Model
|
|
MODEL_CONFIG_CLASS = Dinov2Config
|
|
mean = [0.485, 0.456, 0.406]
|
|
std = [0.229, 0.224, 0.225]
|
|
|
|
|
|
def build_image_encoder(config):
|
|
if config['type'] == 'CLIPImageEncoder':
|
|
return CLIPImageEncoder(**config['kwargs'])
|
|
elif config['type'] == 'DinoImageEncoder':
|
|
return DinoImageEncoder(**config['kwargs'])
|
|
else:
|
|
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
|
|
|
|
|
class DualImageEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
main_image_encoder,
|
|
additional_image_encoder,
|
|
):
|
|
super().__init__()
|
|
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
|
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
|
|
|
def forward(self, image, mask=None):
|
|
outputs = {
|
|
'main': self.main_image_encoder(image, mask=mask),
|
|
'additional': self.additional_image_encoder(image, mask=mask),
|
|
}
|
|
return outputs
|
|
|
|
def unconditional_embedding(self, batch_size):
|
|
outputs = {
|
|
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
|
'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
|
|
}
|
|
return outputs
|
|
|
|
|
|
class SingleImageEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
main_image_encoder,
|
|
):
|
|
super().__init__()
|
|
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
|
|
|
def forward(self, image, mask=None, tiles = 1, ratio = 0.8):
|
|
outputs = {
|
|
'main': self.main_image_encoder(image, mask=mask, tiles = tiles, ratio = ratio),
|
|
}
|
|
return outputs
|
|
|
|
def unconditional_embedding(self, batch_size):
|
|
outputs = {
|
|
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
|
}
|
|
return outputs
|