[CPU] Enable data parallel for CPU backend (#23903)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-08-29 17:19:58 +08:00 committed by GitHub
parent 2554b27baa
commit ad39106b16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 48 additions and 9 deletions

View File

@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
function cpu_tests() {
set -e
@ -89,17 +89,33 @@ function cpu_tests() {
pytest -x -s -v \
tests/lora/test_qwen2vl.py"
# online serving
# online serving: tp+pp
docker exec cpu-test-"$NUMA_NODE" bash -c '
set -e
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
vllm bench serve \
--backend vllm \
--dataset-name random \
--model meta-llama/Llama-3.2-3B-Instruct \
--num-prompts 20 \
--endpoint /v1/completions'
--endpoint /v1/completions
kill -s SIGTERM $server_pid &'
# online serving: tp+dp
docker exec cpu-test-"$NUMA_NODE" bash -c '
set -e
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
vllm bench serve \
--backend vllm \
--dataset-name random \
--model meta-llama/Llama-3.2-3B-Instruct \
--num-prompts 20 \
--endpoint /v1/completions
kill -s SIGTERM $server_pid &'
}
# All of CPU tests are expected to be finished less than 40 mins.

View File

@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels.
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
@ -179,7 +180,7 @@ Inference batch size is an important parameter for the performance. Larger batch
- Offline Inference: `256 * world_size`
- Online Serving: `128 * world_size`
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes.
vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes.
### Which quantization configs does vLLM CPU support?

View File

@ -43,7 +43,7 @@ docker build -f docker/Dockerfile.cpu \
# Launching OpenAI server
docker run --rm \
--privileged=true \
--security-opt seccomp=unconfined \
--shm-size=4g \
-p 8000:8000 \
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \

View File

@ -69,6 +69,7 @@ class CpuPlatform(Platform):
device_type: str = "cpu"
dispatch_key: str = "CPU"
dist_backend: str = "gloo"
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
@property
def supported_dtypes(self) -> list[torch.dtype]:
@ -297,6 +298,13 @@ class CpuPlatform(Platform):
allowed_numa_nodes.add(x.numa_node) # type: ignore
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
env_key = CpuPlatform.device_control_env_var
if (env_key in os.environ and os.environ[env_key] != ""):
visible_nodes = [int(s) for s in os.environ[env_key].split(',')]
allowed_numa_nodes_list = [
x for x in visible_nodes if x in allowed_cpu_id_list
]
return allowed_numa_nodes_list, logical_cpu_list
@classmethod

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.nn as nn
@ -113,6 +113,11 @@ class CPUModelRunner(GPUModelRunner):
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
return sampled_token_ids.tolist()
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
# Note: For CPU backend, dp padding is not required for now.
return 0, None
@contextmanager
def _torch_cuda_wrapper():

View File

@ -55,7 +55,14 @@ class CPUWorker(Worker):
else:
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
local_dp_rank = self.parallel_config.data_parallel_rank_local
omp_cpuids = omp_cpuids.split("|")
if local_dp_rank is not None:
world_size = self.parallel_config.world_size
omp_cpuids = omp_cpuids[local_dp_rank *
world_size:(local_dp_rank + 1) *
world_size]
self.local_omp_cpuid = omp_cpuids[self.rank]
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
@ -162,7 +169,9 @@ class CPUWorker(Worker):
# Reserve CPUs for other processes
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
if reserve_cpu_num is None:
reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0
need_reserve = (self.parallel_config.world_size > 1 or
self.parallel_config.data_parallel_size_local > 1)
reserve_cpu_num = 1 if need_reserve else 0
assert len(logical_cpu_list) > reserve_cpu_num, (
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
f"should less than {len(logical_cpu_list)}.")