[Bugfix] Fix TP inference for Flex attention backend (#19657)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-06-16 19:21:37 +08:00 committed by GitHub
parent 4d5424029b
commit 1173804dca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 2 deletions

View File

@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from ...utils import create_new_process_for_each_test
from ...utils import create_new_process_for_each_test, multi_gpu_test
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Odd steps schedules a new batch.
assert output is None
step += 1
@multi_gpu_test(num_gpus=2)
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
"""
Test engine can initialize worker in tp properly
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(
model=MODEL_NAME,
tensor_parallel_size=2,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)
num_gpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_gpu_blocks", ))
num_cpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_cpu_blocks", ))
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)

View File

@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@ -236,7 +237,12 @@ class FlexAttentionMetadata:
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
self.mask_mod,
None,
None,

View File

@ -84,6 +84,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
self.structured_output_manager = StructuredOutputManager(vllm_config)

View File

@ -112,6 +112,11 @@ class Worker(WorkerBase):
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until

View File

@ -93,6 +93,11 @@ class TPUWorker:
if self.model_config.seed is None:
self.model_config.seed = 0
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D