mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 15:29:08 +08:00
[Log] Optimize Startup Log (#26601)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
afc47e4de7
commit
e251e457c5
@ -13,6 +13,7 @@ from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric
|
|||||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
is_symmetric_memory_enabled,
|
is_symmetric_memory_enabled,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.parallel_state import is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -95,35 +96,35 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
from .all2all import NaiveAll2AllManager
|
from .all2all import NaiveAll2AllManager
|
||||||
|
|
||||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||||
logger.info("Using naive all2all manager.")
|
|
||||||
elif all2all_backend == "allgather_reducescatter":
|
elif all2all_backend == "allgather_reducescatter":
|
||||||
from .all2all import AgRsAll2AllManager
|
from .all2all import AgRsAll2AllManager
|
||||||
|
|
||||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||||
logger.info("Using AllGather-ReduceScatter all2all manager.")
|
|
||||||
elif all2all_backend == "pplx":
|
elif all2all_backend == "pplx":
|
||||||
from .all2all import PPLXAll2AllManager
|
from .all2all import PPLXAll2AllManager
|
||||||
|
|
||||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||||
logger.info("Using PPLX all2all manager.")
|
|
||||||
elif all2all_backend == "deepep_high_throughput":
|
elif all2all_backend == "deepep_high_throughput":
|
||||||
from .all2all import DeepEPHTAll2AllManager
|
from .all2all import DeepEPHTAll2AllManager
|
||||||
|
|
||||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||||
logger.info("Using DeepEP High-Throughput all2all manager.")
|
|
||||||
elif all2all_backend == "deepep_low_latency":
|
elif all2all_backend == "deepep_low_latency":
|
||||||
from .all2all import DeepEPLLAll2AllManager
|
from .all2all import DeepEPLLAll2AllManager
|
||||||
|
|
||||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
|
||||||
elif all2all_backend == "flashinfer_all2allv":
|
elif all2all_backend == "flashinfer_all2allv":
|
||||||
from .all2all import FlashInferAllToAllManager
|
from .all2all import FlashInferAllToAllManager
|
||||||
|
|
||||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||||
logger.info("Using Flashinfer all2allv manager.")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||||
|
|
||||||
|
if is_global_first_rank():
|
||||||
|
logger.info(
|
||||||
|
"Using %s all2all manager.",
|
||||||
|
self.all2all_manager.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def all_reduce(self, input_):
|
def all_reduce(self, input_):
|
||||||
# since currently we perform copy input -> symm_input -> out-of-place AR
|
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||||
# return symm_output, we don't need to check if input is symmetric
|
# return symm_output, we don't need to check if input is symmetric
|
||||||
|
|||||||
@ -105,11 +105,10 @@ class PyNcclCommunicator:
|
|||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
# get the unique id from NCCL
|
# get the unique id from NCCL
|
||||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||||
|
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||||
else:
|
else:
|
||||||
# construct an empty unique id
|
# construct an empty unique id
|
||||||
self.unique_id = ncclUniqueId()
|
self.unique_id = ncclUniqueId()
|
||||||
|
|||||||
@ -1144,7 +1144,7 @@ def find_nccl_library() -> str:
|
|||||||
so_file = "librccl.so.1"
|
so_file = "librccl.so.1"
|
||||||
else:
|
else:
|
||||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||||
logger.info("Found nccl from library %s", so_file)
|
logger.debug_once("Found nccl from library %s", so_file)
|
||||||
return so_file
|
return so_file
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -139,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
# FORCE_NUM_KV_SPLITS=1
|
# FORCE_NUM_KV_SPLITS=1
|
||||||
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
||||||
if force_num_kv_splits:
|
if force_num_kv_splits:
|
||||||
logger.warning_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
||||||
self._num_kv_splits = int(force_num_kv_splits)
|
self._num_kv_splits = int(force_num_kv_splits)
|
||||||
else:
|
else:
|
||||||
self._num_kv_splits = -1 # => Auto-detect
|
self._num_kv_splits = -1 # => Auto-detect
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import zmq
|
|||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||||
|
from vllm.distributed.parallel_state import is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -91,11 +92,12 @@ class EngineCore:
|
|||||||
load_general_plugins()
|
load_general_plugins()
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
logger.info(
|
if is_global_first_rank():
|
||||||
"Initializing a V1 LLM engine (v%s) with config: %s",
|
logger.info(
|
||||||
VLLM_VERSION,
|
"Initializing a V1 LLM engine (v%s) with config: %s",
|
||||||
vllm_config,
|
VLLM_VERSION,
|
||||||
)
|
vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
|
|||||||
@ -2876,7 +2876,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
time_before_load = time.perf_counter()
|
time_before_load = time.perf_counter()
|
||||||
model_loader = get_model_loader(self.load_config)
|
model_loader = get_model_loader(self.load_config)
|
||||||
logger.info("Loading model from scratch...")
|
|
||||||
self.model = model_loader.load_model(
|
self.model = model_loader.load_model(
|
||||||
vllm_config=self.vllm_config, model_config=self.model_config
|
vllm_config=self.vllm_config, model_config=self.model_config
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user