[Misc] add use_tqdm_on_load to reduce logs (#14407)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2025-03-08 08:57:46 -05:00 committed by GitHub
parent 03fe18ae0f
commit 0b7f06b447
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 22 deletions

View File

@ -1277,6 +1277,8 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
loading. Default to True
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
@ -1284,6 +1286,7 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field( model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict) default_factory=dict)
ignore_patterns: Optional[Union[list[str], str]] = None ignore_patterns: Optional[Union[list[str], str]] = None
use_tqdm_on_load: bool = True
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """

View File

@ -217,6 +217,7 @@ class EngineArgs:
additional_config: Optional[Dict[str, Any]] = None additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
use_tqdm_on_load: bool = True
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
@ -751,6 +752,14 @@ class EngineArgs:
default=1, default=1,
help=('Maximum number of forward steps per ' help=('Maximum number of forward steps per '
'scheduler call.')) 'scheduler call.'))
parser.add_argument(
'--use-tqdm-on-load',
dest='use_tqdm_on_load',
action=argparse.BooleanOptionalAction,
default=EngineArgs.use_tqdm_on_load,
help='Whether to enable/disable progress bar '
'when loading model weights.',
)
parser.add_argument( parser.add_argument(
'--multi-step-stream-outputs', '--multi-step-stream-outputs',
@ -1179,6 +1188,7 @@ class EngineArgs:
download_dir=self.download_dir, download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns, ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
) )
def create_engine_config(self, def create_engine_config(self,

View File

@ -354,11 +354,18 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.download_dir, self.load_config.download_dir,
hf_folder, hf_folder,
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load,
) )
elif use_safetensors: elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else: else:
weights_iterator = pt_weights_iterator(hf_weights_files) weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if current_platform.is_tpu(): if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that # In PyTorch XLA, we should call `xm.mark_step` frequently so that
@ -806,9 +813,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors: if use_safetensors:
iterator = safetensors_weights_iterator(hf_weights_files) iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else: else:
iterator = pt_weights_iterator(hf_weights_files) iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
for org_name, param in iterator: for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving # mapping weight names from transformers to vllm while preserving
# original names. # original names.
@ -1396,7 +1409,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision) hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files) return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary""" """Download model if necessary"""

View File

@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def enable_tqdm(use_tqdm_on_load: bool):
return use_tqdm_on_load and (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)
def np_cache_weights_iterator( def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, model_name_or_path: str,
hf_weights_files: List[str] cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files. """Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped. Will dump the model weights to numpy files if they are not already dumped.
""" """
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, "np") np_folder = os.path.join(hf_folder, "np")
@ -389,7 +395,7 @@ def np_cache_weights_iterator(
for bin_file in tqdm( for bin_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading np_cache checkpoint shards", desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm, disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, state = torch.load(bin_file,
@ -414,15 +420,14 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: List[str] hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm( for st_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading safetensors checkpoint shards", desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm, disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
@ -432,16 +437,15 @@ def safetensors_weights_iterator(
def runai_safetensors_weights_iterator( def runai_safetensors_weights_iterator(
hf_weights_files: List[str] hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer: with SafetensorsStreamer() as streamer:
for st_file in tqdm( for st_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading safetensors using Runai Model Streamer", desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm, disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
streamer.stream_file(st_file) streamer.stream_file(st_file)
@ -449,15 +453,14 @@ def runai_safetensors_weights_iterator(
def pt_weights_iterator( def pt_weights_iterator(
hf_weights_files: List[str] hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files.""" """Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm( for bin_file in tqdm(
hf_weights_files, hf_weights_files,
desc="Loading pt checkpoint shards", desc="Loading pt checkpoint shards",
disable=not enable_tqdm, disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu", weights_only=True) state = torch.load(bin_file, map_location="cpu", weights_only=True)