[XPU]fix cuda event used in XPU model runner (#23708)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-08-27 15:27:14 +08:00 committed by GitHub
parent 69244e67e6
commit 6446677839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner):
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(vllm_config, device)
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False
@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner):
def _sync_device(self) -> None:
torch.xpu.synchronize()
@contextmanager
def _torch_cuda_wrapper():
class _EventPlaceholder:
def __init__(self, *args, **kwargs) -> None:
self.record = lambda: None
self.synchronize = lambda: None
try:
# replace cuda Event with xpu Event, this should work by default
torch.cuda.Event = torch.xpu.Event
yield
finally:
# if anything goes wrong, just patch it with a placeholder
torch.cuda.Event = _EventPlaceholder