mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-23 13:24:26 +08:00
tiled dino test
This commit is contained in:
parent
7e234a017e
commit
2abec9be34
@ -32,6 +32,57 @@ from transformers import (
|
||||
Dinov2Config,
|
||||
)
|
||||
|
||||
def split_tiles(embeds, num_split):
|
||||
B, C, H, W = embeds.shape
|
||||
out = []
|
||||
for x in embeds: # x shape: [C, H, W]
|
||||
x = x.unsqueeze(0) # shape: [1, C, H, W]
|
||||
h, w = H // num_split, W // num_split
|
||||
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w]
|
||||
for i in range(num_split)
|
||||
for j in range(num_split)], dim=0)
|
||||
out.append(x_split)
|
||||
print("x_split", x_split.shape)
|
||||
|
||||
x_split = torch.stack(out, dim=0) # Final shape: [B, num_split*num_split, C, h, w]
|
||||
|
||||
return x_split
|
||||
|
||||
# matteo's ipadapter tiling code
|
||||
def merge_hiddenstates(x, tiles):
|
||||
chunk_size = tiles*tiles
|
||||
x = x.split(chunk_size)
|
||||
|
||||
out = []
|
||||
for embeds in x:
|
||||
print("embeds", embeds.shape)
|
||||
num_tiles = embeds.shape[0]
|
||||
tile_size = int((embeds.shape[1]-1) ** 0.5)
|
||||
grid_size = int(num_tiles ** 0.5)
|
||||
|
||||
# Extract class tokens
|
||||
class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]]
|
||||
avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]]
|
||||
|
||||
patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]]
|
||||
reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1])
|
||||
|
||||
merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1)
|
||||
for i in range(grid_size)], dim=0)
|
||||
|
||||
merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]]
|
||||
|
||||
# Pool to original size
|
||||
pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1)
|
||||
flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1])
|
||||
|
||||
# Add back the class token
|
||||
with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape
|
||||
out.append(with_class)
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
|
||||
return out
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
def __init__(
|
||||
@ -67,7 +118,7 @@ class ImageEncoder(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, image, mask=None, value_range=(-1, 1)):
|
||||
def forward(self, image, mask=None, value_range=(-1, 1), tiles = 1, ratio = 0.8):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
@ -78,12 +129,25 @@ class ImageEncoder(nn.Module):
|
||||
mask = mask.to(image)
|
||||
image = image * mask
|
||||
|
||||
inputs = self.transform(image)
|
||||
outputs = self.model(inputs)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
image = self.transform(image)
|
||||
|
||||
last_hidden_state = self.model(image).last_hidden_state
|
||||
|
||||
if tiles > 1:
|
||||
hidden_state = None
|
||||
image_split = split_tiles(image, tiles)
|
||||
for i in image_split:
|
||||
i = self.transform(i)
|
||||
if hidden_state is None:
|
||||
hidden_state = self.model(i).last_hidden_state
|
||||
else:
|
||||
hidden_state = torch.cat([hidden_state, self.model(i).last_hidden_state], dim=0)
|
||||
hidden_state = merge_hiddenstates(hidden_state, tiles)
|
||||
last_hidden_state = last_hidden_state*ratio + hidden_state * (1-ratio)
|
||||
|
||||
if not self.use_cls_token:
|
||||
last_hidden_state = last_hidden_state[:, 1:, :]
|
||||
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
@ -157,9 +221,9 @@ class SingleImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||
|
||||
def forward(self, image, mask=None):
|
||||
def forward(self, image, mask=None, tiles = 1, ratio = 0.8):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask),
|
||||
'main': self.main_image_encoder(image, mask=mask, tiles = tiles, ratio = ratio),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
@ -304,10 +304,10 @@ class Hunyuan3DDiTPipeline:
|
||||
self.model.to(dtype=dtype)
|
||||
self.conditioner.to(dtype=dtype)
|
||||
|
||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, tiles=1, ratio=0.8):
|
||||
self.conditioner.to(self.main_device)
|
||||
bsz = image.shape[0]
|
||||
cond = self.conditioner(image=image, mask=mask)
|
||||
cond = self.conditioner(image=image, mask=mask, tiles=tiles, ratio=ratio)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
un_cond = self.conditioner.unconditional_embedding(bsz)
|
||||
@ -565,6 +565,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
num_chunks=8000,
|
||||
output_type: Optional[str] = "trimesh",
|
||||
enable_pbar=True,
|
||||
tiles=1,
|
||||
ratio=0.8,
|
||||
**kwargs,
|
||||
) -> List[List[trimesh.Trimesh]]:
|
||||
callback = kwargs.pop("callback", None)
|
||||
@ -584,6 +586,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
mask=mask,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dual_guidance=False,
|
||||
tiles=tiles,
|
||||
ratio=ratio,
|
||||
)
|
||||
batch_size = image.shape[0]
|
||||
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -215,6 +215,7 @@ class DownloadAndLoadHy3DPaintModel:
|
||||
snapshot_download(
|
||||
repo_id="tencent/Hunyuan3D-2",
|
||||
allow_patterns=[f"*{model}*"],
|
||||
ignore_patterns=["*diffusion_pytorch_model.bin"],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
@ -683,6 +684,8 @@ class Hy3DGenerateMesh:
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
"tiles": ("INT", {"default": 1, "min": 1, "max": 10000000, "step": 1}),
|
||||
"ratio": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -691,7 +694,7 @@ class Hy3DGenerateMesh:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None):
|
||||
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None, tiles=1, ratio=0.8):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -715,6 +718,8 @@ class Hy3DGenerateMesh:
|
||||
num_inference_steps=steps,
|
||||
mc_algo='mc',
|
||||
guidance_scale=guidance_scale,
|
||||
tiles=tiles,
|
||||
ratio=ratio,
|
||||
octree_resolution=octree_resolution,
|
||||
generator=torch.manual_seed(seed))[0]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user