diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 18c96b264ad82..00bb5cae43f00 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -110,8 +110,9 @@ vLLM CPU backend supports the following vLLM features: ## Related runtime environment variables -- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. -- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. +- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. +- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`. +- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`. - `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). ## Performance tips @@ -133,7 +134,15 @@ export VLLM_CPU_OMP_THREADS_BIND=0-29 vllm serve facebook/opt-125m ``` -- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND`. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: + or using default auto thread binding: + +```console +export VLLM_CPU_KVCACHE_SPACE=40 +export VLLM_CPU_NUM_OF_RESERVED_CPU=2 +vllm serve facebook/opt-125m +``` + +- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND` or using auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: ```console $ lscpu -e # check the mapping between logical CPU cores and physical CPU cores @@ -178,6 +187,12 @@ $ python examples/offline_inference/basic/basic.py VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp ``` + or using default auto thread binding: + + ```console + VLLM_CPU_KVCACHE_SPACE=40 vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp + ``` + - For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node. - Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory. diff --git a/requirements/cpu.txt b/requirements/cpu.txt index e43b443977524..d7b0fc6d80a74 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -27,3 +27,5 @@ triton==3.2.0; platform_machine == "x86_64" # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64" +py-libnuma; platform_system != "Darwin" +psutil; platform_system != "Darwin" diff --git a/vllm/envs.py b/vllm/envs.py index 6f876d3df6fd8..80c5f289bba90 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 VLLM_CPU_MOE_PREPACK: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False @@ -422,7 +423,12 @@ environment_variables: dict[str, Callable[[], Any]] = { # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. "VLLM_CPU_OMP_THREADS_BIND": - lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), + + # (CPU backend only) CPU cores not used by OMP threads . + # Those CPU cores will not be used by OMP threads of a rank. + "VLLM_CPU_NUM_OF_RESERVED_CPU": + lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")), # (CPU backend only) whether to use prepack for MoE layer. This will be # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 71c964fbfbb5e..27c591e3babd4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -208,6 +208,9 @@ class CpuPlatform(Platform): # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + # Share the cpusets list among ranks by spawning process instead + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + # Intel OpenMP setting ld_prealod_str = os.getenv("LD_PRELOAD", "") if "libiomp5.so" in ld_prealod_str: diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 0b710b7bc203f..9a35e88120386 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import os +from importlib import util from typing import Optional import torch @@ -38,10 +39,14 @@ class CPUWorker(Worker): def init_device(self): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND - if omp_cpuids == "all": - self.local_omp_cpuid = "all" + self.local_omp_cpuid = "all" + if omp_cpuids == "auto": + self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( + ) else: self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + + if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) if ret: logger.info(ret) @@ -99,3 +104,49 @@ class CPUWorker(Worker): assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None + + def get_cpus_id_binding_based_on_numa_nodes(self) -> str: + """Return CPUs id binding based on NUMA nodes. + """ + rank_to_cpus = self.local_omp_cpuid + # Setup OpenMP thread affinity based on NUMA nodes automatically + world_size = self.vllm_config.parallel_config.world_size + libnuma_found = util.find_spec("numa") is not None + psutil_found = util.find_spec("psutil") is not None + if libnuma_found and psutil_found: + import psutil + from numa import info + cpu_count = psutil.cpu_count(logical=False) + cpus_allow_list = psutil.Process().cpu_affinity() + numa_size = info.get_num_configured_nodes() + cpu_count_per_numa = cpu_count // numa_size + num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, + cpu_count_per_numa // 2) + + # check allow node_to_cpus list + node_to_cpus = [] + for i in range(numa_size): + node_intersect = set( + info.node_to_cpus(i)).intersection(cpus_allow_list) + if bool(node_intersect): + node_to_cpus.append(list(node_intersect)) + + if world_size > len(node_to_cpus): + logger.error( + "Auto thread-binding failed due to " + "world size: %d is larger than " + "allowed NUMA nodes number: %d." + "Please try to bind threads manually.", world_size, + len(node_to_cpus)) + else: + end = cpu_count_per_numa - num_of_reserved_cpu + rank_to_cpus_list = node_to_cpus[self.rank][:end] + rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) + logger.info("auto thread-binding list: %s", rank_to_cpus) + else: + logger.warning( + "Auto thread-binding is not supported due to " + "the lack of package numa and psutil," + "fallback to no thread-binding. To get better performance," + "please try to manually bind threads.") + return rank_to_cpus diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index b04a9a1eb08d1..9e834befd68ab 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A CPU worker class.""" import os +from importlib import util from typing import Dict, List, Optional, Set, Tuple, Type import torch @@ -156,8 +157,10 @@ class CPUWorker(LocalOrDistributedWorkerBase): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND - if omp_cpuids == "all": - self.local_omp_cpuid = "all" + self.local_omp_cpuid = "all" + if omp_cpuids == "auto": + self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( + ) else: self.local_omp_cpuid = omp_cpuids.split("|")[rank] @@ -399,3 +402,49 @@ class CPUWorker(LocalOrDistributedWorkerBase): return CPUCacheEngine.get_cache_block_size( self.cache_config.block_size, self.cache_config.cache_dtype, self.model_config, self.parallel_config) + + def get_cpus_id_binding_based_on_numa_nodes(self) -> str: + """Return CPUs id binding based on NUMA nodes. + """ + rank_to_cpus = self.local_omp_cpuid + # Setup OpenMP thread affinity based on NUMA nodes automatically + world_size = self.vllm_config.parallel_config.world_size + libnuma_found = util.find_spec("numa") is not None + psutil_found = util.find_spec("psutil") is not None + if libnuma_found and psutil_found: + import psutil + from numa import info + cpu_count = psutil.cpu_count(logical=False) + cpus_allow_list = psutil.Process().cpu_affinity() + numa_size = info.get_num_configured_nodes() + cpu_count_per_numa = cpu_count // numa_size + num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, + cpu_count_per_numa // 2) + + # check allow node_to_cpus list + node_to_cpus = [] + for i in range(numa_size): + node_intersect = set( + info.node_to_cpus(i)).intersection(cpus_allow_list) + if bool(node_intersect): + node_to_cpus.append(list(node_intersect)) + + if world_size > len(node_to_cpus): + logger.error( + "Auto thread-binding failed due to " + "world size: %d is larger than " + "allowed NUMA nodes number: %d." + "Please try to bind threads manually.", world_size, + len(node_to_cpus)) + else: + end = cpu_count_per_numa - num_of_reserved_cpu + rank_to_cpus_list = node_to_cpus[self.rank][:end] + rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) + logger.info("auto thread-binding list: %s", rank_to_cpus) + else: + logger.warning( + "Auto thread-binding is not supported due to " + "the lack of package numa and psutil," + "fallback to no thread-binding. To get better performance," + "please try to manually bind threads.") + return rank_to_cpus