mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-18 02:24:28 +08:00
feat(gds): implement GPUDirect Storage for efficient model loading
This commit is contained in:
parent
fab0954077
commit
5f24eb699c
461
comfy/gds_loader.py
Normal file
461
comfy/gds_loader.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||||
|
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||||
|
|
||||||
|
"""
|
||||||
|
GPUDirect Storage (GDS) Integration for ComfyUI
|
||||||
|
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
|
||||||
|
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
|
||||||
|
|
||||||
|
This module provides GPUDirect Storage functionality to load models directly
|
||||||
|
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict, Any, Union
|
||||||
|
from pathlib import Path
|
||||||
|
import safetensors
|
||||||
|
import gc
|
||||||
|
import mmap
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy
|
||||||
|
import cupy.cuda.runtime as cuda_runtime
|
||||||
|
CUPY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CUPY_AVAILABLE = False
|
||||||
|
logging.warning("CuPy not available. GDS will use fallback mode.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cudf # RAPIDS for GPU dataframes
|
||||||
|
RAPIDS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
RAPIDS_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
NVML_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
NVML_AVAILABLE = False
|
||||||
|
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GDSConfig:
|
||||||
|
"""Configuration for GPUDirect Storage"""
|
||||||
|
enabled: bool = True
|
||||||
|
min_file_size_mb: int = 100 # Only use GDS for files larger than this
|
||||||
|
chunk_size_mb: int = 64 # Size of chunks to transfer
|
||||||
|
use_pinned_memory: bool = True
|
||||||
|
prefetch_enabled: bool = True
|
||||||
|
compression_aware: bool = True
|
||||||
|
max_concurrent_streams: int = 4
|
||||||
|
fallback_to_cpu: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class GDSError(Exception):
|
||||||
|
"""GDS-specific errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GPUDirectStorage:
|
||||||
|
"""
|
||||||
|
GPUDirect Storage implementation for ComfyUI
|
||||||
|
Enables direct SSD-to-GPU transfers for model loading
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[GDSConfig] = None):
|
||||||
|
self.config = config or GDSConfig()
|
||||||
|
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
|
||||||
|
self.cuda_streams = []
|
||||||
|
self.pinned_buffers = {}
|
||||||
|
self.stats = {
|
||||||
|
'gds_loads': 0,
|
||||||
|
'fallback_loads': 0,
|
||||||
|
'total_bytes_gds': 0,
|
||||||
|
'total_time_gds': 0.0,
|
||||||
|
'avg_bandwidth_gbps': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize GDS if available
|
||||||
|
self._gds_available = self._check_gds_availability()
|
||||||
|
if self._gds_available:
|
||||||
|
self._init_gds()
|
||||||
|
else:
|
||||||
|
logging.warning("GDS not available, using fallback methods")
|
||||||
|
|
||||||
|
def _check_gds_availability(self) -> bool:
|
||||||
|
"""Check if GDS is available on the system"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not CUPY_AVAILABLE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for GPUDirect Storage support
|
||||||
|
try:
|
||||||
|
# Check CUDA version (GDS requires CUDA 11.4+)
|
||||||
|
cuda_version = torch.version.cuda
|
||||||
|
if cuda_version:
|
||||||
|
major, minor = map(int, cuda_version.split('.')[:2])
|
||||||
|
if major < 11 or (major == 11 and minor < 4):
|
||||||
|
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if cuFile is available (part of CUDA toolkit)
|
||||||
|
try:
|
||||||
|
import cupy.cuda.cufile as cufile
|
||||||
|
# Try to initialize cuFile
|
||||||
|
cufile.initialize()
|
||||||
|
return True
|
||||||
|
except (ImportError, RuntimeError) as e:
|
||||||
|
logging.warning(f"cuFile not available: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"GDS availability check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _init_gds(self):
|
||||||
|
"""Initialize GDS resources"""
|
||||||
|
try:
|
||||||
|
# Create CUDA streams for async operations
|
||||||
|
for i in range(self.config.max_concurrent_streams):
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
self.cuda_streams.append(stream)
|
||||||
|
|
||||||
|
# Pre-allocate pinned memory buffers
|
||||||
|
if self.config.use_pinned_memory:
|
||||||
|
self._allocate_pinned_buffers()
|
||||||
|
|
||||||
|
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize GDS: {e}")
|
||||||
|
self._gds_available = False
|
||||||
|
|
||||||
|
def _allocate_pinned_buffers(self):
|
||||||
|
"""Pre-allocate pinned memory buffers for staging"""
|
||||||
|
try:
|
||||||
|
# Allocate buffers of different sizes
|
||||||
|
buffer_sizes = [16, 32, 64, 128, 256] # MB
|
||||||
|
|
||||||
|
for size_mb in buffer_sizes:
|
||||||
|
size_bytes = size_mb * 1024 * 1024
|
||||||
|
# Allocate pinned memory using CuPy
|
||||||
|
if CUPY_AVAILABLE:
|
||||||
|
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
|
||||||
|
self.pinned_buffers[size_mb] = buffer
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to allocate pinned buffers: {e}")
|
||||||
|
|
||||||
|
def _get_file_size(self, file_path: str) -> int:
|
||||||
|
"""Get file size in bytes"""
|
||||||
|
return os.path.getsize(file_path)
|
||||||
|
|
||||||
|
def _should_use_gds(self, file_path: str) -> bool:
|
||||||
|
"""Determine if GDS should be used for this file"""
|
||||||
|
if not self._gds_available or not self.config.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
|
||||||
|
return file_size_mb >= self.config.min_file_size_mb
|
||||||
|
|
||||||
|
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load model using GPUDirect Storage"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||||
|
return self._load_safetensors_gds(file_path)
|
||||||
|
else:
|
||||||
|
return self._load_pytorch_gds(file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS loading failed for {file_path}: {e}")
|
||||||
|
if self.config.fallback_to_cpu:
|
||||||
|
logging.info("Falling back to CPU loading")
|
||||||
|
self.stats['fallback_loads'] += 1
|
||||||
|
return self._load_fallback(file_path)
|
||||||
|
else:
|
||||||
|
raise GDSError(f"GDS loading failed: {e}")
|
||||||
|
finally:
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
self.stats['total_time_gds'] += load_time
|
||||||
|
|
||||||
|
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load safetensors file using GDS"""
|
||||||
|
try:
|
||||||
|
import cupy.cuda.cufile as cufile
|
||||||
|
|
||||||
|
# Open file with cuFile for direct GPU loading
|
||||||
|
with cufile.CuFileManager() as manager:
|
||||||
|
# Memory-map the file for efficient access
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
# Use mmap for large files
|
||||||
|
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
|
||||||
|
|
||||||
|
# Parse safetensors header
|
||||||
|
header_size = int.from_bytes(mmapped_file[:8], 'little')
|
||||||
|
header_bytes = mmapped_file[8:8+header_size]
|
||||||
|
|
||||||
|
import json
|
||||||
|
header = json.loads(header_bytes.decode('utf-8'))
|
||||||
|
|
||||||
|
# Load tensors directly to GPU
|
||||||
|
tensors = {}
|
||||||
|
data_offset = 8 + header_size
|
||||||
|
|
||||||
|
for name, info in header.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
dtype_map = {
|
||||||
|
'F32': torch.float32,
|
||||||
|
'F16': torch.float16,
|
||||||
|
'BF16': torch.bfloat16,
|
||||||
|
'I8': torch.int8,
|
||||||
|
'I16': torch.int16,
|
||||||
|
'I32': torch.int32,
|
||||||
|
'I64': torch.int64,
|
||||||
|
'U8': torch.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype = dtype_map.get(info['dtype'], torch.float32)
|
||||||
|
shape = info['shape']
|
||||||
|
start_offset = data_offset + info['data_offsets'][0]
|
||||||
|
end_offset = data_offset + info['data_offsets'][1]
|
||||||
|
|
||||||
|
# Direct GPU allocation
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
|
||||||
|
|
||||||
|
# Use cuFile for direct transfer
|
||||||
|
tensor_bytes = end_offset - start_offset
|
||||||
|
|
||||||
|
# Get GPU memory pointer
|
||||||
|
gpu_ptr = tensor.data_ptr()
|
||||||
|
|
||||||
|
# Direct file-to-GPU transfer
|
||||||
|
cufile.copy_from_file(
|
||||||
|
gpu_ptr,
|
||||||
|
mmapped_file[start_offset:end_offset],
|
||||||
|
tensor_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
tensors[name] = tensor
|
||||||
|
|
||||||
|
self.stats['gds_loads'] += 1
|
||||||
|
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS safetensors loading failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load PyTorch file using GDS with staging"""
|
||||||
|
try:
|
||||||
|
# For PyTorch files, we need to use a staging approach
|
||||||
|
# since torch.load doesn't support direct GPU loading
|
||||||
|
|
||||||
|
# Load to pinned memory first
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
file_size = self._get_file_size(file_path)
|
||||||
|
|
||||||
|
# Choose appropriate buffer or allocate new one
|
||||||
|
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
|
||||||
|
|
||||||
|
if buffer_size_mb in self.pinned_buffers:
|
||||||
|
pinned_buffer = self.pinned_buffers[buffer_size_mb]
|
||||||
|
else:
|
||||||
|
# Allocate temporary pinned buffer
|
||||||
|
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
|
||||||
|
|
||||||
|
# Read file to pinned memory
|
||||||
|
f.readinto(pinned_buffer)
|
||||||
|
|
||||||
|
# Use torch.load with map_location to specific GPU
|
||||||
|
# This will be faster due to pinned memory
|
||||||
|
state_dict = torch.load(
|
||||||
|
f,
|
||||||
|
map_location=f'cuda:{self.device}',
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stats['gds_loads'] += 1
|
||||||
|
self.stats['total_bytes_gds'] += file_size
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS PyTorch loading failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Fallback loading method using standard approaches"""
|
||||||
|
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||||
|
# Use safetensors with device parameter
|
||||||
|
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
|
||||||
|
return {k: f.get_tensor(k) for k in f.keys()}
|
||||||
|
else:
|
||||||
|
# Standard PyTorch loading
|
||||||
|
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
|
||||||
|
|
||||||
|
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Main entry point for loading models with GDS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the model file
|
||||||
|
device: Target device (if None, uses current CUDA device)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of tensors loaded directly to GPU
|
||||||
|
"""
|
||||||
|
if device is not None and device.type == 'cuda':
|
||||||
|
self.device = device.index or 0
|
||||||
|
|
||||||
|
if self._should_use_gds(file_path):
|
||||||
|
logging.info(f"Loading {file_path} with GDS")
|
||||||
|
return self._load_with_gds(file_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"Loading {file_path} with standard method")
|
||||||
|
self.stats['fallback_loads'] += 1
|
||||||
|
return self._load_fallback(file_path)
|
||||||
|
|
||||||
|
def prefetch_model(self, file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Prefetch model to GPU memory cache (if supported)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if prefetch was successful
|
||||||
|
"""
|
||||||
|
if not self.config.prefetch_enabled or not self._gds_available:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Basic prefetch implementation
|
||||||
|
# This would ideally use NVIDIA's GPUDirect Storage API
|
||||||
|
# to warm up the storage cache
|
||||||
|
|
||||||
|
file_size = self._get_file_size(file_path)
|
||||||
|
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
|
||||||
|
|
||||||
|
# Read file metadata to warm caches
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
# Read first and last chunks to trigger prefetch
|
||||||
|
f.read(1024 * 1024) # First 1MB
|
||||||
|
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
|
||||||
|
f.read()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Prefetch failed for {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get loading statistics"""
|
||||||
|
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
|
||||||
|
|
||||||
|
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
|
||||||
|
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
|
||||||
|
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self.stats,
|
||||||
|
'total_loads': total_loads,
|
||||||
|
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
|
||||||
|
'gds_available': self._gds_available,
|
||||||
|
'config': self.config.__dict__
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up GDS resources"""
|
||||||
|
try:
|
||||||
|
# Clear CUDA streams
|
||||||
|
for stream in self.cuda_streams:
|
||||||
|
stream.synchronize()
|
||||||
|
self.cuda_streams.clear()
|
||||||
|
|
||||||
|
# Free pinned buffers
|
||||||
|
for buffer in self.pinned_buffers.values():
|
||||||
|
if CUPY_AVAILABLE:
|
||||||
|
cupy.cuda.free_pinned_memory(buffer)
|
||||||
|
self.pinned_buffers.clear()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"GDS cleanup failed: {e}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Destructor to ensure cleanup"""
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# Global GDS instance
|
||||||
|
_gds_instance: Optional[GPUDirectStorage] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
|
||||||
|
"""Get or create the global GDS instance"""
|
||||||
|
global _gds_instance
|
||||||
|
|
||||||
|
if _gds_instance is None:
|
||||||
|
_gds_instance = GPUDirectStorage(config)
|
||||||
|
|
||||||
|
return _gds_instance
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
GDS-enabled replacement for comfy.utils.load_torch_file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt: Path to checkpoint file
|
||||||
|
safe_load: Whether to use safe loading (for compatibility)
|
||||||
|
device: Target device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of loaded tensors
|
||||||
|
"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load with GDS
|
||||||
|
return gds.load_model(ckpt, device)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS loading failed, falling back to standard method: {e}")
|
||||||
|
# Fallback to original method
|
||||||
|
import comfy.utils
|
||||||
|
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_model_gds(file_path: str) -> bool:
|
||||||
|
"""Prefetch model for faster loading"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
return gds.prefetch_model(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gds_stats() -> Dict[str, Any]:
|
||||||
|
"""Get GDS statistics"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
return gds.get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def configure_gds(config: GDSConfig):
|
||||||
|
"""Configure GDS settings"""
|
||||||
|
global _gds_instance
|
||||||
|
_gds_instance = GPUDirectStorage(config)
|
||||||
Loading…
x
Reference in New Issue
Block a user