diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index fbbfc630d27d..bc7894e92814 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput -from ...utils import create_new_process_for_each_test +from ...utils import create_new_process_for_each_test, multi_gpu_test if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Odd steps schedules a new batch. assert output is None step += 1 + + +@multi_gpu_test(num_gpus=2) +def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch): + """ + Test engine can initialize worker in tp properly + """ + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + """Setup the EngineCore.""" + engine_args = EngineArgs( + model=MODEL_NAME, + tensor_parallel_size=2, + # Reduce startup time. + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + + with set_default_torch_num_threads(1): + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=True) + + def get_worker_cache_config_field(worker, key: str): + return getattr(worker.cache_config, key) + + num_gpu_blocks = engine_core.collective_rpc( + get_worker_cache_config_field, args=("num_gpu_blocks", )) + num_cpu_blocks = engine_core.collective_rpc( + get_worker_cache_config_field, args=("num_cpu_blocks", )) + assert all(x is not None for x in num_gpu_blocks) + assert all(x is not None for x in num_cpu_blocks) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a572b89470f4..17b0f259cb76 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -236,7 +237,12 @@ class FlexAttentionMetadata: def build_block_mask(self) -> BlockMask: assert self.mask_mod is not None - return create_block_mask_compiled( + # FIXME: With TP>1, create_block_mask_compiled will raise + # CUDA error: an illegal memory access was encountered + create_block_mask_fn = (create_block_mask_compiled + if get_tensor_model_parallel_world_size() == 1 + else create_block_mask) + return create_block_mask_fn( self.mask_mod, None, None, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 07761bf000a6..57fcf8daa5a1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -84,6 +84,8 @@ class EngineCore: vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + self.collective_rpc("initialize_cache", + args=(num_gpu_blocks, num_cpu_blocks)) self.structured_output_manager = StructuredOutputManager(vllm_config) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b7d244f27045..58795e3fe292 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -112,6 +112,11 @@ class Worker(WorkerBase): buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5da481baeeea..87af8e476707 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -93,6 +93,11 @@ class TPUWorker: if self.model_config.seed is None: self.model_config.seed = 0 + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D