mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
[Bugfix] Fix TP inference for Flex attention backend (#19657)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
4d5424029b
commit
1173804dca
@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
from ...utils import create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
# Odd steps schedules a new batch.
|
||||
assert output is None
|
||||
step += 1
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test engine can initialize worker in tp properly
|
||||
"""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
tensor_parallel_size=2,
|
||||
# Reduce startup time.
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
|
||||
def get_worker_cache_config_field(worker, key: str):
|
||||
return getattr(worker.cache_config, key)
|
||||
|
||||
num_gpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_gpu_blocks", ))
|
||||
num_cpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_cpu_blocks", ))
|
||||
assert all(x is not None for x in num_gpu_blocks)
|
||||
assert all(x is not None for x in num_cpu_blocks)
|
||||
|
||||
@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
@ -236,7 +237,12 @@ class FlexAttentionMetadata:
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
assert self.mask_mod is not None
|
||||
return create_block_mask_compiled(
|
||||
# FIXME: With TP>1, create_block_mask_compiled will raise
|
||||
# CUDA error: an illegal memory access was encountered
|
||||
create_block_mask_fn = (create_block_mask_compiled
|
||||
if get_tensor_model_parallel_world_size() == 1
|
||||
else create_block_mask)
|
||||
return create_block_mask_fn(
|
||||
self.mask_mod,
|
||||
None,
|
||||
None,
|
||||
|
||||
@ -84,6 +84,8 @@ class EngineCore:
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.collective_rpc("initialize_cache",
|
||||
args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
|
||||
@ -112,6 +112,11 @@ class Worker(WorkerBase):
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
|
||||
@ -93,6 +93,11 @@ class TPUWorker:
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user