[Misc] add use_tqdm_on_load to reduce logs (#14407)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2025-03-08 08:57:46 -05:00 committed by GitHub
parent 03fe18ae0f
commit 0b7f06b447
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 22 deletions

View File

@ -1277,6 +1277,8 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
loading. Default to True
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
@ -1284,6 +1286,7 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[list[str], str]] = None
use_tqdm_on_load: bool = True
def compute_hash(self) -> str:
"""

View File

@ -217,6 +217,7 @@ class EngineArgs:
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None
use_tqdm_on_load: bool = True
def __post_init__(self):
if not self.tokenizer:
@ -751,6 +752,14 @@ class EngineArgs:
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--use-tqdm-on-load',
dest='use_tqdm_on_load',
action=argparse.BooleanOptionalAction,
default=EngineArgs.use_tqdm_on_load,
help='Whether to enable/disable progress bar '
'when loading model weights.',
)
parser.add_argument(
'--multi-step-stream-outputs',
@ -1179,6 +1188,7 @@ class EngineArgs:
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
)
def create_engine_config(self,

View File

@ -354,11 +354,18 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
@ -806,9 +813,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
iterator = safetensors_weights_iterator(hf_weights_files)
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(hf_weights_files)
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
@ -1396,7 +1409,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""

View File

@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def enable_tqdm(use_tqdm_on_load: bool):
return use_tqdm_on_load and (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)
def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
@ -389,7 +395,7 @@ def np_cache_weights_iterator(
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
@ -414,15 +420,14 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
@ -432,16 +437,15 @@ def safetensors_weights_iterator(
def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
@ -449,15 +453,14 @@ def runai_safetensors_weights_iterator(
def pt_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)