diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 8b8f0e8c6578d..0f734763f13fd 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -89,17 +89,33 @@ function cpu_tests() { pytest -x -s -v \ tests/lora/test_qwen2vl.py" - # online serving + # online serving: tp+pp docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions' + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' + + # online serving: tp+dp + docker exec cpu-test-"$NUMA_NODE" bash -c ' + set -e + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & + server_pid=$! + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + vllm bench serve \ + --backend vllm \ + --dataset-name random \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --num-prompts 20 \ + --endpoint /v1/completions + kill -s SIGTERM $server_pid &' } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index e76ec35e1edcb..7f0ecb2bc0b74 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. - `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). @@ -179,7 +180,7 @@ Inference batch size is an important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 6dc6f94249c34..f7af259ace628 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -43,7 +43,7 @@ docker build -f docker/Dockerfile.cpu \ # Launching OpenAI server docker run --rm \ - --privileged=true \ + --security-opt seccomp=unconfined \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 5686fae5cd7d1..12d5e0bf08652 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -69,6 +69,7 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" dist_backend: str = "gloo" + device_control_env_var = "CPU_VISIBLE_MEMORY_NODES" @property def supported_dtypes(self) -> list[torch.dtype]: @@ -297,6 +298,13 @@ class CpuPlatform(Platform): allowed_numa_nodes.add(x.numa_node) # type: ignore allowed_numa_nodes_list = sorted(allowed_numa_nodes) + env_key = CpuPlatform.device_control_env_var + if (env_key in os.environ and os.environ[env_key] != ""): + visible_nodes = [int(s) for s in os.environ[env_key].split(',')] + allowed_numa_nodes_list = [ + x for x in visible_nodes if x in allowed_cpu_id_list + ] + return allowed_numa_nodes_list, logical_cpu_list @classmethod diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 7d0726112704a..226d7792a42f7 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch import torch.nn as nn @@ -113,6 +113,11 @@ class CPUModelRunner(GPUModelRunner): def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: return sampled_token_ids.tolist() + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + # Note: For CPU backend, dp padding is not required for now. + return 0, None + @contextmanager def _torch_cuda_wrapper(): diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index be78597926e09..b87c4fe09bb90 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -55,7 +55,14 @@ class CPUWorker(Worker): else: self.local_omp_cpuid = "all" else: - self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + local_dp_rank = self.parallel_config.data_parallel_rank_local + omp_cpuids = omp_cpuids.split("|") + if local_dp_rank is not None: + world_size = self.parallel_config.world_size + omp_cpuids = omp_cpuids[local_dp_rank * + world_size:(local_dp_rank + 1) * + world_size] + self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) @@ -162,7 +169,9 @@ class CPUWorker(Worker): # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0 + need_reserve = (self.parallel_config.world_size > 1 or + self.parallel_config.data_parallel_size_local > 1) + reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " f"should less than {len(logical_cpu_list)}.")