diff --git a/hy3dgen/shapegen/models/conditioner.py b/hy3dgen/shapegen/models/conditioner.py index 3616fca..0e6427f 100755 --- a/hy3dgen/shapegen/models/conditioner.py +++ b/hy3dgen/shapegen/models/conditioner.py @@ -77,8 +77,12 @@ class ImageEncoder(nn.Module): if mask is not None: mask = mask.to(image) image = image * mask - - inputs = self.transform(image) + supported_sizes = [518, 530] + if image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes: + print(f'Image shape {image.shape} not supported. Resizing to 518x518') + inputs = self.transform(image) + else: + inputs = image outputs = self.model(inputs) last_hidden_state = outputs.last_hidden_state diff --git a/nodes.py b/nodes.py index a9d2862..ba943bb 100644 --- a/nodes.py +++ b/nodes.py @@ -1,6 +1,7 @@ import os import torch import torchvision.transforms as transforms +import torch.nn.functional as F from PIL import Image from pathlib import Path import numpy as np @@ -1042,6 +1043,8 @@ class Hy3DGenerateMesh: if mask is not None: mask = mask.unsqueeze(0).to(device) + if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]: + mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest') pipeline.to(device)