mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Bugfix] Fix pooling models on CPU backend (#23392)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
998720859c
commit
88016c372a
@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||
torch.cuda.set_stream = _patched_set_stream
|
||||
|
||||
|
||||
class _StreamPlaceholder:
|
||||
|
||||
def __init__(self):
|
||||
self.synchronize = lambda: None
|
||||
|
||||
|
||||
def current_stream() -> torch.cuda.Stream:
|
||||
"""
|
||||
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
|
||||
@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream:
|
||||
# On ROCm using the default 0 stream in combination with RCCL
|
||||
# is hurting performance. Therefore creating a dedicated stream
|
||||
# per process
|
||||
_current_stream_tls.value = torch.cuda.Stream(
|
||||
) if current_platform.is_rocm() else torch.cuda.current_stream()
|
||||
if current_platform.is_rocm():
|
||||
_current_stream_tls.value = torch.cuda.Stream()
|
||||
elif current_platform.is_cpu():
|
||||
_current_stream_tls.value = _StreamPlaceholder()
|
||||
else:
|
||||
current_stream = current_platform.current_stream
|
||||
if current_stream is not None:
|
||||
_current_stream_tls.value = current_stream()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Fail to set current stream, current platform "
|
||||
"may not support current_stream with torch API")
|
||||
return _current_stream_tls.value
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user