mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:35:00 +08:00
Signed-off-by: simpx <simpxx@gmail.com>
This commit is contained in:
parent
a15a50fc17
commit
a0e827e07c
@ -23,9 +23,9 @@ from vllm.transformers_utils.detokenizer_utils import (
|
|||||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||||
bind_kv_cache, common_broadcastable_dtype,
|
bind_kv_cache, common_broadcastable_dtype,
|
||||||
deprecate_kwargs, get_open_port, get_tcp_uri,
|
current_stream, deprecate_kwargs, get_open_port,
|
||||||
is_lossless_cast, join_host_port, make_zmq_path,
|
get_tcp_uri, is_lossless_cast, join_host_port,
|
||||||
make_zmq_socket, memory_profiling,
|
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||||
merge_async_iterators, sha256, split_host_port,
|
merge_async_iterators, sha256, split_host_port,
|
||||||
split_zmq_path, supports_kw, swap_dict_values)
|
split_zmq_path, supports_kw, swap_dict_values)
|
||||||
|
|
||||||
@ -957,3 +957,41 @@ def test_convert_ids_list_to_tokens():
|
|||||||
]
|
]
|
||||||
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
|
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
|
||||||
assert tokens == ['Hello', ',', ' world', '!']
|
assert tokens == ['Hello', ',', ' world', '!']
|
||||||
|
|
||||||
|
|
||||||
|
def test_current_stream_multithread():
|
||||||
|
import threading
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA not available")
|
||||||
|
|
||||||
|
main_default_stream = torch.cuda.current_stream()
|
||||||
|
child_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
thread_stream_ready = threading.Event()
|
||||||
|
thread_can_exit = threading.Event()
|
||||||
|
|
||||||
|
def child_thread_func():
|
||||||
|
with torch.cuda.stream(child_stream):
|
||||||
|
thread_stream_ready.set()
|
||||||
|
thread_can_exit.wait(timeout=10)
|
||||||
|
|
||||||
|
child_thread = threading.Thread(target=child_thread_func)
|
||||||
|
child_thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert thread_stream_ready.wait(
|
||||||
|
timeout=5), "Child thread failed to enter stream context in time"
|
||||||
|
|
||||||
|
main_current_stream = current_stream()
|
||||||
|
|
||||||
|
assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
|
||||||
|
assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"
|
||||||
|
|
||||||
|
# Notify child thread it can exit
|
||||||
|
thread_can_exit.set()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure child thread exits properly
|
||||||
|
child_thread.join(timeout=5)
|
||||||
|
if child_thread.is_alive():
|
||||||
|
pytest.fail("Child thread failed to exit properly")
|
||||||
|
|||||||
@ -1383,12 +1383,11 @@ def find_nccl_library() -> str:
|
|||||||
|
|
||||||
prev_set_stream = torch.cuda.set_stream
|
prev_set_stream = torch.cuda.set_stream
|
||||||
|
|
||||||
_current_stream = None
|
_current_stream_tls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||||
global _current_stream
|
_current_stream_tls.value = stream
|
||||||
_current_stream = stream
|
|
||||||
prev_set_stream(stream)
|
prev_set_stream(stream)
|
||||||
|
|
||||||
|
|
||||||
@ -1407,16 +1406,16 @@ def current_stream() -> torch.cuda.Stream:
|
|||||||
from C/C++ code.
|
from C/C++ code.
|
||||||
"""
|
"""
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
global _current_stream
|
if not hasattr(_current_stream_tls,
|
||||||
if _current_stream is None:
|
"value") or _current_stream_tls.value is None:
|
||||||
# when this function is called before any stream is set,
|
# when this function is called before any stream is set,
|
||||||
# we return the default stream.
|
# we return the default stream.
|
||||||
# On ROCm using the default 0 stream in combination with RCCL
|
# On ROCm using the default 0 stream in combination with RCCL
|
||||||
# is hurting performance. Therefore creating a dedicated stream
|
# is hurting performance. Therefore creating a dedicated stream
|
||||||
# per process
|
# per process
|
||||||
_current_stream = torch.cuda.Stream() if current_platform.is_rocm(
|
_current_stream_tls.value = torch.cuda.Stream(
|
||||||
) else torch.cuda.current_stream()
|
) if current_platform.is_rocm() else torch.cuda.current_stream()
|
||||||
return _current_stream
|
return _current_stream_tls.value
|
||||||
|
|
||||||
|
|
||||||
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user