mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-19 02:54:35 +08:00
refactor(gds): add show_stats option; moved GDS initialization to dedicated file;
This commit is contained in:
parent
64811809a0
commit
a19f0a88e4
@ -54,6 +54,7 @@ class GDSConfig:
|
|||||||
compression_aware: bool = True
|
compression_aware: bool = True
|
||||||
max_concurrent_streams: int = 4
|
max_concurrent_streams: int = 4
|
||||||
fallback_to_cpu: bool = True
|
fallback_to_cpu: bool = True
|
||||||
|
show_stats: bool = False # Whether to show stats on exit
|
||||||
|
|
||||||
|
|
||||||
class GDSError(Exception):
|
class GDSError(Exception):
|
||||||
@ -458,4 +459,36 @@ def get_gds_stats() -> Dict[str, Any]:
|
|||||||
def configure_gds(config: GDSConfig):
|
def configure_gds(config: GDSConfig):
|
||||||
"""Configure GDS settings"""
|
"""Configure GDS settings"""
|
||||||
global _gds_instance
|
global _gds_instance
|
||||||
_gds_instance = GPUDirectStorage(config)
|
_gds_instance = GPUDirectStorage(config)
|
||||||
|
|
||||||
|
|
||||||
|
def init_gds(config: GDSConfig):
|
||||||
|
"""
|
||||||
|
Initialize GPUDirect Storage with the provided configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: GDSConfig object with initialization parameters
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Configure GDS
|
||||||
|
configure_gds(config)
|
||||||
|
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||||
|
|
||||||
|
# Set up exit handler for stats if requested
|
||||||
|
if hasattr(config, 'show_stats') and config.show_stats:
|
||||||
|
import atexit
|
||||||
|
def print_gds_stats():
|
||||||
|
stats = get_gds_stats()
|
||||||
|
logging.info("=== GDS Statistics ===")
|
||||||
|
logging.info(f"Total loads: {stats['total_loads']}")
|
||||||
|
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||||
|
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||||
|
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||||
|
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||||
|
logging.info("===================")
|
||||||
|
atexit.register(print_gds_stats)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS initialization failed: {e}")
|
||||||
31
main.py
31
main.py
@ -166,41 +166,20 @@ def init_gds():
|
|||||||
# GDS not explicitly requested, use auto-detection
|
# GDS not explicitly requested, use auto-detection
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
if hasattr(args, 'enable_gds') and args.enable_gds:
|
||||||
from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats
|
from comfy.gds_loader import GDSConfig, init_gds as gds_init
|
||||||
|
|
||||||
# Create GDS configuration from CLI args
|
|
||||||
config = GDSConfig(
|
config = GDSConfig(
|
||||||
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||||
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||||
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||||
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||||
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||||
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False)
|
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
|
||||||
|
show_stats=getattr(args, 'gds_stats', False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure GDS
|
gds_init(config)
|
||||||
configure_gds(config)
|
|
||||||
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
|
||||||
|
|
||||||
# Set up exit handler for stats if requested
|
|
||||||
if getattr(args, 'gds_stats', False):
|
|
||||||
import atexit
|
|
||||||
def print_gds_stats():
|
|
||||||
stats = get_gds_stats()
|
|
||||||
logging.info("=== GDS Statistics ===")
|
|
||||||
logging.info(f"Total loads: {stats['total_loads']}")
|
|
||||||
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
|
||||||
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
|
||||||
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
|
||||||
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
|
||||||
logging.info("===================")
|
|
||||||
atexit.register(print_gds_stats)
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"GDS initialization failed: {e}")
|
|
||||||
|
|
||||||
# Initialize GDS
|
# Initialize GDS
|
||||||
init_gds()
|
init_gds()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user