little tweaks

This commit is contained in:
kijai 2025-02-03 20:01:32 +02:00
parent d618646a88
commit 1b7c0606e2
2 changed files with 9 additions and 2 deletions

View File

@ -77,8 +77,12 @@ class ImageEncoder(nn.Module):
if mask is not None:
mask = mask.to(image)
image = image * mask
inputs = self.transform(image)
supported_sizes = [518, 530]
if image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes:
print(f'Image shape {image.shape} not supported. Resizing to 518x518')
inputs = self.transform(image)
else:
inputs = image
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state

View File

@ -1,6 +1,7 @@
import os
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
import numpy as np
@ -1042,6 +1043,8 @@ class Hy3DGenerateMesh:
if mask is not None:
mask = mask.unsqueeze(0).to(device)
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
pipeline.to(device)