refactor(gds): add show_stats option; moved GDS initialization to dedicated file;

This commit is contained in:
Maifee Ul Asad 2025-10-12 00:54:04 +06:00
parent 64811809a0
commit a19f0a88e4
2 changed files with 39 additions and 27 deletions

View File

@ -54,6 +54,7 @@ class GDSConfig:
compression_aware: bool = True
max_concurrent_streams: int = 4
fallback_to_cpu: bool = True
show_stats: bool = False # Whether to show stats on exit
class GDSError(Exception):
@ -458,4 +459,36 @@ def get_gds_stats() -> Dict[str, Any]:
def configure_gds(config: GDSConfig):
"""Configure GDS settings"""
global _gds_instance
_gds_instance = GPUDirectStorage(config)
_gds_instance = GPUDirectStorage(config)
def init_gds(config: GDSConfig):
"""
Initialize GPUDirect Storage with the provided configuration
Args:
config: GDSConfig object with initialization parameters
"""
try:
# Configure GDS
configure_gds(config)
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
# Set up exit handler for stats if requested
if hasattr(config, 'show_stats') and config.show_stats:
import atexit
def print_gds_stats():
stats = get_gds_stats()
logging.info("=== GDS Statistics ===")
logging.info(f"Total loads: {stats['total_loads']}")
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
logging.info(f"Fallback loads: {stats['fallback_loads']}")
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
logging.info("===================")
atexit.register(print_gds_stats)
except ImportError as e:
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
except Exception as e:
logging.error(f"GDS initialization failed: {e}")

31
main.py
View File

@ -166,41 +166,20 @@ def init_gds():
# GDS not explicitly requested, use auto-detection
return
try:
from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats
if hasattr(args, 'enable_gds') and args.enable_gds:
from comfy.gds_loader import GDSConfig, init_gds as gds_init
# Create GDS configuration from CLI args
config = GDSConfig(
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
max_concurrent_streams=getattr(args, 'gds_streams', 4),
prefetch_enabled=getattr(args, 'gds_prefetch', True),
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False)
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
show_stats=getattr(args, 'gds_stats', False)
)
# Configure GDS
configure_gds(config)
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
# Set up exit handler for stats if requested
if getattr(args, 'gds_stats', False):
import atexit
def print_gds_stats():
stats = get_gds_stats()
logging.info("=== GDS Statistics ===")
logging.info(f"Total loads: {stats['total_loads']}")
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
logging.info(f"Fallback loads: {stats['fallback_loads']}")
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
logging.info("===================")
atexit.register(print_gds_stats)
except ImportError as e:
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
except Exception as e:
logging.error(f"GDS initialization failed: {e}")
gds_init(config)
# Initialize GDS
init_gds()