From a19f0a88e45e4e0cfc56eff0e4e4096c5d7e9813 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Sun, 12 Oct 2025 00:54:04 +0600 Subject: [PATCH] refactor(gds): add `show_stats` option; moved GDS initialization to dedicated file; --- comfy/gds_loader.py | 35 ++++++++++++++++++++++++++++++++++- main.py | 31 +++++-------------------------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/comfy/gds_loader.py b/comfy/gds_loader.py index 21c175540..7c7f2530b 100644 --- a/comfy/gds_loader.py +++ b/comfy/gds_loader.py @@ -54,6 +54,7 @@ class GDSConfig: compression_aware: bool = True max_concurrent_streams: int = 4 fallback_to_cpu: bool = True + show_stats: bool = False # Whether to show stats on exit class GDSError(Exception): @@ -458,4 +459,36 @@ def get_gds_stats() -> Dict[str, Any]: def configure_gds(config: GDSConfig): """Configure GDS settings""" global _gds_instance - _gds_instance = GPUDirectStorage(config) \ No newline at end of file + _gds_instance = GPUDirectStorage(config) + + +def init_gds(config: GDSConfig): + """ + Initialize GPUDirect Storage with the provided configuration + + Args: + config: GDSConfig object with initialization parameters + """ + try: + # Configure GDS + configure_gds(config) + logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}") + + # Set up exit handler for stats if requested + if hasattr(config, 'show_stats') and config.show_stats: + import atexit + def print_gds_stats(): + stats = get_gds_stats() + logging.info("=== GDS Statistics ===") + logging.info(f"Total loads: {stats['total_loads']}") + logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)") + logging.info(f"Fallback loads: {stats['fallback_loads']}") + logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB") + logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s") + logging.info("===================") + atexit.register(print_gds_stats) + + except ImportError as e: + logging.warning(f"GDS initialization failed - missing dependencies: {e}") + except Exception as e: + logging.error(f"GDS initialization failed: {e}") \ No newline at end of file diff --git a/main.py b/main.py index 9c4414a17..9417fda87 100644 --- a/main.py +++ b/main.py @@ -166,41 +166,20 @@ def init_gds(): # GDS not explicitly requested, use auto-detection return - try: - from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats + if hasattr(args, 'enable_gds') and args.enable_gds: + from comfy.gds_loader import GDSConfig, init_gds as gds_init - # Create GDS configuration from CLI args config = GDSConfig( enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False), min_file_size_mb=getattr(args, 'gds_min_file_size', 100), chunk_size_mb=getattr(args, 'gds_chunk_size', 64), max_concurrent_streams=getattr(args, 'gds_streams', 4), prefetch_enabled=getattr(args, 'gds_prefetch', True), - fallback_to_cpu=not getattr(args, 'gds_no_fallback', False) + fallback_to_cpu=not getattr(args, 'gds_no_fallback', False), + show_stats=getattr(args, 'gds_stats', False) ) - # Configure GDS - configure_gds(config) - logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}") - - # Set up exit handler for stats if requested - if getattr(args, 'gds_stats', False): - import atexit - def print_gds_stats(): - stats = get_gds_stats() - logging.info("=== GDS Statistics ===") - logging.info(f"Total loads: {stats['total_loads']}") - logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)") - logging.info(f"Fallback loads: {stats['fallback_loads']}") - logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB") - logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s") - logging.info("===================") - atexit.register(print_gds_stats) - - except ImportError as e: - logging.warning(f"GDS initialization failed - missing dependencies: {e}") - except Exception as e: - logging.error(f"GDS initialization failed: {e}") + gds_init(config) # Initialize GDS init_gds()