tiled dino test

This commit is contained in:
kijai 2025-01-29 20:30:55 +02:00
parent 7e234a017e
commit 2abec9be34
3 changed files with 83 additions and 10 deletions

View File

@ -32,6 +32,57 @@ from transformers import (
Dinov2Config,
)
def split_tiles(embeds, num_split):
B, C, H, W = embeds.shape
out = []
for x in embeds: # x shape: [C, H, W]
x = x.unsqueeze(0) # shape: [1, C, H, W]
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w]
for i in range(num_split)
for j in range(num_split)], dim=0)
out.append(x_split)
print("x_split", x_split.shape)
x_split = torch.stack(out, dim=0) # Final shape: [B, num_split*num_split, C, h, w]
return x_split
# matteo's ipadapter tiling code
def merge_hiddenstates(x, tiles):
chunk_size = tiles*tiles
x = x.split(chunk_size)
out = []
for embeds in x:
print("embeds", embeds.shape)
num_tiles = embeds.shape[0]
tile_size = int((embeds.shape[1]-1) ** 0.5)
grid_size = int(num_tiles ** 0.5)
# Extract class tokens
class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]]
avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]]
patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]]
reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1])
merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1)
for i in range(grid_size)], dim=0)
merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]]
# Pool to original size
pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1)
flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1])
# Add back the class token
with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape
out.append(with_class)
out = torch.cat(out, dim=0)
return out
class ImageEncoder(nn.Module):
def __init__(
@ -67,7 +118,7 @@ class ImageEncoder(nn.Module):
]
)
def forward(self, image, mask=None, value_range=(-1, 1)):
def forward(self, image, mask=None, value_range=(-1, 1), tiles = 1, ratio = 0.8):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
@ -78,12 +129,25 @@ class ImageEncoder(nn.Module):
mask = mask.to(image)
image = image * mask
inputs = self.transform(image)
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state
image = self.transform(image)
last_hidden_state = self.model(image).last_hidden_state
if tiles > 1:
hidden_state = None
image_split = split_tiles(image, tiles)
for i in image_split:
i = self.transform(i)
if hidden_state is None:
hidden_state = self.model(i).last_hidden_state
else:
hidden_state = torch.cat([hidden_state, self.model(i).last_hidden_state], dim=0)
hidden_state = merge_hiddenstates(hidden_state, tiles)
last_hidden_state = last_hidden_state*ratio + hidden_state * (1-ratio)
if not self.use_cls_token:
last_hidden_state = last_hidden_state[:, 1:, :]
return last_hidden_state
@ -157,9 +221,9 @@ class SingleImageEncoder(nn.Module):
super().__init__()
self.main_image_encoder = build_image_encoder(main_image_encoder)
def forward(self, image, mask=None):
def forward(self, image, mask=None, tiles = 1, ratio = 0.8):
outputs = {
'main': self.main_image_encoder(image, mask=mask),
'main': self.main_image_encoder(image, mask=mask, tiles = tiles, ratio = ratio),
}
return outputs

View File

@ -304,10 +304,10 @@ class Hunyuan3DDiTPipeline:
self.model.to(dtype=dtype)
self.conditioner.to(dtype=dtype)
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, tiles=1, ratio=0.8):
self.conditioner.to(self.main_device)
bsz = image.shape[0]
cond = self.conditioner(image=image, mask=mask)
cond = self.conditioner(image=image, mask=mask, tiles=tiles, ratio=ratio)
if do_classifier_free_guidance:
un_cond = self.conditioner.unconditional_embedding(bsz)
@ -565,6 +565,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
num_chunks=8000,
output_type: Optional[str] = "trimesh",
enable_pbar=True,
tiles=1,
ratio=0.8,
**kwargs,
) -> List[List[trimesh.Trimesh]]:
callback = kwargs.pop("callback", None)
@ -584,6 +586,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
mask=mask,
do_classifier_free_guidance=do_classifier_free_guidance,
dual_guidance=False,
tiles=tiles,
ratio=ratio,
)
batch_size = image.shape[0]

View File

@ -215,6 +215,7 @@ class DownloadAndLoadHy3DPaintModel:
snapshot_download(
repo_id="tencent/Hunyuan3D-2",
allow_patterns=[f"*{model}*"],
ignore_patterns=["*diffusion_pytorch_model.bin"],
local_dir=download_path,
local_dir_use_symlinks=False,
)
@ -683,6 +684,8 @@ class Hy3DGenerateMesh:
},
"optional": {
"mask": ("MASK", ),
"tiles": ("INT", {"default": 1, "min": 1, "max": 10000000, "step": 1}),
"ratio": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
@ -691,7 +694,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None):
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None, tiles=1, ratio=0.8):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -715,6 +718,8 @@ class Hy3DGenerateMesh:
num_inference_steps=steps,
mc_algo='mc',
guidance_scale=guidance_scale,
tiles=tiles,
ratio=ratio,
octree_resolution=octree_resolution,
generator=torch.manual_seed(seed))[0]