[Core] Add a random suffix to frontend-provided request IDs (#27987)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-12-23 21:05:39 +00:00 committed by GitHub
parent 34916ae37f
commit f790068600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 328 additions and 154 deletions

View File

@ -260,7 +260,7 @@ async def test_multi_abort(output_kind: RequestOutputKind):
# Use multi-abort to abort multiple requests at once # Use multi-abort to abort multiple requests at once
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT] abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
await engine.abort(abort_request_ids) await engine.abort(abort_request_ids, internal=False)
# Wait for all tasks to complete # Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
@ -609,7 +609,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind):
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Abort the request # Abort the request
await engine.abort(request_id) await engine.abort(request_id, internal=False)
# Wait for generation to complete and return final output # Wait for generation to complete and return final output
final_output = await generated final_output = await generated

View File

@ -40,10 +40,16 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "I am Gyoubu Masataka Oniwa" PROMPT = "I am Gyoubu Masataka Oniwa"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request() -> EngineCoreRequest: def make_request() -> EngineCoreRequest:
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_features=None, mm_features=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View File

@ -45,6 +45,8 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels" PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request( def make_request(
params: SamplingParams, prompt_tokens_ids: list[int] | None = None params: SamplingParams, prompt_tokens_ids: list[int] | None = None
@ -52,8 +54,12 @@ def make_request(
if not prompt_tokens_ids: if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS prompt_tokens_ids = PROMPT_TOKENS
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=prompt_tokens_ids, prompt_token_ids=prompt_tokens_ids,
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,

View File

@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
params = SamplingParams(skip_special_tokens=True) params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest( request = EngineCoreRequest(
request_id="test", request_id="test",
external_req_id="test-ext",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
mm_features=None, mm_features=None,
sampling_params=params, sampling_params=params,

View File

@ -58,12 +58,12 @@ def test_incremental_detokenization(
output_processor = OutputProcessor( output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
) )
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests. # Make N requests.
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -83,6 +83,11 @@ def test_incremental_detokenization(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt) output_processor.add_request(request, prompt)
@ -438,15 +443,6 @@ def test_logprobs_processor(
dummy_test_vectors, dummy_test_vectors,
): ):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
)
# Make N requests. # Make N requests.
request_id_list = [ request_id_list = [
@ -454,7 +450,8 @@ def test_logprobs_processor(
] ]
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=request_id_list[idx], request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -476,6 +473,17 @@ def test_logprobs_processor(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt) output_processor.add_request(request, prompt)
@ -621,19 +629,12 @@ def test_stop_token(
] ]
prompt_string = dummy_test_vectors.prompt_strings[0] prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0] prompt_tokens = dummy_test_vectors.prompt_tokens[0]
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
)
# Make request. # Make request.
request_id = "request-0" request_id = "request-0"
request = EngineCoreRequest( request = EngineCoreRequest(
request_id=request_id, request_id=request_id,
external_req_id=request_id + "-ext",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
@ -655,6 +656,16 @@ def test_stop_token(
pooling_params=None, pooling_params=None,
) )
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
request_ids=[request.request_id],
)
# Add request to the detokenizer. # Add request to the detokenizer.
output_processor.add_request(request, prompt_string) output_processor.add_request(request, prompt_string)
@ -720,13 +731,6 @@ def test_stop_string(
dummy_test_vectors, dummy_test_vectors,
): ):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
)
# Make N requests. # Make N requests.
request_id_list = [ request_id_list = [
@ -734,7 +738,8 @@ def test_stop_string(
] ]
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=request_id_list[idx], request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -756,6 +761,15 @@ def test_stop_string(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer. # Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt) output_processor.add_request(request, prompt)
@ -813,9 +827,12 @@ def test_stop_string(
for idx, (ref_gen_str, stop_str) in enumerate( for idx, (ref_gen_str, stop_str) in enumerate(
zip(dummy_test_vectors.generation_strings, STOP_STRINGS) zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
): ):
# Request should be aborted. # Request should be aborted (check internal ID in abort list).
internal_request_id = f"request-{idx}-int"
assert internal_request_id in aborted
# Use external ID for collecting outputs
request_id = f"request-{idx}" request_id = f"request-{idx}"
assert request_id in aborted
# Collected values that were generated. # Collected values that were generated.
gen_str = gen_strings[request_id] gen_str = gen_strings[request_id]
@ -848,13 +865,13 @@ def test_stop_string(
def test_iteration_stats(dummy_test_vectors): def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic() engine_core_timestamp = time.monotonic()
# Make N requests. # Make N requests.
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
external_req_id=f"request-{idx}-ext",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests except one to the OutputProcessor. # Add all requests except one to the OutputProcessor.
num_active = len(dummy_test_vectors.generation_tokens) - 1 num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]: for request in requests[:num_active]:
@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
output_processor = OutputProcessor( output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats dummy_test_vectors.tokenizer, log_stats=log_stats
) )
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic() engine_core_timestamp = time.monotonic()
# Create LoRA requests # Create LoRA requests
@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
lora_assignments = [lora1, lora2, None] lora_assignments = [lora1, lora2, None]
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests to the OutputProcessor # Add all requests to the OutputProcessor
for request in requests: for request in requests:
output_processor.add_request(request, None) output_processor.add_request(request, None)
@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs( outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
) )
# Find and mark request-0 as finished (it uses lora-1) # Find and mark request-0-int as finished (it uses lora-1)
for output in outputs.outputs: for output in outputs.outputs:
if output.request_id == "request-0": if output.request_id == "request-0-int":
output.finish_reason = FinishReason.LENGTH output.finish_reason = FinishReason.LENGTH
break break
@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs( outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
) )
# Find and mark request-1 as finished (it uses lora-2) # Find and mark request-1-int as finished (it uses lora-2)
for output in outputs.outputs: for output in outputs.outputs:
if output.request_id == "request-1": if output.request_id == "request-1-int":
output.finish_reason = FinishReason.LENGTH output.finish_reason = FinishReason.LENGTH
break break
@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs( outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats() outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
) )
# Find and mark request-2 as finished (it has no LoRA) # Find and mark request-2-int as finished (it has no LoRA)
for output in outputs.outputs: for output in outputs.outputs:
if output.request_id == "request-2": if output.request_id == "request-2-int":
output.finish_reason = FinishReason.LENGTH output.finish_reason = FinishReason.LENGTH
break break
@ -1107,7 +1134,9 @@ async def test_request_output_collector():
for idx in range(NUM_REQS) for idx in range(NUM_REQS)
] ]
collector = RequestOutputCollector(RequestOutputKind.DELTA) collector = RequestOutputCollector(
RequestOutputKind.DELTA, request_id="my-request-id-int"
)
# CASE 1: Put then get. # CASE 1: Put then get.
outputs = make_outputs() outputs = make_outputs()
@ -1163,7 +1192,9 @@ async def test_request_output_collector():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cumulative_output_collector_n(): async def test_cumulative_output_collector_n():
"""Test collector correctly handles multiple outputs by index.""" """Test collector correctly handles multiple outputs by index."""
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE) collector = RequestOutputCollector(
RequestOutputKind.CUMULATIVE, request_id="my-request-id-int"
)
outputs = [ outputs = [
RequestOutput( RequestOutput(
request_id="my-request-id", request_id="my-request-id",
@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n():
@pytest.mark.parametrize("runner", ["generate", "pooling"]) @pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors): @pytest.mark.parametrize("abort_by", ["internal", "external"])
def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
requests = [ requests = [
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
external_req_id=f"external-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
mm_features=None, mm_features=None,
eos_token_id=None, eos_token_id=None,
@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors):
output_kind = request.sampling_params.output_kind output_kind = request.sampling_params.output_kind
else: else:
output_kind = request.pooling_params.output_kind output_kind = request.pooling_params.output_kind
queue = RequestOutputCollector(output_kind=output_kind) queue = RequestOutputCollector(
output_kind=output_kind, request_id=request.request_id
)
output_processor.add_request(request, None, queue=queue) output_processor.add_request(request, None, queue=queue)
for request in requests: for request in requests:
output_processor.abort_requests([request.request_id]) if abort_by == "internal":
output_processor.abort_requests([request.request_id], internal=True)
else:
output_processor.abort_requests([request.external_req_id], internal=False)

View File

@ -4,11 +4,12 @@
from vllm import SamplingParams from vllm import SamplingParams
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
def test_parent_request_to_output_stream() -> None: def test_parent_request_to_output_stream() -> None:
parent_request = ParentRequest("parent_id", SamplingParams(n=2)) parent_request = ParentRequest(make_request(SamplingParams(n=2)))
parent_request.child_requests = {"child_id_0", "child_id_1"} parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput( output_0 = CompletionOutput(
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
@ -17,51 +18,31 @@ def test_parent_request_to_output_stream() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
) )
# Request not finished # Request not finished
assert ("parent_id", [output_0], False) == parent_request.get_outputs( assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
"child_id_0", output_0 assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
) assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ("parent_id", [output_1], False) == parent_request.get_outputs( assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
"child_id_1", output_1
)
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
# output_1 finished # output_1 finished
output_1.finish_reason = "ended" output_1.finish_reason = "ended"
assert ("parent_id", [output_0], False) == parent_request.get_outputs( assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
"child_id_0", output_0 assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
# Finished output_1 had already returned, DO NOT returned again # Finished output_1 had already returned, DO NOT returned again
assert ("parent_id", [output_0], False) == parent_request.get_outputs( assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
"child_id_0", output_0 assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
# output_0 finished # output_0 finished
output_0.finish_reason = "ended" output_0.finish_reason = "ended"
assert ("parent_id", [output_0], True) == parent_request.get_outputs( assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0)
"child_id_0", output_0 assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
# Finished output_0 had already returned, DO NOT returned again # Finished output_0 had already returned, DO NOT returned again
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True) assert parent_request.get_outputs("child_id_0", output_0) == ([], True)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True) assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
def test_parent_request_to_output_final_only() -> None: def test_parent_request_to_output_final_only() -> None:
parent_request = ParentRequest( parent_request = ParentRequest(
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY) make_request(SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY))
) )
parent_request.child_requests = {"child_id_0", "child_id_1"} parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput( output_0 = CompletionOutput(
@ -71,33 +52,33 @@ def test_parent_request_to_output_final_only() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
) )
# Request not finished, return nothing # Request not finished, return nothing
assert parent_request.get_outputs("child_id_0", output_0) == ( assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
"parent_id", assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
# output_1 finished, but outputs won't be returned until all child requests finished # output_1 finished, but outputs won't be returned until all child requests finished
output_1.finish_reason = "ended" output_1.finish_reason = "ended"
assert parent_request.get_outputs("child_id_0", output_0) == ( assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
"parent_id", assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
# output_0 finished, as all child requests finished, the output would be returned # output_0 finished, as all child requests finished, the output would be returned
output_0.finish_reason = "ended" output_0.finish_reason = "ended"
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs( assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_0", output_0 "child_id_0", output_0
) )
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs( assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_1", output_1 "child_id_1", output_1
) )
def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id="parent_id",
external_req_id="ext_parent_id",
prompt_token_ids=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -6,6 +6,7 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine import input_processor as input_processor_mod
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
) )
captured: dict[str, object] = {} captured: dict[str, MultiModalUUIDDict] = {}
def fake_preprocess( def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
@ -196,7 +197,16 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
assert captured["mm_uuids"] == { mm_uuids = captured["mm_uuids"]
"image": [f"{request_id}-image-0", f"{request_id}-image-1"], assert set(mm_uuids.keys()) == {"image", "video"}
"video": [f"{request_id}-video-0"], assert len(mm_uuids["image"]) == 2
} assert len(mm_uuids["video"]) == 1
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][0].endswith("-0")
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][1].endswith("-1")
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
"video"
][0].endswith("-0")

View File

@ -343,6 +343,7 @@ class MockEngineCore:
eos_token_id: int | None = None, eos_token_id: int | None = None,
stop_token_ids: list[int] | None = None, stop_token_ids: list[int] | None = None,
ignore_eos: bool = False, ignore_eos: bool = False,
request_ids: list[str] | None = None,
) -> None: ) -> None:
self.num_requests = len(tokens_list) self.num_requests = len(tokens_list)
self.tokens_list = tokens_list self.tokens_list = tokens_list
@ -355,6 +356,11 @@ class MockEngineCore:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.request_ids = (
request_ids
if request_ids is not None
else [f"request-{i}" for i in range(self.num_requests)]
)
def get_outputs(self) -> list[EngineCoreOutput]: def get_outputs(self) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs do_logprobs = self.do_logprobs
@ -386,7 +392,7 @@ class MockEngineCore:
prompt_logprobs = None prompt_logprobs = None
new_token_id = token_ids[token_idx] new_token_id = token_ids[token_idx]
output = EngineCoreOutput( output = EngineCoreOutput(
request_id=f"request-{req_idx}", request_id=self.request_ids[req_idx],
new_token_ids=[new_token_id], new_token_ids=[new_token_id],
new_logprobs=logprobs, new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs, new_prompt_logprobs_tensors=prompt_logprobs,

View File

@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import Platform from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
@ -1265,6 +1268,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
run_test_and_cleanup() run_test_and_cleanup()
class RequestIdMapper:
"""Helper class to map external request IDs to internal request IDs."""
def __init__(self, output_processor: OutputProcessor):
self.req_id_mapping: dict[str, str] = {}
self.original_add_request = output_processor.add_request
output_processor.add_request = self._add_request
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
self.req_id_mapping[request.external_req_id] = request.request_id
return self.original_add_request(request, *args, **kwargs)
def __call__(self, external_req_id: str) -> str:
return self.req_id_mapping[external_req_id]
def _run_abort_timeout_test(llm: LLM, timeout: int): def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic.""" """Helper function to run the abort timeout test logic."""
remote_prefill_opts = { remote_prefill_opts = {
@ -1286,24 +1305,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
0 0
].req_to_blocks ].req_to_blocks
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
def req_id(outputs: list[RequestOutput]) -> str:
assert len(outputs) == 1
return id_mapper(outputs[0].request_id)
padding = "Just making this request a little longer so that we're sure " padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't " "we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the " "actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D." "blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params) req0_id = req_id(
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
)
# Request finished but not freed # Request finished but not freed
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
# Some other request, 0 still not freed # Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) req1_id = req_id(
assert "0" in req_to_blocks llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks )
assert req0_id in req_to_blocks
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
# Wait for timeout and trigger another scheduler loop # Wait for timeout and trigger another scheduler loop
time.sleep(timeout) time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared! # Request-0 times out and is cleared!
assert "0" not in req_to_blocks assert req0_id not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port # Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown() llm.llm_engine.engine_core.shutdown()

View File

@ -1621,7 +1621,7 @@ class LLM:
added_request_ids.append(request_id) added_request_ids.append(request_id)
except Exception as e: except Exception as e:
if added_request_ids: if added_request_ids:
self.llm_engine.abort_request(added_request_ids) self.llm_engine.abort_request(added_request_ids, internal=True)
raise e raise e
def _validate_mm_data_and_uuids( def _validate_mm_data_and_uuids(
@ -1731,7 +1731,7 @@ class LLM:
priority=priority, priority=priority,
prompt_text=prompt_text, prompt_text=prompt_text,
) )
return request_id return engine_request.request_id
def _run_engine( def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True self, *, use_tqdm: bool | Callable[..., tqdm] = True

View File

@ -75,6 +75,12 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None trace_headers: Mapping[str, str] | None = None
# The user-provided request ID. This field is set internally,
# copied from the provided request_id that's originally assigned
# to the request_id field, see InputProcessor.assign_request_id().
# Used in outputs and to support abort(req_id, internal=False).
external_req_id: str | None = None
@property @property
def params(self) -> SamplingParams | PoolingParams: def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling).""" """Return the processed params (sampling or pooling)."""

View File

@ -290,12 +290,15 @@ class AsyncLLM(EngineClient):
is_pooling = isinstance(params, PoolingParams) is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request. # Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest): if isinstance(prompt, EngineCoreRequest):
request = prompt request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else: else:
assert prompt_text is None assert prompt_text is None
request = self.input_processor.process_inputs( request = self.input_processor.process_inputs(
@ -314,6 +317,11 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping): elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt")) prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id)
# Use cloned params that may have been updated in process_inputs() # Use cloned params that may have been updated in process_inputs()
params = request.params params = request.params
@ -325,7 +333,7 @@ class AsyncLLM(EngineClient):
assert isinstance(parent_params, SamplingParams) assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params) parent_request = ParentRequest(request)
for idx in range(parent_params.n): for idx in range(parent_params.n):
request_id, child_params = parent_request.get_child_info(idx) request_id, child_params = parent_request.get_child_info(idx)
child_request = request if idx == parent_params.n - 1 else copy(request) child_request = request if idx == parent_params.n - 1 else copy(request)
@ -396,6 +404,7 @@ class AsyncLLM(EngineClient):
"prompt logprobs" "prompt logprobs"
) )
q: RequestOutputCollector | None = None
try: try:
# We start the output_handler on the first call to generate() so # We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us # we can call __init__ before the event loop, which enables us
@ -446,7 +455,8 @@ class AsyncLLM(EngineClient):
# is cancelled or the generator is garbage collected. So, # is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here. # we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit): except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id) if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests: if self.log_requests:
logger.info("Request %s aborted.", request_id) logger.info("Request %s aborted.", request_id)
raise raise
@ -465,7 +475,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable). # Unexpected error in the generate() task (possibly recoverable).
except Exception as e: except Exception as e:
await self.abort(request_id) if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests: if self.log_requests:
logger.info("Request %s failed.", request_id) logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e raise EngineGenerateError() from e
@ -541,13 +552,15 @@ class AsyncLLM(EngineClient):
self.output_handler = asyncio.create_task(output_handler()) self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str | Iterable[str]) -> None: async def abort(
self, request_id: str | Iterable[str], internal: bool = False
) -> None:
"""Abort RequestId in OutputProcessor and EngineCore.""" """Abort RequestId in OutputProcessor and EngineCore."""
request_ids = ( request_ids = (
(request_id,) if isinstance(request_id, str) else as_list(request_id) (request_id,) if isinstance(request_id, str) else as_list(request_id)
) )
all_request_ids = self.output_processor.abort_requests(request_ids) all_request_ids = self.output_processor.abort_requests(request_ids, internal)
await self.engine_core.abort_requests_async(all_request_ids) await self.engine_core.abort_requests_async(all_request_ids)
if self.log_requests: if self.log_requests:
@ -581,7 +594,7 @@ class AsyncLLM(EngineClient):
if not wait_for_inflight_requests: if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys()) request_ids = list(self.output_processor.request_states.keys())
if request_ids: if request_ids:
await self.abort(request_ids) await self.abort(request_ids, internal=True)
# Wait for running requests to drain before clearing cache. # Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests(): if self.output_processor.has_unfinished_requests():
@ -633,6 +646,7 @@ class AsyncLLM(EngineClient):
TODO: Remove truncate_prompt_tokens in v0.15. TODO: Remove truncate_prompt_tokens in v0.15.
""" """
q: RequestOutputCollector | None = None
try: try:
# We start the output_handler on the first call to generate() so # We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us # we can call __init__ before the event loop, which enables us
@ -687,7 +701,8 @@ class AsyncLLM(EngineClient):
# If the request is disconnected by the client, generate() # If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here. # is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError: except asyncio.CancelledError:
await self.abort(request_id) if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests: if self.log_requests:
logger.info("Request %s aborted.", request_id) logger.info("Request %s aborted.", request_id)
raise raise
@ -706,7 +721,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable). # Unexpected error in the generate() task (possibly recoverable).
except Exception as e: except Exception as e:
await self.abort(request_id) if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests: if self.log_requests:
logger.info("Request %s failed.", request_id) logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e raise EngineGenerateError() from e

View File

@ -21,7 +21,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
@ -406,6 +406,19 @@ class InputProcessor:
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids return mm_uuids
@staticmethod
def assign_request_id(request: EngineCoreRequest):
"""Replace the externally supplied request ID with an internal request ID
that adds 8 random characters in order to ensure uniquness.
"""
if request.external_req_id is not None:
raise ValueError(
"The external_req_id field should not be set on EngineCoreRequests"
" passed to vLLM; use the request_id field."
)
request.external_req_id = request.request_id
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,

View File

@ -213,10 +213,10 @@ class LLMEngine:
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks() return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None: def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer.""" """Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids) request_ids = self.output_processor.abort_requests(request_ids, internal)
self.engine_core.abort_requests(request_ids) self.engine_core.abort_requests(request_ids)
def add_request( def add_request(
@ -238,6 +238,12 @@ class LLMEngine:
# Process raw inputs into the request. # Process raw inputs into the request.
if isinstance(prompt, EngineCoreRequest): if isinstance(prompt, EngineCoreRequest):
request = prompt request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else: else:
assert prompt_text is None assert prompt_text is None
request = self.input_processor.process_inputs( request = self.input_processor.process_inputs(
@ -255,6 +261,8 @@ class LLMEngine:
elif isinstance(prompt, Mapping): elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt")) prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Use cloned params that may have been updated in process_inputs() # Use cloned params that may have been updated in process_inputs()
params = request.params params = request.params
@ -268,7 +276,7 @@ class LLMEngine:
return return
# Fan out child requests (for n>1). # Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params) parent_req = ParentRequest(request)
for idx in range(n): for idx in range(n):
request_id, child_params = parent_req.get_child_info(idx) request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request) child_request = request if idx == n - 1 else copy(request)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast from typing import Any, cast
@ -40,8 +41,9 @@ class RequestOutputCollector:
producer gets ahead of the consumer. producer gets ahead of the consumer.
""" """
def __init__(self, output_kind: RequestOutputKind): def __init__(self, output_kind: RequestOutputKind, request_id: str):
self.aggregate = output_kind == RequestOutputKind.DELTA self.aggregate = output_kind == RequestOutputKind.DELTA
self.request_id = request_id
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event() self.ready = asyncio.Event()
@ -92,6 +94,7 @@ class RequestState:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
external_req_id: str,
parent_req: ParentRequest | None, parent_req: ParentRequest | None,
request_index: int, request_index: int,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
@ -111,6 +114,7 @@ class RequestState:
temperature: float | None = None, temperature: float | None = None,
): ):
self.request_id = request_id self.request_id = request_id
self.external_req_id = external_req_id
self.parent_req = parent_req self.parent_req = parent_req
self.request_index = request_index self.request_index = request_index
self.lora_request = lora_request self.lora_request = lora_request
@ -176,8 +180,10 @@ class RequestState:
assert request.pooling_params is not None assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind output_kind = request.pooling_params.output_kind
assert request.external_req_id is not None
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
external_req_id=request.external_req_id,
parent_req=parent_req, parent_req=parent_req,
request_index=request_index, request_index=request_index,
lora_request=request.lora_request, lora_request=request.lora_request,
@ -235,10 +241,13 @@ class RequestState:
] ]
self.sent_tokens_offset = len(self.detokenizer.output_token_ids) self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
request_id = self.request_id external_req_id = self.external_req_id
if pooling_output is not None: if pooling_output is not None:
return self._new_request_output( return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)], finished external_req_id,
[self._new_pooling_output(pooling_output)],
finished,
) )
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
@ -246,19 +255,18 @@ class RequestState:
if self.parent_req is None: if self.parent_req is None:
outputs = [output] outputs = [output]
else: else:
request_id, outputs, finished = self.parent_req.get_outputs( outputs, finished = self.parent_req.get_outputs(self.request_id, output)
request_id, output
)
if not outputs: if not outputs:
return None return None
external_req_id = self.parent_req.external_req_id
return self._new_request_output( return self._new_request_output(
request_id, outputs, finished, kv_transfer_params external_req_id, outputs, finished, kv_transfer_params
) )
def _new_request_output( def _new_request_output(
self, self,
request_id: str, external_req_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput], outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool, finished: bool,
kv_transfer_params: dict[str, Any] | None = None, kv_transfer_params: dict[str, Any] | None = None,
@ -269,7 +277,7 @@ class RequestState:
# Prompt embeddings are currently not supported by pooling requests. # Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None assert self.prompt_token_ids is not None
return PoolingRequestOutput( return PoolingRequestOutput(
request_id=request_id, request_id=external_req_id,
outputs=first_output, outputs=first_output,
num_cached_tokens=self.num_cached_tokens, num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids, prompt_token_ids=self.prompt_token_ids,
@ -288,7 +296,7 @@ class RequestState:
prompt_token_ids = [0] * len(self.prompt_embeds) prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput( return RequestOutput(
request_id=request_id, request_id=external_req_id, # request_id is what was provided externally
lora_request=self.lora_request, lora_request=self.lora_request,
prompt=self.prompt, prompt=self.prompt,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
@ -352,6 +360,7 @@ class OutputProcessor:
self.stream_interval = stream_interval self.stream_interval = stream_interval
self.request_states: dict[str, RequestState] = {} self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {} self.parent_requests: dict[str, ParentRequest] = {}
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats) self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None self.tracer: Tracer | None = None
self._requests_drained = asyncio.Event() self._requests_drained = asyncio.Event()
@ -375,12 +384,41 @@ class OutputProcessor:
assert state.queue is not None assert state.queue is not None
state.queue.put(e) state.queue.put(e)
def abort_requests( def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
self, """Abort a list of requests.
request_ids: Iterable[str],
) -> list[str]: The request_ids may be either external request IDs (those passed to
request_ids_to_abort = [] InputProcessor.process_inputs()) or internal request IDs (those randomly
generated when creating the EngineCoreRequest).
If an external request ID is provided, and that external request ID
was used for multiple requests, all requests associated with that external
request ID are aborted.
In the case of parallel sampling, a request ID may be used to identify
a parent request, in which case the associated child requests are aborted
also.
"""
internal_req_ids = []
for request_id in request_ids: for request_id in request_ids:
if internal:
# Internal ID - this may be a parent request
internal_req_ids.append(request_id)
# Remove internal ID from the external->internal mapping
if req_state := self.request_states.get(request_id):
external_req_id = req_state.external_req_id
internal_ids = self.external_req_ids[external_req_id]
internal_ids.remove(request_id)
if not internal_ids:
del self.external_req_ids[external_req_id]
elif internal_ids := self.external_req_ids.pop(request_id, []):
# External ID - abort all requests in the external->internal mapping
internal_req_ids.extend(internal_ids)
request_ids_to_abort = []
for request_id in internal_req_ids:
req_state = self.request_states.pop(request_id, None) req_state = self.request_states.pop(request_id, None)
if req_state is not None: if req_state is not None:
self.lora_states.request_finished(request_id, req_state.lora_name) self.lora_states.request_finished(request_id, req_state.lora_name)
@ -404,7 +442,7 @@ class OutputProcessor:
# Abort children prior to removing the parent. # Abort children prior to removing the parent.
if parent.child_requests: if parent.child_requests:
child_reqs = list(parent.child_requests) child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs) child_reqs = self.abort_requests(child_reqs, internal=True)
request_ids_to_abort.extend(child_reqs) request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None) self.parent_requests.pop(request_id, None)
if not self.request_states: if not self.request_states:
@ -439,6 +477,9 @@ class OutputProcessor:
if parent_req: if parent_req:
self.parent_requests[parent_req.request_id] = parent_req self.parent_requests[parent_req.request_id] = parent_req
# Track the external_req_id -> [internal_req_id, ...] mapping
self.external_req_ids[req_state.external_req_id].append(request_id)
def process_outputs( def process_outputs(
self, self,
engine_core_outputs: list[EngineCoreOutput], engine_core_outputs: list[EngineCoreOutput],
@ -522,6 +563,12 @@ class OutputProcessor:
# Free completed requests. # Free completed requests.
if finish_reason is not None: if finish_reason is not None:
self.request_states.pop(req_id) self.request_states.pop(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable. # Remove parent request if applicable.
parent_req = req_state.parent_req parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests: if parent_req and not parent_req.child_requests:
@ -597,7 +644,9 @@ class OutputProcessor:
) )
# meta # meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id) span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
)
if req_state.top_p: if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p) span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param: if req_state.max_tokens_param:

View File

@ -6,6 +6,7 @@ from typing import Optional, cast
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
@ -17,6 +18,7 @@ class ParentRequest:
""" """
request_id: str request_id: str
external_req_id: str
sampling_params: SamplingParams sampling_params: SamplingParams
# To track the completion of child requests # To track the completion of child requests
@ -31,8 +33,11 @@ class ParentRequest:
# To efficiently obtain child sampling params # To efficiently obtain child sampling params
cached_child_sampling_params: SamplingParams | None cached_child_sampling_params: SamplingParams | None
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: def __init__(self, request: EngineCoreRequest) -> None:
self.request_id = request_id assert request.external_req_id is not None
sampling_params = request.params
self.request_id = request.request_id
self.external_req_id = request.external_req_id
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.child_requests = set() self.child_requests = set()
@ -96,7 +101,7 @@ class ParentRequest:
self, self,
child_request_id: str, child_request_id: str,
completion_output: CompletionOutput, completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]: ) -> tuple[list[CompletionOutput], bool]:
already_finished_and_returned: bool = False already_finished_and_returned: bool = False
if completion_output.finished(): if completion_output.finished():
if child_request_id in self.child_requests: if child_request_id in self.child_requests:
@ -118,7 +123,7 @@ class ParentRequest:
outputs = [] if self.child_requests else self.output_aggregator outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests finished = not self.child_requests
return self.request_id, outputs, finished return outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int): def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max( self.max_num_generation_tokens = max(