diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 0a11935607e2..61aa7df13b4d 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -6,6 +6,7 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-48-95} +OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95} NUMA_NODE=${NUMA_NODE:-1} export CMAKE_BUILD_PARALLEL_LEVEL=32 @@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -56,7 +55,7 @@ function cpu_tests() { # Run AWQ test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + VLLM_USE_V1=0 pytest -s -v \ tests/quantization/test_ipex_quant.py" # Run chunked-prefill and prefix-cache test @@ -68,8 +67,6 @@ function cpu_tests() { # online serving docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - export VLLM_CPU_KVCACHE_SPACE=10 - export VLLM_CPU_OMP_THREADS_BIND=$1 python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ @@ -89,4 +86,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index a2321bf98900..7c4909cb5d91 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -40,6 +40,8 @@ This living user guide outlines a few known **important changes and limitations* | **NVIDIA** | 🚀 Natively Supported | | **AMD** | 🚧 WIP | | **TPU** | 🚧 WIP | +| **CPU** | 🚧 WIP | + #### Feature / Model | Feature / Model | Status | diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 1213301584ce..e43b44397752 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,6 +1,9 @@ # Common dependencies -r common.txt +numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +numba == 0.61.2; python_version > '3.9' + # Dependencies for CPUs packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 435fe6225614..f3e64155703c 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -85,7 +85,10 @@ def test_env( CpuPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, block_size, False) - assert backend.get_name() == "TORCH_SDPA" + if use_v1: + assert backend.get_name() == "TORCH_SDPA_VLLM_V1" + else: + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.current_platform", diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index ed9e54722514..f656f90c4bd3 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -87,7 +87,6 @@ AITER_MODEL_LIST = [ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral - marks=[pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py index cf7883e121ab..793cb87b7434 100644 --- a/vllm/attention/backends/cpu_mla.py +++ b/vllm/attention/backends/cpu_mla.py @@ -178,7 +178,7 @@ class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_kv_len=max_kv_len, - query_start_loc=query_start_loc, + prefill_query_start_loc=query_start_loc, kv_start_loc=kv_start_loc, max_decode_seq_len=input_data.max_decode_seq_len, num_prefills=input_data.num_prefills, @@ -264,8 +264,8 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): key=k, value=v_padded, out=output, - seqlen_q=prefill_metadata.query_start_loc, - seqlen_k=prefill_metadata.query_start_loc, + seqlen_q=prefill_metadata.prefill_query_start_loc, + seqlen_k=prefill_metadata.prefill_query_start_loc, max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_k=prefill_metadata.max_query_len, pdropout=0.0, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f3fb5adcf05c..23231c323f13 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -87,10 +87,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): # For chunked prefill only max_query_len: Optional[int] = None max_kv_len: Optional[int] = None - query_start_loc: Optional[torch.Tensor] = None + prefill_query_start_loc: Optional[torch.Tensor] = None kv_start_loc: Optional[torch.Tensor] = None prefill_block_tables: Optional[torch.Tensor] = None + # For V1 logits index only + query_start_loc: Optional[torch.Tensor] = None + # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None @@ -375,7 +378,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_kv_len=max_kv_len, - query_start_loc=query_start_loc, + prefill_query_start_loc=query_start_loc, kv_start_loc=kv_start_loc, max_decode_seq_len=input_data.max_decode_seq_len, num_prefills=input_data.num_prefills, @@ -470,6 +473,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ + + # For warming-up + if attn_metadata is None: + return query + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): @@ -537,8 +545,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: - assert attn_metadata.seq_lens is not None if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + assert attn_metadata.seq_lens is not None self._run_sdpa_forward(output, query, key, @@ -555,7 +563,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): query[:prefill_meta.num_prefill_tokens, :, :], key_cache, value_cache, - prefill_meta.query_start_loc, + prefill_meta.prefill_query_start_loc, prefill_meta.kv_start_loc, prefill_meta.max_query_len, prefill_meta.max_kv_len, diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 8c8d0b5cb229..2a261c84c3fc 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -41,11 +41,16 @@ class TorchCompileWrapperWithCustomDispatcher: # compiling the forward method backend = vllm_config.compilation_config.init_backend(vllm_config) + options = None + if isinstance(backend, str) and backend == "inductor": + options = get_current_vllm_config( + ).compilation_config.inductor_compile_config compiled_callable = torch.compile( self.forward, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) + backend=backend, + options=options) self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2197d44ca825..b1c4b27a0ca4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1399,6 +1399,7 @@ class EngineArgs: "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", + "TORCH_SDPA_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1431,7 +1432,8 @@ class EngineArgs: # Non-[CUDA, TPU] may be supported on V1, but off by default for now. v0_hardware = not any( - (current_platform.is_cuda(), current_platform.is_tpu())) + (current_platform.is_cuda(), current_platform.is_tpu(), + current_platform.is_cpu())) if v0_hardware and _warn_or_fallback( # noqa: SIM103 current_platform.device_name): return False diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2739f5c8c690..265959d626e0 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -57,7 +57,10 @@ class CpuPlatform(Platform): logger.info("Using CPU MLA backend.") return "vllm.attention.backends.cpu_mla.CPUMLABackend" logger.info("Using Torch SDPA backend.") - return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + if use_v1: + return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + else: + return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: @@ -81,6 +84,8 @@ class CpuPlatform(Platform): if not model_config.enforce_eager: model_config.enforce_eager = True + model_config.disable_cascade_attn = True + cache_config = vllm_config.cache_config ipex_available = find_spec("intel_extension_for_pytorch") is not None @@ -128,7 +133,8 @@ class CpuPlatform(Platform): f" {kv_cache_space}, expect a positive integer value.") parallel_config = vllm_config.parallel_config - if (parallel_config.distributed_executor_backend is not None + if (parallel_config.world_size > 1 + and parallel_config.distributed_executor_backend is not None and parallel_config.distributed_executor_backend != "mp"): logger.warning(("%s is not supported on CPU, fallback to mp " "distributed executor backend."), @@ -141,7 +147,38 @@ class CpuPlatform(Platform): parallel_config.sd_worker_cls = \ "vllm.worker.cpu_worker.CPUWorker" else: - parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.cpu_worker.CPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.cpu_worker.CPUWorker" + + # Note: workaround for v1 gpu_model_runner + from vllm.config import CompilationLevel + vllm_config.compilation_config.cudagraph_capture_sizes = [] + + compilation_config = vllm_config.compilation_config + if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE): + compilation_config.level = CompilationLevel.DYNAMO_ONCE + compilation_config.backend = "eager" + compilation_config.custom_ops += ["none"] + compilation_config.inductor_compile_config.update({ + "dce": + True, + "size_asserts": + False, + "nan_asserts": + False, + "memory_planning": + True, + "epilogue_fusion": + True, + }) + + if vllm_config.lora_config is not None: + compilation_config.level = CompilationLevel.NO_COMPILATION assert vllm_config.device_config.device_type == "cpu" @@ -149,6 +186,12 @@ class CpuPlatform(Platform): # Environment variables for CPU executor # + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + # Note: to avoid the error 'nthreads cannot be larger than environment + # variable "NUMEXPR_MAX_THREADS" (64)'. + os.environ["NUMEXPR_MAX_THREADS"] = str(len(os.sched_getaffinity(0))) + # Set default threads num for OpenMP parallel os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads()) @@ -171,13 +214,6 @@ class CpuPlatform(Platform): # To hint IPEX uses shared memory based AllReduce os.environ["LOCAL_WORLD_SIZE"] = str( vllm_config.parallel_config.tensor_parallel_size) - if sys.platform == "darwin" and \ - envs.VLLM_WORKER_MULTIPROC_METHOD == "fork": - if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None: - logger.warning( - "Default to spawn method on MacOS. If this is not desired," - " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' if vllm_config.model_config and vllm_config.model_config.use_mla: logger.info( @@ -204,3 +240,14 @@ class CpuPlatform(Platform): Get device specific communicator class for distributed communication. """ return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa + + @classmethod + def supports_structured_output(cls) -> bool: + return True + + @classmethod + def supports_v1(cls, model_config) -> bool: + """Returns whether the current platform can support v1 for the supplied + model configuration. + """ + return True diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py new file mode 100644 index 000000000000..d7a580c2883c --- /dev/null +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +import numpy as np +import torch + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, + TorchSDPAMetadata) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.cpu_model_runner import CPUModelRunner +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class TorchSDPABackend: + accept_output_buffer: bool = False + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: + return TorchSDPAMetadataBuilderV1 + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +class TorchSDPAMetadataBuilderV1: + + def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, + block_table: BlockTable) -> None: + self.runner = runner + self.block_table = block_table + + # For reorder + self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, + dtype=np.int64) + self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, + dtype=np.int64) + self.num_prompt_req: int = 0 + + self.seq_start_loc_cpu = torch.zeros( + runner.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + ) + self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput) -> bool: + prompt_list_idx = 0 + decode_list_idx = 0 + for req_index in range(input_batch.num_reqs): + if input_batch.num_computed_tokens_cpu[ + req_index] < input_batch.num_prompt_tokens[req_index]: + # prompt stage + self.reorder_prompt_req_index_list[prompt_list_idx] = req_index + prompt_list_idx += 1 + else: + # decode stage + self.reorder_decode_req_index_list[decode_list_idx] = req_index + decode_list_idx += 1 + assert decode_list_idx + prompt_list_idx == input_batch.num_reqs + + # Update prompt requests number + self.num_prompt_req = prompt_list_idx + + reorder_req_num = 0 + for req_index in range(decode_list_idx): + if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: + reorder_req_num += 1 + else: + break + + if reorder_req_num == 0: + return False + + reorder_prompt_list = ( + self.reorder_prompt_req_index_list[:prompt_list_idx] + [-reorder_req_num:]) + reorder_decode_list = ( + self.reorder_decode_req_index_list[:decode_list_idx] + [:reorder_req_num]) + assert reorder_decode_list.size == reorder_prompt_list.size + + for idx in range(reorder_req_num): + prompt_req_index = reorder_prompt_list[idx].item() + decode_req_index = reorder_decode_list[idx].item() + input_batch.swap_states(prompt_req_index, decode_req_index) + + return True + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + runner = self.runner + block_table = self.block_table + seq_lens_np = runner.seq_lens_np[:num_reqs] + num_prompt_req = self.num_prompt_req + max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( + ) if num_prompt_req > 0 else 0 + max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item( + ) if num_prompt_req < num_reqs else 0 + self.seq_start_loc_np[0] = 0 + np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) + num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() + num_decode_tokens = runner.query_start_loc_np[num_reqs].item( + ) - num_prefill_tokens + slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() + block_table_tensor = block_table.get_device_tensor() + attn_metadata = TorchSDPAMetadata( + num_prefills=num_prompt_req, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + slot_mapping=slot_mapping, + seq_lens_tensor=runner. + seq_lens_cpu[num_prompt_req:num_reqs], # decode + max_decode_seq_len=max_decode_seq_len, # decode + block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode + chunked_prefill=True, + max_query_len=max_query_len, + max_kv_len=max_prefill_seq_len, + prefill_query_start_loc=runner. + query_start_loc_cpu[:num_prompt_req + 1], # prefill + kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + + 1], # prefill + prefill_block_tables=block_table_tensor[: + num_prompt_req], # prefill + query_start_loc=runner.query_start_loc_cpu[:num_reqs + + 1], # for logits index + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) + + return attn_metadata diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py new file mode 100644 index 000000000000..607cfc0ef69c --- /dev/null +++ b/vllm/v1/worker/cpu_model_runner.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager +from typing import Any + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +logger = init_logger(__name__) + + +class CPUModelRunner(GPUModelRunner): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + super().__init__(vllm_config, device) + + assert device == torch.device("cpu") + assert self.speculative_config is None, "spec decode is not supported." + + self.use_cuda_graph = False + self.cascade_attn_enabled = False + + self._postprocess_tenosrs() + + def _postprocess_tenosrs(self) -> None: + # Note: replace device tensors with cpu tensors + def replace_tensor(obj: Any, cpu_attr_name: str, + device_attr_name) -> None: + cpu_tensor = getattr(obj, cpu_attr_name, None) + device_tensor = getattr(obj, device_attr_name, None) + if cpu_tensor is not None and device_tensor is not None: + assert isinstance(cpu_tensor, torch.Tensor) + assert isinstance(device_tensor, torch.Tensor) + setattr(obj, device_attr_name, cpu_tensor) + + for k, v in vars(self).items(): + if k.endswith("_cpu") and isinstance(v, torch.Tensor): + replace_tensor(self, k, k[:-4]) + + for k, v in vars(self.input_batch).items(): + if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): + replace_tensor(self.input_batch, k, k[:-11]) + + for k, v in vars(self.input_batch.block_table).items(): + if k.endswith("_cpu") and isinstance(v, torch.Tensor): + replace_tensor(self.input_batch.block_table, k, k[:-4]) + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + self.model = get_model(vllm_config=self.vllm_config) + + if self.lora_config: + self.model = self.load_lora_model(self.model, self.model_config, + self.scheduler_config, + self.lora_config, self.device) + + def warming_up_model(self) -> None: + logger.info("Warming up model for the compilation...") + # Only generate graph for the generic shape + self._dummy_run(max(16, self.max_num_reqs)) + logger.info("Warming up done.") + + def _init_device_properties(self) -> None: + pass + + def _sync_device(self) -> None: + pass + + +@contextmanager +def _set_global_compilation_settings(): + import torch._inductor.config + + # Note: The CPPGEMM backend requires freezing parameters. + freezing_value = torch._inductor.config.freezing + torch._inductor.config.freezing = True + # Note: workaround for "ValueError: fast mode: can't pickle cyclic objects + # including object type dict" + force_disable_caches = torch._inductor.config.force_disable_caches + torch._inductor.config.force_disable_caches = True + yield + torch._inductor.config.freezing = freezing_value + torch._inductor.config.force_disable_caches = force_disable_caches diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py new file mode 100644 index 000000000000..0b710b7bc203 --- /dev/null +++ b/vllm/v1/worker/cpu_worker.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional + +import torch + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.logger import init_logger +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import IntermediateTensors +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.cpu_model_runner import CPUModelRunner +from vllm.v1.worker.gpu_worker import (Worker, + init_worker_distributed_environment) + +logger = init_logger(__name__) + + +class CPUWorker(Worker): + + def __init__(self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False): + super().__init__(vllm_config, + local_rank, + rank, + distributed_init_method, + is_driver_worker=is_driver_worker) + + self.parallel_config.disable_custom_all_reduce = True + + def init_device(self): + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if ret: + logger.info(ret) + + # Note: unique identifier for creating allreduce shared memory + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( + ":")[-1] + # Initialize the distributed environment. + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, "gloo") + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner: CPUModelRunner = CPUModelRunner( + self.vllm_config, torch.device("cpu")) + + def sleep(self, level: int = 1) -> None: + logger.warning("sleep mode is not supported on CPU, ignore it.") + pass + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + logger.warning("sleep mode is not supported on CPU, ignore it.") + pass + + def determine_available_memory(self) -> int: + return self.cache_config.cpu_kvcache_space_bytes # type: ignore + + def compile_or_warm_up_model(self) -> None: + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + self.model_runner.warming_up_model() + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional[ModelRunnerOutput]: + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if not get_pp_group().is_last_rank: + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return None + + assert isinstance(output, ModelRunnerOutput) + return output if self.is_driver_worker else None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6a566a602b19..6ea6bb020ed7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import copy import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -38,7 +38,6 @@ from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, is_pin_memory_available) -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, @@ -203,8 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.vllm_config.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count + self._init_device_properties() # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, @@ -315,6 +313,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.input_batch, scheduler_output) return batch_reordered + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties + """ + self.device_properties = torch.cuda.get_device_properties(self.device) + self.num_sms = self.device_properties.multi_processor_count + + # Note: used for model runner override. + def _sync_device(self) -> None: + torch.cuda.synchronize() + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -538,8 +547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, - Optional[SpecDecodeMetadata]]: + ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -652,7 +660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata: dict[str, FlashAttentionMetadata] = {} + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1710,7 +1718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Must synchronize the non-blocking GPU->CPU transfers. if prompt_logprobs_dict: - torch.cuda.synchronize() + self._sync_device() return prompt_logprobs_dict @@ -1740,7 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): dtype=np.int32) if skip_attn: - attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None + attn_metadata: Optional[dict[str, Any]] = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] @@ -1964,7 +1972,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampler_output = self._dummy_sampler_run(hidden_states) else: sampler_output = None - torch.cuda.synchronize() + self._sync_device() del hidden_states, sampler_output self.encoder_cache.clear() gc.collect() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f36cf5d5c319..3bf3b2221a44 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -342,13 +342,14 @@ def init_worker_distributed_environment( rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, + backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) + distributed_init_method, local_rank, backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)