mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 02:18:42 +08:00
[misc] only tqdm for first rank (#6672)
This commit is contained in:
parent
97234be0ec
commit
c5201240a4
@ -313,6 +313,13 @@ def filter_files_not_needed_for_inference(
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
|
||||
hf_weights_files: List[str]
|
||||
@ -321,6 +328,8 @@ def np_cache_weights_iterator(
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
@ -331,8 +340,12 @@ def np_cache_weights_iterator(
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: List[str] = []
|
||||
for bin_file in tqdm(hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards"):
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
@ -356,8 +369,14 @@ def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in tqdm(hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards"):
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
@ -368,8 +387,14 @@ def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in tqdm(hf_weights_files,
|
||||
desc="Loading pt checkpoint shards"):
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user