mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:05:52 +08:00
[Bugfix] Fix TP inference for Flex attention backend (#19657)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
4d5424029b
commit
1173804dca
@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
from ...utils import create_new_process_for_each_test
|
from ...utils import create_new_process_for_each_test, multi_gpu_test
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||||
@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Odd steps schedules a new batch.
|
# Odd steps schedules a new batch.
|
||||||
assert output is None
|
assert output is None
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""
|
||||||
|
Test engine can initialize worker in tp properly
|
||||||
|
"""
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
"""Setup the EngineCore."""
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
# Reduce startup time.
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
vllm_config = engine_args.create_engine_config()
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
|
||||||
|
with set_default_torch_num_threads(1):
|
||||||
|
engine_core = EngineCore(vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=True)
|
||||||
|
|
||||||
|
def get_worker_cache_config_field(worker, key: str):
|
||||||
|
return getattr(worker.cache_config, key)
|
||||||
|
|
||||||
|
num_gpu_blocks = engine_core.collective_rpc(
|
||||||
|
get_worker_cache_config_field, args=("num_gpu_blocks", ))
|
||||||
|
num_cpu_blocks = engine_core.collective_rpc(
|
||||||
|
get_worker_cache_config_field, args=("num_cpu_blocks", ))
|
||||||
|
assert all(x is not None for x in num_gpu_blocks)
|
||||||
|
assert all(x is not None for x in num_cpu_blocks)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
@ -236,7 +237,12 @@ class FlexAttentionMetadata:
|
|||||||
|
|
||||||
def build_block_mask(self) -> BlockMask:
|
def build_block_mask(self) -> BlockMask:
|
||||||
assert self.mask_mod is not None
|
assert self.mask_mod is not None
|
||||||
return create_block_mask_compiled(
|
# FIXME: With TP>1, create_block_mask_compiled will raise
|
||||||
|
# CUDA error: an illegal memory access was encountered
|
||||||
|
create_block_mask_fn = (create_block_mask_compiled
|
||||||
|
if get_tensor_model_parallel_world_size() == 1
|
||||||
|
else create_block_mask)
|
||||||
|
return create_block_mask_fn(
|
||||||
self.mask_mod,
|
self.mask_mod,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|||||||
@ -84,6 +84,8 @@ class EngineCore:
|
|||||||
|
|
||||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.collective_rpc("initialize_cache",
|
||||||
|
args=(num_gpu_blocks, num_cpu_blocks))
|
||||||
|
|
||||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -112,6 +112,11 @@ class Worker(WorkerBase):
|
|||||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||||
self._sleep_saved_buffers = {}
|
self._sleep_saved_buffers = {}
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
|||||||
@ -93,6 +93,11 @@ class TPUWorker:
|
|||||||
if self.model_config.seed is None:
|
if self.model_config.seed is None:
|
||||||
self.model_config.seed = 0
|
self.model_config.seed = 0
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user