mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 15:44:30 +08:00
[XPU]fix cuda event used in XPU model runner (#23708)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
69244e67e6
commit
6446677839
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner):
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner):
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
class _EventPlaceholder:
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
try:
|
||||
# replace cuda Event with xpu Event, this should work by default
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
yield
|
||||
finally:
|
||||
# if anything goes wrong, just patch it with a placeholder
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user