mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[CPU] V1 support for the CPU backend (#16441)
This commit is contained in:
parent
52dceb172d
commit
4555143ea7
@ -6,6 +6,7 @@ set -ex
|
||||
|
||||
# allow to bind to different cores
|
||||
CORE_RANGE=${CORE_RANGE:-48-95}
|
||||
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
|
||||
NUMA_NODE=${NUMA_NODE:-1}
|
||||
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=32
|
||||
@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
@ -56,7 +55,7 @@ function cpu_tests() {
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
@ -68,8 +67,6 @@ function cpu_tests() {
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
export VLLM_CPU_KVCACHE_SPACE=10
|
||||
export VLLM_CPU_OMP_THREADS_BIND=$1
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
@ -89,4 +86,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -40,6 +40,8 @@ This living user guide outlines a few known **important changes and limitations*
|
||||
| **NVIDIA** | <nobr>🚀 Natively Supported</nobr> |
|
||||
| **AMD** | <nobr>🚧 WIP</nobr> |
|
||||
| **TPU** | <nobr>🚧 WIP</nobr> |
|
||||
| **CPU** | <nobr>🚧 WIP</nobr> |
|
||||
|
||||
#### Feature / Model
|
||||
|
||||
| Feature / Model | Status |
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
# Common dependencies
|
||||
-r common.txt
|
||||
|
||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
|
||||
# Dependencies for CPUs
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -85,7 +85,10 @@ def test_env(
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
block_size, False)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
if use_v1:
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
else:
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
|
||||
@ -87,7 +87,6 @@ AITER_MODEL_LIST = [
|
||||
pytest.param("bigcode/starcoder2-3b"), # starcoder2
|
||||
pytest.param(
|
||||
"TitanML/tiny-mixtral", # mixtral
|
||||
marks=[pytest.mark.cpu_model],
|
||||
)
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
|
||||
@ -178,7 +178,7 @@ class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
@ -264,8 +264,8 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
key=k,
|
||||
value=v_padded,
|
||||
out=output,
|
||||
seqlen_q=prefill_metadata.query_start_loc,
|
||||
seqlen_k=prefill_metadata.query_start_loc,
|
||||
seqlen_q=prefill_metadata.prefill_query_start_loc,
|
||||
seqlen_k=prefill_metadata.prefill_query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.max_query_len,
|
||||
pdropout=0.0,
|
||||
|
||||
@ -87,10 +87,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# For V1 logits index only
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
@ -375,7 +378,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
@ -470,6 +473,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
return query
|
||||
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
@ -537,8 +545,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
|
||||
assert attn_metadata.seq_lens is not None
|
||||
self._run_sdpa_forward(output,
|
||||
query,
|
||||
key,
|
||||
@ -555,7 +563,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.prefill_query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
|
||||
@ -41,11 +41,16 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# compiling the forward method
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = get_current_vllm_config(
|
||||
).compilation_config.inductor_compile_config
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=backend)
|
||||
backend=backend,
|
||||
options=options)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
|
||||
@ -1399,6 +1399,7 @@ class EngineArgs:
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
"ROCM_AITER_MLA",
|
||||
"TORCH_SDPA_VLLM_V1",
|
||||
]
|
||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||
@ -1431,7 +1432,8 @@ class EngineArgs:
|
||||
|
||||
# Non-[CUDA, TPU] may be supported on V1, but off by default for now.
|
||||
v0_hardware = not any(
|
||||
(current_platform.is_cuda(), current_platform.is_tpu()))
|
||||
(current_platform.is_cuda(), current_platform.is_tpu(),
|
||||
current_platform.is_cpu()))
|
||||
if v0_hardware and _warn_or_fallback( # noqa: SIM103
|
||||
current_platform.device_name):
|
||||
return False
|
||||
|
||||
@ -57,7 +57,10 @@ class CpuPlatform(Platform):
|
||||
logger.info("Using CPU MLA backend.")
|
||||
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||
if use_v1:
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
else:
|
||||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
@ -81,6 +84,8 @@ class CpuPlatform(Platform):
|
||||
if not model_config.enforce_eager:
|
||||
model_config.enforce_eager = True
|
||||
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
ipex_available = find_spec("intel_extension_for_pytorch") is not None
|
||||
@ -128,7 +133,8 @@ class CpuPlatform(Platform):
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.distributed_executor_backend is not None
|
||||
if (parallel_config.world_size > 1
|
||||
and parallel_config.distributed_executor_backend is not None
|
||||
and parallel_config.distributed_executor_backend != "mp"):
|
||||
logger.warning(("%s is not supported on CPU, fallback to mp "
|
||||
"distributed executor backend."),
|
||||
@ -141,7 +147,38 @@ class CpuPlatform(Platform):
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
|
||||
# Note: workaround for v1 gpu_model_runner
|
||||
from vllm.config import CompilationLevel
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE):
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
compilation_config.backend = "eager"
|
||||
compilation_config.custom_ops += ["none"]
|
||||
compilation_config.inductor_compile_config.update({
|
||||
"dce":
|
||||
True,
|
||||
"size_asserts":
|
||||
False,
|
||||
"nan_asserts":
|
||||
False,
|
||||
"memory_planning":
|
||||
True,
|
||||
"epilogue_fusion":
|
||||
True,
|
||||
})
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
assert vllm_config.device_config.device_type == "cpu"
|
||||
|
||||
@ -149,6 +186,12 @@ class CpuPlatform(Platform):
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
# Note: to avoid the error 'nthreads cannot be larger than environment
|
||||
# variable "NUMEXPR_MAX_THREADS" (64)'.
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(len(os.sched_getaffinity(0)))
|
||||
|
||||
# Set default threads num for OpenMP parallel
|
||||
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
|
||||
|
||||
@ -171,13 +214,6 @@ class CpuPlatform(Platform):
|
||||
# To hint IPEX uses shared memory based AllReduce
|
||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
if sys.platform == "darwin" and \
|
||||
envs.VLLM_WORKER_MULTIPROC_METHOD == "fork":
|
||||
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None:
|
||||
logger.warning(
|
||||
"Default to spawn method on MacOS. If this is not desired,"
|
||||
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
@ -204,3 +240,14 @@ class CpuPlatform(Platform):
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_structured_output(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config) -> bool:
|
||||
"""Returns whether the current platform can support v1 for the supplied
|
||||
model configuration.
|
||||
"""
|
||||
return True
|
||||
|
||||
163
vllm/v1/attention/backends/cpu_attn.py
Normal file
163
vllm/v1/attention/backends/cpu_attn.py
Normal file
@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class TorchSDPABackend:
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||
return TorchSDPAMetadataBuilderV1
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1:
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
self.runner = runner
|
||||
self.block_table = block_table
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
) - num_prefill_tokens
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
|
||||
block_table_tensor = block_table.get_device_tensor()
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=runner.
|
||||
seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
chunked_prefill=True,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=runner.
|
||||
query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
86
vllm/v1/worker/cpu_model_runner.py
Normal file
86
vllm/v1/worker/cpu_model_runner.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tenosrs()
|
||||
|
||||
def _postprocess_tenosrs(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||
device_attr_name) -> None:
|
||||
cpu_tensor = getattr(obj, cpu_attr_name, None)
|
||||
device_tensor = getattr(obj, device_attr_name, None)
|
||||
if cpu_tensor is not None and device_tensor is not None:
|
||||
assert isinstance(cpu_tensor, torch.Tensor)
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for k, v in vars(self).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self, k, k[:-4])
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for k, v in vars(self.input_batch.block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch.block_table, k, k[:-4])
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
self._dummy_run(max(16, self.max_num_reqs))
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings():
|
||||
import torch._inductor.config
|
||||
|
||||
# Note: The CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch._inductor.config.freezing
|
||||
torch._inductor.config.freezing = True
|
||||
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
|
||||
# including object type dict"
|
||||
force_disable_caches = torch._inductor.config.force_disable_caches
|
||||
torch._inductor.config.force_disable_caches = True
|
||||
yield
|
||||
torch._inductor.config.freezing = freezing_value
|
||||
torch._inductor.config.force_disable_caches = force_disable_caches
|
||||
101
vllm/v1/worker/cpu_worker.py
Normal file
101
vllm/v1/worker/cpu_worker.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUWorker(Worker):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
super().__init__(vllm_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
def init_device(self):
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
if omp_cpuids == "all":
|
||||
self.local_omp_cpuid = "all"
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank, "gloo")
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||
self.vllm_config, torch.device("cpu"))
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
logger.warning("sleep mode is not supported on CPU, ignore it.")
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model_runner.warming_up_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
@ -5,7 +5,7 @@ import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -38,7 +38,6 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
@ -203,8 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# Cache the device properties.
|
||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||
self.num_sms = self.device_properties.multi_processor_count
|
||||
self._init_device_properties()
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
@ -315,6 +313,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch, scheduler_output)
|
||||
return batch_reordered
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _init_device_properties(self) -> None:
|
||||
"""Initialize attributes from torch.cuda.get_device_properties
|
||||
"""
|
||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||
self.num_sms = self.device_properties.multi_processor_count
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _sync_device(self) -> None:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@ -538,8 +547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -652,7 +660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@ -1710,7 +1718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Must synchronize the non-blocking GPU->CPU transfers.
|
||||
if prompt_logprobs_dict:
|
||||
torch.cuda.synchronize()
|
||||
self._sync_device()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
@ -1740,7 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
@ -1964,7 +1972,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
||||
else:
|
||||
sampler_output = None
|
||||
torch.cuda.synchronize()
|
||||
self._sync_device()
|
||||
del hidden_states, sampler_output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
@ -342,13 +342,14 @@ def init_worker_distributed_environment(
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user