mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 09:47:05 +08:00
fix
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
This commit is contained in:
parent
e8985d9716
commit
0b5e466c8d
@ -28,6 +28,33 @@ def test_version():
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_get_raw_stream_patch():
|
||||
"""Test that get_raw_stream patch is applied only for torch 2.9.x."""
|
||||
import builtins
|
||||
|
||||
# Check if get_raw_stream exists in builtins
|
||||
has_patch = hasattr(builtins, "get_raw_stream")
|
||||
|
||||
# Import torch to get actual version
|
||||
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
is_torch_2_9 = is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer(
|
||||
"2.10.0"
|
||||
)
|
||||
|
||||
if is_torch_2_9:
|
||||
# For torch 2.9.x, the patch should be applied
|
||||
assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
|
||||
# Verify it's callable (it should be the _cuda_getCurrentRawStream function)
|
||||
get_raw_stream = builtins.get_raw_stream # type: ignore[attr-defined]
|
||||
assert callable(get_raw_stream)
|
||||
# Verify it's the correct function from torch._C
|
||||
from torch._C import _cuda_getCurrentRawStream
|
||||
|
||||
assert get_raw_stream is _cuda_getCurrentRawStream
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
@ -371,11 +371,15 @@ def _update_scheduler_patched(self) -> None:
|
||||
# For more context, see https://github.com/vllm-project/vllm/issues/30905.
|
||||
def _patch_get_raw_stream_if_needed():
|
||||
"""Workaround for TorchInductor autotune get_raw_stream() bug."""
|
||||
import builtins
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
||||
# Only apply the patch for torch 2.9.x versions
|
||||
if is_torch_equal_or_newer("2.9.0") and not is_torch_equal_or_newer("2.10.0"):
|
||||
import builtins
|
||||
|
||||
builtins.get_raw_stream = _get_raw_stream
|
||||
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
||||
|
||||
builtins.get_raw_stream = _get_raw_stream
|
||||
|
||||
|
||||
_patch_get_raw_stream_if_needed()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user