mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 12:45:38 +08:00
[Bugfix] Fix disagg hang caused by the prefill and decode communication issues (#12723)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
932c6b7461
commit
45cbc4991d
@ -10,7 +10,6 @@
|
||||
stop the prefill instance when the decode instance is slow.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Union
|
||||
|
||||
@ -29,13 +28,13 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
|
||||
buffer_size_thresh: float):
|
||||
"""
|
||||
signal_pipe: on CPU
|
||||
|
||||
NOTE: on-device recv will block all threads in the process, making the
|
||||
KV cache producer unable to listen to new request while transmitting
|
||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
||||
signal_pipe: on CPU
|
||||
|
||||
NOTE: on-device recv will block all threads in the process, making the
|
||||
KV cache producer unable to listen to new request while transmitting
|
||||
KV cache. Luckily CPU recv only blocks the current thread so we use
|
||||
CPU recv to listen to new request.
|
||||
|
||||
|
||||
data_pipe: on device (e.g. GPU)
|
||||
"""
|
||||
|
||||
@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = buffer_size_thresh
|
||||
self.buffer_lock = threading.Lock()
|
||||
self.buffer_cv = threading.Condition()
|
||||
self.signal_pipe = signal_pipe
|
||||
self.data_pipe = data_pipe
|
||||
self.request_handling_thread: Optional[threading.Thread] = None
|
||||
@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
hidden = hidden.clone()
|
||||
|
||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
||||
|
||||
with self.buffer_lock:
|
||||
for data in buffer_item:
|
||||
self.buffer_size += self._get_element_size(data)
|
||||
with self.buffer_cv:
|
||||
if self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
# log outside the while loop to avoid this message being logged
|
||||
# repeatedly.
|
||||
logger.debug("KV transfer buffer is full. Handling...")
|
||||
while self.buffer_size + data_size > self.buffer_size_threshold:
|
||||
self.buffer_cv.wait()
|
||||
|
||||
self.buffer_size += data_size
|
||||
self.buffer.append(buffer_item)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
def _is_end_signal(self, signal):
|
||||
return signal is None
|
||||
@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
roi = (roi > 0.5)
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
|
||||
matched_length = 0
|
||||
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
with self.buffer_lock:
|
||||
|
||||
def is_buffer_available(
|
||||
tokens_roi_recver: List[torch.Tensor], ) -> bool:
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
for _ in range(len(self.buffer)):
|
||||
|
||||
temp_length = self._matches(self.buffer[0],
|
||||
tokens_roi_recver)
|
||||
if temp_length > 0:
|
||||
matched_length = temp_length
|
||||
break
|
||||
if self._matches(self.buffer[0],
|
||||
tokens_roi_recver) > 0:
|
||||
return True
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
return False
|
||||
|
||||
if matched_length > 0:
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
|
||||
else:
|
||||
# no match, just send None
|
||||
for _ in range(5):
|
||||
self.data_pipe.send_tensor(None)
|
||||
with self.buffer_cv:
|
||||
while not is_buffer_available(tokens_roi_recver):
|
||||
logger.debug(
|
||||
"KV transfer buffer is not available. Waiting...")
|
||||
self.buffer_cv.wait()
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
self.buffer_cv.notify()
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'Connection closed by peer' not in str(e):
|
||||
@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
return [input_tokens, roi, key, value, hidden]
|
||||
|
||||
def full_handler(self):
|
||||
time.sleep(0.001)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
if self.buffer_size > self.buffer_size_threshold:
|
||||
# log outside the while loop to avoid this message being logged
|
||||
# repeatedly.
|
||||
logger.debug("KV transfer buffer is full. Handling...")
|
||||
while self.buffer_size > self.buffer_size_threshold:
|
||||
self.full_handler()
|
||||
|
||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||
|
||||
# when calling the insert, the current process is a sender
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user