make mask do something

This commit is contained in:
kijai 2025-01-25 23:32:24 +02:00
parent e43874b483
commit 7e234a017e

View File

@ -73,6 +73,11 @@ class ImageEncoder(nn.Module):
image = (image - low) / (high - low)
image = image.to(self.model.device, dtype=self.model.dtype)
if mask is not None:
mask = mask.to(image)
image = image * mask
inputs = self.transform(image)
outputs = self.model(inputs)