mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:24:57 +08:00
[MISC][Bugfix] Use less CPU when message queue has been empty for some time (#16226)
Signed-off-by: Povilas Kanapickas <povilas@radix.lt>
This commit is contained in:
parent
61059bee40
commit
85e2b7bb13
@ -128,15 +128,21 @@ def test_models(
|
|||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, distributed_executor_backend, attention_backend, "
|
"model, distributed_executor_backend, attention_backend, "
|
||||||
"test_suite", [
|
"test_suite, extra_env", [
|
||||||
("distilbert/distilgpt2", "ray", "", "L4"),
|
("distilbert/distilgpt2", "ray", "", "L4", {}),
|
||||||
("distilbert/distilgpt2", "mp", "", "L4"),
|
("distilbert/distilgpt2", "mp", "", "L4", {}),
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4"),
|
("distilbert/distilgpt2", "ray", "", "L4", {
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4"),
|
"VLLM_SLEEP_WHEN_IDLE": "1"
|
||||||
("distilbert/distilgpt2", "ray", "", "A100"),
|
}),
|
||||||
("distilbert/distilgpt2", "mp", "", "A100"),
|
("distilbert/distilgpt2", "mp", "", "L4", {
|
||||||
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
|
"VLLM_SLEEP_WHEN_IDLE": "1"
|
||||||
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
|
}),
|
||||||
|
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||||
|
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||||
|
("distilbert/distilgpt2", "ray", "", "A100", {}),
|
||||||
|
("distilbert/distilgpt2", "mp", "", "A100", {}),
|
||||||
|
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100", {}),
|
||||||
|
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100", {}),
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
@ -148,6 +154,7 @@ def test_models_distributed(
|
|||||||
distributed_executor_backend: str,
|
distributed_executor_backend: str,
|
||||||
attention_backend: str,
|
attention_backend: str,
|
||||||
test_suite: str,
|
test_suite: str,
|
||||||
|
extra_env: dict[str, str],
|
||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
@ -173,6 +180,9 @@ def test_models_distributed(
|
|||||||
attention_backend,
|
attention_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for k, v in extra_env.items():
|
||||||
|
monkeypatch_context.setenv(k, v)
|
||||||
|
|
||||||
dtype = "half"
|
dtype = "half"
|
||||||
max_tokens = 5
|
max_tokens = 5
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,43 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SpinTimer:
|
||||||
|
|
||||||
|
def record_activity(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def spin(self):
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
|
||||||
|
class SpinSleepTimer(SpinTimer):
|
||||||
|
"""
|
||||||
|
In setups which have long inactivity periods it is desirable to reduce
|
||||||
|
system power consumption when vllm does nothing. This would lead to more
|
||||||
|
CPU thermal headroom when a request eventually comes, especially when
|
||||||
|
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||||
|
100% CPU usage.
|
||||||
|
|
||||||
|
The simplest solution is to reduce polling frequency when there is no
|
||||||
|
activity for a certain period of time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||||
|
self.last_activity = time.monotonic()
|
||||||
|
self.busy_loop_s = busy_loop_s
|
||||||
|
self.wait_sleep_s = wait_sleep_s
|
||||||
|
|
||||||
|
def record_activity(self):
|
||||||
|
self.last_activity = time.monotonic()
|
||||||
|
|
||||||
|
def spin(self):
|
||||||
|
curr_time = time.monotonic()
|
||||||
|
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||||
|
time.sleep(self.wait_sleep_s)
|
||||||
|
else:
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
|
||||||
class ShmRingBuffer:
|
class ShmRingBuffer:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -238,6 +275,7 @@ class MessageQueue:
|
|||||||
self.local_reader_rank = -1
|
self.local_reader_rank = -1
|
||||||
# rank does not matter for remote readers
|
# rank does not matter for remote readers
|
||||||
self._is_remote_reader = False
|
self._is_remote_reader = False
|
||||||
|
self._read_spin_timer = SpinTimer()
|
||||||
|
|
||||||
self.handle = Handle(
|
self.handle = Handle(
|
||||||
local_reader_ranks=local_reader_ranks,
|
local_reader_ranks=local_reader_ranks,
|
||||||
@ -276,6 +314,9 @@ class MessageQueue:
|
|||||||
self.local_socket.connect(socket_addr)
|
self.local_socket.connect(socket_addr)
|
||||||
|
|
||||||
self.remote_socket = None
|
self.remote_socket = None
|
||||||
|
|
||||||
|
self._read_spin_timer = SpinSleepTimer(
|
||||||
|
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||||
else:
|
else:
|
||||||
self.buffer = None # type: ignore
|
self.buffer = None # type: ignore
|
||||||
self.current_idx = -1
|
self.current_idx = -1
|
||||||
@ -407,7 +448,7 @@ class MessageQueue:
|
|||||||
# we need to wait until it is written
|
# we need to wait until it is written
|
||||||
|
|
||||||
# Release the processor to other threads
|
# Release the processor to other threads
|
||||||
sched_yield()
|
self._read_spin_timer.spin()
|
||||||
|
|
||||||
# if we wait for a long time, log a message
|
# if we wait for a long time, log a message
|
||||||
if (time.monotonic() - start_time
|
if (time.monotonic() - start_time
|
||||||
@ -438,6 +479,8 @@ class MessageQueue:
|
|||||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||||
self.current_idx = (self.current_idx +
|
self.current_idx = (self.current_idx +
|
||||||
1) % self.buffer.max_chunks
|
1) % self.buffer.max_chunks
|
||||||
|
|
||||||
|
self._read_spin_timer.record_activity()
|
||||||
break
|
break
|
||||||
|
|
||||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||||
|
|||||||
@ -122,6 +122,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||||
|
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -841,6 +842,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Regex timeout for use by the vLLM tool parsing plugins.
|
# Regex timeout for use by the vLLM tool parsing plugins.
|
||||||
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
|
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
|
||||||
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
|
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
|
||||||
|
|
||||||
|
# Reduce CPU usage when vLLM is idle. Enabling this will incur small
|
||||||
|
# latency penalty when a request eventually comes.
|
||||||
|
"VLLM_SLEEP_WHEN_IDLE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user