From 2abec9be34d010d8e68ad236bf2b27fd62d00d3d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 29 Jan 2025 20:30:55 +0200 Subject: [PATCH] tiled dino test --- hy3dgen/shapegen/models/conditioner.py | 78 +++++++++++++++++++++++--- hy3dgen/shapegen/pipelines.py | 8 ++- nodes.py | 7 ++- 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/hy3dgen/shapegen/models/conditioner.py b/hy3dgen/shapegen/models/conditioner.py index 3616fca..ffae00b 100755 --- a/hy3dgen/shapegen/models/conditioner.py +++ b/hy3dgen/shapegen/models/conditioner.py @@ -32,6 +32,57 @@ from transformers import ( Dinov2Config, ) +def split_tiles(embeds, num_split): + B, C, H, W = embeds.shape + out = [] + for x in embeds: # x shape: [C, H, W] + x = x.unsqueeze(0) # shape: [1, C, H, W] + h, w = H // num_split, W // num_split + x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] + for i in range(num_split) + for j in range(num_split)], dim=0) + out.append(x_split) + print("x_split", x_split.shape) + + x_split = torch.stack(out, dim=0) # Final shape: [B, num_split*num_split, C, h, w] + + return x_split + +# matteo's ipadapter tiling code +def merge_hiddenstates(x, tiles): + chunk_size = tiles*tiles + x = x.split(chunk_size) + + out = [] + for embeds in x: + print("embeds", embeds.shape) + num_tiles = embeds.shape[0] + tile_size = int((embeds.shape[1]-1) ** 0.5) + grid_size = int(num_tiles ** 0.5) + + # Extract class tokens + class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]] + avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]] + + patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]] + reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1]) + + merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1) + for i in range(grid_size)], dim=0) + + merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]] + + # Pool to original size + pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1) + flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1]) + + # Add back the class token + with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape + out.append(with_class) + + out = torch.cat(out, dim=0) + + return out class ImageEncoder(nn.Module): def __init__( @@ -67,7 +118,7 @@ class ImageEncoder(nn.Module): ] ) - def forward(self, image, mask=None, value_range=(-1, 1)): + def forward(self, image, mask=None, value_range=(-1, 1), tiles = 1, ratio = 0.8): if value_range is not None: low, high = value_range image = (image - low) / (high - low) @@ -78,12 +129,25 @@ class ImageEncoder(nn.Module): mask = mask.to(image) image = image * mask - inputs = self.transform(image) - outputs = self.model(inputs) - - last_hidden_state = outputs.last_hidden_state + image = self.transform(image) + + last_hidden_state = self.model(image).last_hidden_state + + if tiles > 1: + hidden_state = None + image_split = split_tiles(image, tiles) + for i in image_split: + i = self.transform(i) + if hidden_state is None: + hidden_state = self.model(i).last_hidden_state + else: + hidden_state = torch.cat([hidden_state, self.model(i).last_hidden_state], dim=0) + hidden_state = merge_hiddenstates(hidden_state, tiles) + last_hidden_state = last_hidden_state*ratio + hidden_state * (1-ratio) + if not self.use_cls_token: last_hidden_state = last_hidden_state[:, 1:, :] + return last_hidden_state @@ -157,9 +221,9 @@ class SingleImageEncoder(nn.Module): super().__init__() self.main_image_encoder = build_image_encoder(main_image_encoder) - def forward(self, image, mask=None): + def forward(self, image, mask=None, tiles = 1, ratio = 0.8): outputs = { - 'main': self.main_image_encoder(image, mask=mask), + 'main': self.main_image_encoder(image, mask=mask, tiles = tiles, ratio = ratio), } return outputs diff --git a/hy3dgen/shapegen/pipelines.py b/hy3dgen/shapegen/pipelines.py index 4435bf2..dcabe15 100755 --- a/hy3dgen/shapegen/pipelines.py +++ b/hy3dgen/shapegen/pipelines.py @@ -304,10 +304,10 @@ class Hunyuan3DDiTPipeline: self.model.to(dtype=dtype) self.conditioner.to(dtype=dtype) - def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance): + def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, tiles=1, ratio=0.8): self.conditioner.to(self.main_device) bsz = image.shape[0] - cond = self.conditioner(image=image, mask=mask) + cond = self.conditioner(image=image, mask=mask, tiles=tiles, ratio=ratio) if do_classifier_free_guidance: un_cond = self.conditioner.unconditional_embedding(bsz) @@ -565,6 +565,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): num_chunks=8000, output_type: Optional[str] = "trimesh", enable_pbar=True, + tiles=1, + ratio=0.8, **kwargs, ) -> List[List[trimesh.Trimesh]]: callback = kwargs.pop("callback", None) @@ -584,6 +586,8 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): mask=mask, do_classifier_free_guidance=do_classifier_free_guidance, dual_guidance=False, + tiles=tiles, + ratio=ratio, ) batch_size = image.shape[0] diff --git a/nodes.py b/nodes.py index ce4abbb..56ab5ae 100644 --- a/nodes.py +++ b/nodes.py @@ -215,6 +215,7 @@ class DownloadAndLoadHy3DPaintModel: snapshot_download( repo_id="tencent/Hunyuan3D-2", allow_patterns=[f"*{model}*"], + ignore_patterns=["*diffusion_pytorch_model.bin"], local_dir=download_path, local_dir_use_symlinks=False, ) @@ -683,6 +684,8 @@ class Hy3DGenerateMesh: }, "optional": { "mask": ("MASK", ), + "tiles": ("INT", {"default": 1, "min": 1, "max": 10000000, "step": 1}), + "ratio": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), } } @@ -691,7 +694,7 @@ class Hy3DGenerateMesh: FUNCTION = "process" CATEGORY = "Hunyuan3DWrapper" - def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None): + def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None, tiles=1, ratio=0.8): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -715,6 +718,8 @@ class Hy3DGenerateMesh: num_inference_steps=steps, mc_algo='mc', guidance_scale=guidance_scale, + tiles=tiles, + ratio=ratio, octree_resolution=octree_resolution, generator=torch.manual_seed(seed))[0]