mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 03:55:01 +08:00
[Bugfix][Disaggregated] patch the inflight batching on the decode node in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (#13987)
Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
parent
1088f06242
commit
b9e41734c5
@ -214,6 +214,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
input_tokens_tensor = model_input.input_tokens
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
seq_lens = model_input.attn_metadata.seq_lens
|
||||||
|
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req = []
|
hidden_or_intermediate_states_for_one_req = []
|
||||||
@ -225,9 +226,21 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
# enumerate different requests
|
# enumerate different requests
|
||||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||||
for idx, slen in enumerate(seq_lens):
|
for idx, slen in enumerate(seq_lens):
|
||||||
|
|
||||||
start_pos = sum(seq_lens[:idx])
|
start_pos = sum(seq_lens[:idx])
|
||||||
end_pos = start_pos + slen
|
end_pos = start_pos + slen
|
||||||
|
|
||||||
|
if start_pos >= num_prefill_tokens:
|
||||||
|
# This can happen during inflight batching. See:
|
||||||
|
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||||
|
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||||
|
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||||
|
logger.warning("You should set --enable_chunked_prefill=False "
|
||||||
|
"and --max_num_batched_tokens "
|
||||||
|
"should be equal to max_seq_len_to_capture")
|
||||||
|
bypass_model_exec = False
|
||||||
|
assert start_pos == num_prefill_tokens
|
||||||
|
break
|
||||||
|
|
||||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||||
num_tokens = slen
|
num_tokens = slen
|
||||||
|
|
||||||
@ -288,7 +301,7 @@ class SimpleConnector(KVConnectorBase):
|
|||||||
# Here we will fall back to normal model forwarding
|
# Here we will fall back to normal model forwarding
|
||||||
# But optionally you can adjust model_input so that you only do
|
# But optionally you can adjust model_input so that you only do
|
||||||
# prefilling on those tokens that are missing KV caches.
|
# prefilling on those tokens that are missing KV caches.
|
||||||
logger.debug(
|
logger.warning(
|
||||||
"[rank%d]: Failed to receive all KVs and hidden "
|
"[rank%d]: Failed to receive all KVs and hidden "
|
||||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||||
hidden_or_intermediate_states = None
|
hidden_or_intermediate_states = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user