[HARDWARE][CPU] Add Option for Disabling Binding to Specific CPU Cores (#27953)

Signed-off-by: Stan Hatko <stan_hatko@live.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
StanHatko 2025-11-06 10:47:11 -05:00 committed by GitHub
parent 2176778cd3
commit e52e4da971
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 8 deletions

View File

@ -94,7 +94,7 @@ Currently, there are no pre-built CPU wheels.
## Related runtime environment variables
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).

View File

@ -14,6 +14,7 @@ from typing import TYPE_CHECKING
import regex as re
import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
@ -151,7 +152,6 @@ class CpuPlatform(Platform):
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
import vllm.envs as envs
from vllm.utils.mem_constants import GiB_bytes
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
@ -289,11 +289,16 @@ class CpuPlatform(Platform):
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Note: to avoid the error 'nthreads cannot be larger than environment
# variable "NUMEXPR_MAX_THREADS" (64)'.
# variable "NUMEXPR_MAX_THREADS" (64)'.
os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
# Set default threads num for OpenMP parallel
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
if envs.VLLM_CPU_OMP_THREADS_BIND != "nobind":
# Set default threads num for OpenMP parallel
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
else:
# In this case, setting the OpenMP configuration via
# OMP_NUM_THREADS is up to the user.
logger.info("Disabling binding processes to CPU cores...")
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

View File

@ -69,13 +69,15 @@ class CPUWorker(Worker):
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
)
elif current_platform.get_cpu_architecture() == CpuArchEnum.X86:
elif cpu_arch == CpuArchEnum.X86:
# For x86 SMT-2, use 1 CPU per core
self.local_omp_cpuid = self._get_autobind_cpu_ids(
lambda cpus: cpus[-1:]
)
else:
self.local_omp_cpuid = "all"
self.local_omp_cpuid = "nobind"
elif omp_cpuids == "nobind":
self.local_omp_cpuid = "nobind"
else:
local_dp_rank = self.parallel_config.data_parallel_rank_local
omp_cpuids = omp_cpuids.split("|")
@ -86,7 +88,7 @@ class CPUWorker(Worker):
]
self.local_omp_cpuid = omp_cpuids[self.rank]
if self.local_omp_cpuid != "all":
if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)