[Bugfix][Disaggregated] patch the inflight batching on the decode node in SimpleConnector to avoid hangs in SimpleBuffer (nccl based) (#13987)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos 2025-02-28 08:53:45 +01:00 committed by GitHub
parent 1088f06242
commit b9e41734c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -214,6 +214,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten() slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
hidden_or_intermediate_states_for_one_req = [] hidden_or_intermediate_states_for_one_req = []
@ -225,9 +226,21 @@ class SimpleConnector(KVConnectorBase):
# enumerate different requests # enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill. # FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens): for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx]) start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break
current_tokens = input_tokens_tensor[start_pos:end_pos] current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen num_tokens = slen
@ -288,7 +301,7 @@ class SimpleConnector(KVConnectorBase):
# Here we will fall back to normal model forwarding # Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do # But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches. # prefilling on those tokens that are missing KV caches.
logger.debug( logger.warning(
"[rank%d]: Failed to receive all KVs and hidden " "[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank()) "states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None hidden_or_intermediate_states = None