diff --git a/hy3dgen/shapegen/models/conditioner.py b/hy3dgen/shapegen/models/conditioner.py index 1af4c0c..3616fca 100755 --- a/hy3dgen/shapegen/models/conditioner.py +++ b/hy3dgen/shapegen/models/conditioner.py @@ -73,6 +73,11 @@ class ImageEncoder(nn.Module): image = (image - low) / (high - low) image = image.to(self.model.device, dtype=self.model.dtype) + + if mask is not None: + mask = mask.to(image) + image = image * mask + inputs = self.transform(image) outputs = self.model(inputs)