diff --git a/tests/test_utils.py b/tests/test_utils.py index 28acacd251903..53a34642e5baf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,9 +23,9 @@ from vllm.transformers_utils.detokenizer_utils import ( from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, common_broadcastable_dtype, - deprecate_kwargs, get_open_port, get_tcp_uri, - is_lossless_cast, join_host_port, make_zmq_path, - make_zmq_socket, memory_profiling, + current_stream, deprecate_kwargs, get_open_port, + get_tcp_uri, is_lossless_cast, join_host_port, + make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_host_port, split_zmq_path, supports_kw, swap_dict_values) @@ -957,3 +957,41 @@ def test_convert_ids_list_to_tokens(): ] tokens = convert_ids_list_to_tokens(tokenizer, token_ids) assert tokens == ['Hello', ',', ' world', '!'] + + +def test_current_stream_multithread(): + import threading + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + main_default_stream = torch.cuda.current_stream() + child_stream = torch.cuda.Stream() + + thread_stream_ready = threading.Event() + thread_can_exit = threading.Event() + + def child_thread_func(): + with torch.cuda.stream(child_stream): + thread_stream_ready.set() + thread_can_exit.wait(timeout=10) + + child_thread = threading.Thread(target=child_thread_func) + child_thread.start() + + try: + assert thread_stream_ready.wait( + timeout=5), "Child thread failed to enter stream context in time" + + main_current_stream = current_stream() + + assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread" + assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream" + + # Notify child thread it can exit + thread_can_exit.set() + + finally: + # Ensure child thread exits properly + child_thread.join(timeout=5) + if child_thread.is_alive(): + pytest.fail("Child thread failed to exit properly") diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bbcc2a523dcb2..e4f495e22e291 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1383,12 +1383,11 @@ def find_nccl_library() -> str: prev_set_stream = torch.cuda.set_stream -_current_stream = None +_current_stream_tls = threading.local() def _patched_set_stream(stream: torch.cuda.Stream) -> None: - global _current_stream - _current_stream = stream + _current_stream_tls.value = stream prev_set_stream(stream) @@ -1407,16 +1406,16 @@ def current_stream() -> torch.cuda.Stream: from C/C++ code. """ from vllm.platforms import current_platform - global _current_stream - if _current_stream is None: + if not hasattr(_current_stream_tls, + "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - _current_stream = torch.cuda.Stream() if current_platform.is_rocm( - ) else torch.cuda.current_stream() - return _current_stream + _current_stream_tls.value = torch.cuda.Stream( + ) if current_platform.is_rocm() else torch.cuda.current_stream() + return _current_stream_tls.value def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: