mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:25:41 +08:00
[Misc] add use_tqdm_on_load to reduce logs (#14407)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
03fe18ae0f
commit
0b7f06b447
@ -1277,6 +1277,8 @@ class LoadConfig:
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
|
||||
loading. Default to True
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
@ -1284,6 +1286,7 @@ class LoadConfig:
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||
default_factory=dict)
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
use_tqdm_on_load: bool = True
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
@ -217,6 +217,7 @@ class EngineArgs:
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
enable_reasoning: Optional[bool] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
use_tqdm_on_load: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
@ -751,6 +752,14 @@ class EngineArgs:
|
||||
default=1,
|
||||
help=('Maximum number of forward steps per '
|
||||
'scheduler call.'))
|
||||
parser.add_argument(
|
||||
'--use-tqdm-on-load',
|
||||
dest='use_tqdm_on_load',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=EngineArgs.use_tqdm_on_load,
|
||||
help='Whether to enable/disable progress bar '
|
||||
'when loading model weights.',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--multi-step-stream-outputs',
|
||||
@ -1179,6 +1188,7 @@ class EngineArgs:
|
||||
download_dir=self.download_dir,
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def create_engine_config(self,
|
||||
|
||||
@ -354,11 +354,18 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
@ -806,9 +813,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(hf_weights_files)
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
@ -1396,7 +1409,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
|
||||
@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def enable_tqdm(use_tqdm_on_load: bool):
|
||||
return use_tqdm_on_load and (not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_rank() == 0)
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
|
||||
hf_weights_files: List[str]
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
hf_folder: str,
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
@ -389,7 +395,7 @@ def np_cache_weights_iterator(
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
@ -414,15 +420,14 @@ def np_cache_weights_iterator(
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
@ -432,16 +437,15 @@ def safetensors_weights_iterator(
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
@ -449,15 +453,14 @@ def runai_safetensors_weights_iterator(
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user