mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 00:31:55 +08:00
[Bugfix] Fix cuda event usage with CPU model runner (#23643)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
44ac25eae2
commit
9b0187003e
@ -11,6 +11,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import CpuGpuBuffer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -21,7 +22,8 @@ logger = init_logger(__name__)
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
with _torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
@ -71,8 +73,8 @@ class CPUModelRunner(GPUModelRunner):
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for k, v in vars(self).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self, k, k[:-4])
|
||||
if isinstance(v, CpuGpuBuffer):
|
||||
v.gpu = v.cpu
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
@ -108,6 +110,26 @@ class CPUModelRunner(GPUModelRunner):
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
return sampled_token_ids.tolist()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
class _EventPlaceholder:
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.record = lambda: None
|
||||
self.synchronize = lambda: None
|
||||
|
||||
try:
|
||||
cuda_event = torch.cuda.Event
|
||||
torch.cuda.Event = _EventPlaceholder
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.Event = cuda_event
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings(config: VllmConfig):
|
||||
|
||||
@ -321,7 +321,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(self.max_model_len, 1),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*args,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user