2025-01-23 16:41:39 +02:00

36 lines
1.5 KiB
Python

import importlib.metadata
import torch
import logging
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
def check_diffusers_version():
try:
version = importlib.metadata.version('diffusers')
required_version = '0.31.0'
if version < required_version:
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
except importlib.metadata.PackageNotFoundError:
raise AssertionError("diffusers is not installed.")
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
log.info(f"Allocated memory: {memory=:.3f} GB")
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
#log.info(f"Memory Summary:\n{memory_summary}")
def pil_list_to_torch_batch(normal_maps):
# Convert PIL images to numpy arrays and stack
arrays = [np.array(img) for img in normal_maps]
batch = np.stack(arrays, axis=0)
# Convert to torch tensor, ensure float32
tensor = torch.from_numpy(batch).float() / 255.0
# Tensor is now in B,H,W,C format
return tensor