add local rank

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256 2025-10-15 12:47:14 -07:00
parent 0fbcfd64f7
commit b44c430a1d
2 changed files with 25 additions and 2 deletions

View File

@ -1623,6 +1623,29 @@ def is_global_first_rank() -> bool:
return True
def is_local_first_rank() -> bool:
"""
Check if the current process is the first local rank (rank 0 on its node).
"""
try:
# prefer the initialized world group if available
global _WORLD
if _WORLD is not None:
return _WORLD.local_rank == 0
if not torch.distributed.is_initialized():
return True
# fallback to environment-provided local rank if available
# note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
try:
return int(envs.LOCAL_RANK) == 0 # type: ignore[arg-type]
except Exception:
return torch.distributed.get_rank() == 0
except Exception:
return True
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
"""
Returns the total number of nodes in the process group.

View File

@ -13,7 +13,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed.parallel_state import is_global_first_rank
from vllm.distributed.parallel_state import is_local_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
@ -312,7 +312,7 @@ class DefaultModelLoader(BaseModelLoader):
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
self.counter_after_loading_weights = time.perf_counter()
if is_global_first_rank():
if is_local_first_rank():
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights