[CPU] V1 support for the CPU backend (#16441)

This commit is contained in:
Li, Jiang 2025-06-04 09:43:01 +08:00 committed by GitHub
parent 52dceb172d
commit 4555143ea7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 465 additions and 40 deletions

View File

@ -6,6 +6,7 @@ set -ex
# allow to bind to different cores
CORE_RANGE=${CORE_RANGE:-48-95}
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
NUMA_NODE=${NUMA_NODE:-1}
export CMAKE_BUILD_PARALLEL_LEVEL=32
@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
function cpu_tests() {
set -e
@ -56,7 +55,7 @@ function cpu_tests() {
# Run AWQ test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -s -v \
VLLM_USE_V1=0 pytest -s -v \
tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
@ -68,8 +67,6 @@ function cpu_tests() {
# online serving
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=$1
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
@ -89,4 +86,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 40 mins.
export -f cpu_tests
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"

View File

@ -40,6 +40,8 @@ This living user guide outlines a few known **important changes and limitations*
| **NVIDIA** | <nobr>🚀 Natively Supported</nobr> |
| **AMD** | <nobr>🚧 WIP</nobr> |
| **TPU** | <nobr>🚧 WIP</nobr> |
| **CPU** | <nobr>🚧 WIP</nobr> |
#### Feature / Model
| Feature / Model | Status |

View File

@ -1,6 +1,9 @@
# Common dependencies
-r common.txt
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
numba == 0.61.2; python_version > '3.9'
# Dependencies for CPUs
packaging>=24.2
setuptools>=77.0.3,<80.0.0

View File

@ -85,7 +85,10 @@ def test_env(
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
block_size, False)
assert backend.get_name() == "TORCH_SDPA"
if use_v1:
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
else:
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",

View File

@ -87,7 +87,6 @@ AITER_MODEL_LIST = [
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
"TitanML/tiny-mixtral", # mixtral
marks=[pytest.mark.cpu_model],
)
])
@pytest.mark.parametrize("max_tokens", [32])

View File

@ -178,7 +178,7 @@ class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
prefill_query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
@ -264,8 +264,8 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
key=k,
value=v_padded,
out=output,
seqlen_q=prefill_metadata.query_start_loc,
seqlen_k=prefill_metadata.query_start_loc,
seqlen_q=prefill_metadata.prefill_query_start_loc,
seqlen_k=prefill_metadata.prefill_query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
pdropout=0.0,

View File

@ -87,10 +87,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
query_start_loc: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
@ -375,7 +378,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
prefill_query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
@ -470,6 +473,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For warming-up
if attn_metadata is None:
return query
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@ -537,8 +545,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
assert attn_metadata.seq_lens is not None
self._run_sdpa_forward(output,
query,
key,
@ -555,7 +563,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.prefill_query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,

View File

@ -41,11 +41,16 @@ class TorchCompileWrapperWithCustomDispatcher:
# compiling the forward method
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = None
if isinstance(backend, str) and backend == "inductor":
options = get_current_vllm_config(
).compilation_config.inductor_compile_config
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
backend=backend,
options=options)
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__

View File

@ -1399,6 +1399,7 @@ class EngineArgs:
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
@ -1431,7 +1432,8 @@ class EngineArgs:
# Non-[CUDA, TPU] may be supported on V1, but off by default for now.
v0_hardware = not any(
(current_platform.is_cuda(), current_platform.is_tpu()))
(current_platform.is_cuda(), current_platform.is_tpu(),
current_platform.is_cpu()))
if v0_hardware and _warn_or_fallback( # noqa: SIM103
current_platform.device_name):
return False

View File

@ -57,7 +57,10 @@ class CpuPlatform(Platform):
logger.info("Using CPU MLA backend.")
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
logger.info("Using Torch SDPA backend.")
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
if use_v1:
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
else:
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
@ -81,6 +84,8 @@ class CpuPlatform(Platform):
if not model_config.enforce_eager:
model_config.enforce_eager = True
model_config.disable_cascade_attn = True
cache_config = vllm_config.cache_config
ipex_available = find_spec("intel_extension_for_pytorch") is not None
@ -128,7 +133,8 @@ class CpuPlatform(Platform):
f" {kv_cache_space}, expect a positive integer value.")
parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None
if (parallel_config.world_size > 1
and parallel_config.distributed_executor_backend is not None
and parallel_config.distributed_executor_backend != "mp"):
logger.warning(("%s is not supported on CPU, fallback to mp "
"distributed executor backend."),
@ -141,7 +147,38 @@ class CpuPlatform(Platform):
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
# Note: workaround for v1 gpu_model_runner
from vllm.config import CompilationLevel
vllm_config.compilation_config.cudagraph_capture_sizes = []
compilation_config = vllm_config.compilation_config
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE):
compilation_config.level = CompilationLevel.DYNAMO_ONCE
compilation_config.backend = "eager"
compilation_config.custom_ops += ["none"]
compilation_config.inductor_compile_config.update({
"dce":
True,
"size_asserts":
False,
"nan_asserts":
False,
"memory_planning":
True,
"epilogue_fusion":
True,
})
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
assert vllm_config.device_config.device_type == "cpu"
@ -149,6 +186,12 @@ class CpuPlatform(Platform):
# Environment variables for CPU executor
#
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Note: to avoid the error 'nthreads cannot be larger than environment
# variable "NUMEXPR_MAX_THREADS" (64)'.
os.environ["NUMEXPR_MAX_THREADS"] = str(len(os.sched_getaffinity(0)))
# Set default threads num for OpenMP parallel
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
@ -171,13 +214,6 @@ class CpuPlatform(Platform):
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
vllm_config.parallel_config.tensor_parallel_size)
if sys.platform == "darwin" and \
envs.VLLM_WORKER_MULTIPROC_METHOD == "fork":
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None:
logger.warning(
"Default to spawn method on MacOS. If this is not desired,"
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
@ -204,3 +240,14 @@ class CpuPlatform(Platform):
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
@classmethod
def supports_structured_output(cls) -> bool:
return True
@classmethod
def supports_v1(cls, model_config) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return True

View File

@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
class TorchSDPABackend:
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class TorchSDPAMetadataBuilderV1:
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table
# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
runner.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
runner = self.runner
block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs]
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
) - num_prefill_tokens
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
block_table_tensor = block_table.get_device_tensor()
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.
query_start_loc_cpu[:num_prompt_req + 1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
1], # for logits index
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return attn_metadata

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class CPUModelRunner(GPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)
assert device == torch.device("cpu")
assert self.speculative_config is None, "spec decode is not supported."
self.use_cuda_graph = False
self.cascade_attn_enabled = False
self._postprocess_tenosrs()
def _postprocess_tenosrs(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str,
device_attr_name) -> None:
cpu_tensor = getattr(obj, cpu_attr_name, None)
device_tensor = getattr(obj, device_attr_name, None)
if cpu_tensor is not None and device_tensor is not None:
assert isinstance(cpu_tensor, torch.Tensor)
assert isinstance(device_tensor, torch.Tensor)
setattr(obj, device_attr_name, cpu_tensor)
for k, v in vars(self).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self, k, k[:-4])
for k, v in vars(self.input_batch).items():
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch, k, k[:-11])
for k, v in vars(self.input_batch.block_table).items():
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config,
self.scheduler_config,
self.lora_config, self.device)
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
self._dummy_run(max(16, self.max_num_reqs))
logger.info("Warming up done.")
def _init_device_properties(self) -> None:
pass
def _sync_device(self) -> None:
pass
@contextmanager
def _set_global_compilation_settings():
import torch._inductor.config
# Note: The CPPGEMM backend requires freezing parameters.
freezing_value = torch._inductor.config.freezing
torch._inductor.config.freezing = True
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
# including object type dict"
force_disable_caches = torch._inductor.config.force_disable_caches
torch._inductor.config.force_disable_caches = True
yield
torch._inductor.config.freezing = freezing_value
torch._inductor.config.force_disable_caches = force_disable_caches

View File

@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
logger = init_logger(__name__)
class CPUWorker(Worker):
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
super().__init__(vllm_config,
local_rank,
rank,
distributed_init_method,
is_driver_worker=is_driver_worker)
self.parallel_config.disable_custom_all_reduce = True
def init_device(self):
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
if omp_cpuids == "all":
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, "gloo")
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner: CPUModelRunner = CPUModelRunner(
self.vllm_config, torch.device("cpu"))
def sleep(self, level: int = 1) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def wake_up(self, tags: Optional[list[str]] = None) -> None:
logger.warning("sleep mode is not supported on CPU, ignore it.")
pass
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None

View File

@ -5,7 +5,7 @@ import copy
import gc
import time
import weakref
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
@ -38,7 +38,6 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
@ -203,8 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# Cache the device properties.
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
self._init_device_properties()
# Persistent buffers for CUDA graphs.
self.input_ids = torch.zeros(self.max_num_tokens,
@ -315,6 +313,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch, scheduler_output)
return batch_reordered
# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties
"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
# Note: used for model runner override.
def _sync_device(self) -> None:
torch.cuda.synchronize()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@ -538,8 +547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
Optional[SpecDecodeMetadata]]:
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@ -652,7 +660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata: dict[str, FlashAttentionMetadata] = {}
attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@ -1710,7 +1718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Must synchronize the non-blocking GPU->CPU transfers.
if prompt_logprobs_dict:
torch.cuda.synchronize()
self._sync_device()
return prompt_logprobs_dict
@ -1740,7 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=np.int32)
if skip_attn:
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
attn_metadata: Optional[dict[str, Any]] = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
@ -1964,7 +1972,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampler_output = self._dummy_sampler_run(hidden_states)
else:
sampler_output = None
torch.cuda.synchronize()
self._sync_device()
del hidden_states, sampler_output
self.encoder_cache.clear()
gc.collect()

View File

@ -342,13 +342,14 @@ def init_worker_distributed_environment(
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "nccl",
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)