Merge branch 'main' into elvischenv/update-flashinfer

This commit is contained in:
elvischenv 2025-12-24 09:53:17 +08:00 committed by GitHub
commit 30d32ef5c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 657 additions and 606 deletions

View File

@ -2,4 +2,4 @@
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
You can use vLLM with KServe's [Hugging Face serving runtime](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) or via [`LLMInferenceService` that uses llm-d](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).

View File

@ -0,0 +1,5 @@
# llm-d
vLLM can be deployed with [llm-d](https://github.com/llm-d/llm-d), a Kubernetes-native distributed inference serving stack providing well-lit paths for anyone to serve large generative AI models at scale. It helps achieve the fastest "time to state-of-the-art (SOTA) performance" for key OSS models across most hardware accelerators and infrastructure providers.
You can use vLLM with llm-d directly by following [this guide](https://llm-d.ai/docs/guide) or via [KServe's LLMInferenceService](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).

View File

@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
- [Helm](frameworks/helm.md)
- [InftyAI/llmaz](integrations/llmaz.md)
- [llm-d](integrations/llm-d.md)
- [KAITO](integrations/kaito.md)
- [KServe](integrations/kserve.md)
- [Kthena](integrations/kthena.md)

View File

@ -0,0 +1,11 @@
model_name: "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
accuracy_threshold: 0.85
num_questions: 1319
num_fewshot: 5
server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--async-scheduling
env:
VLLM_USE_FLASHINFER_MOE_FP8: "1"

View File

@ -4,3 +4,4 @@ Qwen1.5-MoE-W4A16-CT.yaml
DeepSeek-V2-Lite-Instruct-FP8.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Qwen3-Next-FP8-EP2.yaml

View File

@ -71,6 +71,7 @@ def test_gsm8k_correctness(config_filename):
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}")
# Launch server and run evaluation
with RemoteOpenAIServer(

View File

@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

View File

@ -106,6 +106,7 @@ class RemoteOpenAIServer:
env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
print(f"Environment variables: {env}")
self.proc: subprocess.Popen = subprocess.Popen(
serve_cmd,
env=env,

View File

@ -260,7 +260,7 @@ async def test_multi_abort(output_kind: RequestOutputKind):
# Use multi-abort to abort multiple requests at once
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
await engine.abort(abort_request_ids)
await engine.abort(abort_request_ids, internal=False)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
@ -609,7 +609,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind):
await asyncio.sleep(0.5)
# Abort the request
await engine.abort(request_id)
await engine.abort(request_id, internal=False)
# Wait for generation to complete and return final output
final_output = await generated

View File

@ -40,10 +40,16 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "I am Gyoubu Masataka Oniwa"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request() -> EngineCoreRequest:
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=PROMPT_TOKENS,
mm_features=None,
sampling_params=SamplingParams(),

View File

@ -45,6 +45,8 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
_REQUEST_COUNTER = 0
def make_request(
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
@ -52,8 +54,12 @@ def make_request(
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
global _REQUEST_COUNTER
_REQUEST_COUNTER += 1
request_id = f"request-{_REQUEST_COUNTER}"
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
request_id=request_id,
external_req_id=f"{request_id}-{uuid.uuid4()}",
prompt_token_ids=prompt_tokens_ids,
mm_features=None,
sampling_params=params,

View File

@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
request_id="test",
external_req_id="test-ext",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,

View File

@ -58,12 +58,12 @@ def test_incremental_detokenization(
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
)
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests.
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -83,6 +83,11 @@ def test_incremental_detokenization(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -438,15 +443,6 @@ def test_logprobs_processor(
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
)
# Make N requests.
request_id_list = [
@ -454,7 +450,8 @@ def test_logprobs_processor(
]
requests = [
EngineCoreRequest(
request_id=request_id_list[idx],
request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -476,6 +473,17 @@ def test_logprobs_processor(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -621,19 +629,12 @@ def test_stop_token(
]
prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
)
# Make request.
request_id = "request-0"
request = EngineCoreRequest(
request_id=request_id,
external_req_id=request_id + "-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=eos_token_id,
@ -655,6 +656,16 @@ def test_stop_token(
pooling_params=None,
)
engine_core = MockEngineCore(
tokens_list=[generation_tokens],
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos,
request_ids=[request.request_id],
)
# Add request to the detokenizer.
output_processor.add_request(request, prompt_string)
@ -720,13 +731,6 @@ def test_stop_string(
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
)
# Make N requests.
request_id_list = [
@ -734,7 +738,8 @@ def test_stop_string(
]
requests = [
EngineCoreRequest(
request_id=request_id_list[idx],
request_id=request_id_list[idx] + "-int",
external_req_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -756,6 +761,15 @@ def test_stop_string(
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
request_ids=[req.request_id for req in requests],
)
# Add requests to the detokenizer.
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
@ -813,9 +827,12 @@ def test_stop_string(
for idx, (ref_gen_str, stop_str) in enumerate(
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
):
# Request should be aborted.
# Request should be aborted (check internal ID in abort list).
internal_request_id = f"request-{idx}-int"
assert internal_request_id in aborted
# Use external ID for collecting outputs
request_id = f"request-{idx}"
assert request_id in aborted
# Collected values that were generated.
gen_str = gen_strings[request_id]
@ -848,13 +865,13 @@ def test_stop_string(
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Make N requests.
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
external_req_id=f"request-{idx}-ext",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests except one to the OutputProcessor.
num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]:
@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats
)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Create LoRA requests
@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
lora_assignments = [lora1, lora2, None]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
request_id=f"request-{idx}-int",
external_req_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
engine_core = MockEngineCore(
dummy_test_vectors.generation_tokens,
request_ids=[req.request_id for req in requests],
)
# Add all requests to the OutputProcessor
for request in requests:
output_processor.add_request(request, None)
@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-0 as finished (it uses lora-1)
# Find and mark request-0-int as finished (it uses lora-1)
for output in outputs.outputs:
if output.request_id == "request-0":
if output.request_id == "request-0-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-1 as finished (it uses lora-2)
# Find and mark request-1-int as finished (it uses lora-2)
for output in outputs.outputs:
if output.request_id == "request-1":
if output.request_id == "request-1-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-2 as finished (it has no LoRA)
# Find and mark request-2-int as finished (it has no LoRA)
for output in outputs.outputs:
if output.request_id == "request-2":
if output.request_id == "request-2-int":
output.finish_reason = FinishReason.LENGTH
break
@ -1107,7 +1134,9 @@ async def test_request_output_collector():
for idx in range(NUM_REQS)
]
collector = RequestOutputCollector(RequestOutputKind.DELTA)
collector = RequestOutputCollector(
RequestOutputKind.DELTA, request_id="my-request-id-int"
)
# CASE 1: Put then get.
outputs = make_outputs()
@ -1163,7 +1192,9 @@ async def test_request_output_collector():
@pytest.mark.asyncio
async def test_cumulative_output_collector_n():
"""Test collector correctly handles multiple outputs by index."""
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
collector = RequestOutputCollector(
RequestOutputKind.CUMULATIVE, request_id="my-request-id-int"
)
outputs = [
RequestOutput(
request_id="my-request-id",
@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n():
@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
@pytest.mark.parametrize("abort_by", ["internal", "external"])
def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
external_req_id=f"external-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors):
output_kind = request.sampling_params.output_kind
else:
output_kind = request.pooling_params.output_kind
queue = RequestOutputCollector(output_kind=output_kind)
queue = RequestOutputCollector(
output_kind=output_kind, request_id=request.request_id
)
output_processor.add_request(request, None, queue=queue)
for request in requests:
output_processor.abort_requests([request.request_id])
if abort_by == "internal":
output_processor.abort_requests([request.request_id], internal=True)
else:
output_processor.abort_requests([request.external_req_id], internal=False)

View File

@ -4,11 +4,12 @@
from vllm import SamplingParams
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.parallel_sampling import ParentRequest
def test_parent_request_to_output_stream() -> None:
parent_request = ParentRequest("parent_id", SamplingParams(n=2))
parent_request = ParentRequest(make_request(SamplingParams(n=2)))
parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput(
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
@ -17,51 +18,31 @@ def test_parent_request_to_output_stream() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
)
# Request not finished
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
# output_1 finished
output_1.finish_reason = "ended"
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
"child_id_1", output_1
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
# Finished output_1 had already returned, DO NOT returned again
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
"child_id_0", output_0
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_0 finished
output_0.finish_reason = "ended"
assert ("parent_id", [output_0], True) == parent_request.get_outputs(
"child_id_0", output_0
)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0)
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
# Finished output_0 had already returned, DO NOT returned again
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True)
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
assert parent_request.get_outputs("child_id_0", output_0) == ([], True)
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
def test_parent_request_to_output_final_only() -> None:
parent_request = ParentRequest(
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY)
make_request(SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY))
)
parent_request.child_requests = {"child_id_0", "child_id_1"}
output_0 = CompletionOutput(
@ -71,33 +52,33 @@ def test_parent_request_to_output_final_only() -> None:
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
)
# Request not finished, return nothing
assert parent_request.get_outputs("child_id_0", output_0) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_1 finished, but outputs won't be returned until all child requests finished
output_1.finish_reason = "ended"
assert parent_request.get_outputs("child_id_0", output_0) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_1", output_1) == (
"parent_id",
[],
False,
)
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
# output_0 finished, as all child requests finished, the output would be returned
output_0.finish_reason = "ended"
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_0", output_0
)
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
assert ([output_0, output_1], True) == parent_request.get_outputs(
"child_id_1", output_1
)
def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id="parent_id",
external_req_id="ext_parent_id",
prompt_token_ids=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -6,6 +6,7 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import input_processor as input_processor_mod
from vllm.v1.engine.input_processor import InputProcessor
@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
)
captured: dict[str, object] = {}
captured: dict[str, MultiModalUUIDDict] = {}
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
@ -196,7 +197,16 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
# Expect request-id-based overrides are passed through
assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}
mm_uuids = captured["mm_uuids"]
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][0].endswith("-0")
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][1].endswith("-1")
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
"video"
][0].endswith("-0")

View File

@ -343,6 +343,7 @@ class MockEngineCore:
eos_token_id: int | None = None,
stop_token_ids: list[int] | None = None,
ignore_eos: bool = False,
request_ids: list[str] | None = None,
) -> None:
self.num_requests = len(tokens_list)
self.tokens_list = tokens_list
@ -355,6 +356,11 @@ class MockEngineCore:
self.eos_token_id = eos_token_id
self.stop_token_ids = stop_token_ids
self.ignore_eos = ignore_eos
self.request_ids = (
request_ids
if request_ids is not None
else [f"request-{i}" for i in range(self.num_requests)]
)
def get_outputs(self) -> list[EngineCoreOutput]:
do_logprobs = self.do_logprobs
@ -386,7 +392,7 @@ class MockEngineCore:
prompt_logprobs = None
new_token_id = token_ids[token_idx]
output = EngineCoreOutput(
request_id=f"request-{req_idx}",
request_id=self.request_ids[req_idx],
new_token_ids=[new_token_id],
new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs,

View File

@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
@ -1265,6 +1268,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
run_test_and_cleanup()
class RequestIdMapper:
"""Helper class to map external request IDs to internal request IDs."""
def __init__(self, output_processor: OutputProcessor):
self.req_id_mapping: dict[str, str] = {}
self.original_add_request = output_processor.add_request
output_processor.add_request = self._add_request
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
self.req_id_mapping[request.external_req_id] = request.request_id
return self.original_add_request(request, *args, **kwargs)
def __call__(self, external_req_id: str) -> str:
return self.req_id_mapping[external_req_id]
def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic."""
remote_prefill_opts = {
@ -1286,24 +1305,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
0
].req_to_blocks
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
def req_id(outputs: list[RequestOutput]) -> str:
assert len(outputs) == 1
return id_mapper(outputs[0].request_id)
padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
req0_id = req_id(
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
)
# Request finished but not freed
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
# Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
assert "0" in req_to_blocks
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
req1_id = req_id(
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
)
assert req0_id in req_to_blocks
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
# Wait for timeout and trigger another scheduler loop
time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
assert req0_id not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()

View File

@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens=cu_seqlens,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def _forward_fa(
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
fa_version=self._fa_version,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
output = output.reshape(bsz, q_len, -1)
return output
def forward_native(

View File

@ -1621,7 +1621,7 @@ class LLM:
added_request_ids.append(request_id)
except Exception as e:
if added_request_ids:
self.llm_engine.abort_request(added_request_ids)
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e
def _validate_mm_data_and_uuids(
@ -1731,7 +1731,7 @@ class LLM:
priority=priority,
prompt_text=prompt_text,
)
return request_id
return engine_request.request_id
def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True

View File

@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p
BCx, _ = self.in_proj(hidden_states)
@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = []

View File

@ -3,17 +3,11 @@
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend):
@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
@dataclass
class Mamba1AttentionMetadata:
query_start_loc_p: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
pass
class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
metadata_cls = Mamba1AttentionMetadata
supports_update_block_table: bool = False

View File

@ -1,19 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
from dataclasses import dataclass
from dataclasses import dataclass, replace
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
chunk_size: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
seq_idx_p: torch.Tensor | None
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
prep_initial_states: bool = False
chunk_size: int = 0
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None
state_indices_tensor: torch.Tensor # shape: [batch,]
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
supports_update_block_table: bool = True
metadata_cls = Mamba2AttentionMetadata
def __init__(
self,
@ -150,87 +128,93 @@ class Mamba2AttentionMetadataBuilder(
"chunk_size needs to be set in the model config for Mamba2 models"
)
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
seq_lens = common_attn_metadata.seq_lens
common = self._compute_common_metadata(common_attn_metadata)
query_start_loc_p = None
seq_idx_p = None
cu_chunk_seqlen_p = None
last_chunk_indices_p = None
# Need flags to indicate if there are initial states
has_initial_states_p = None
prep_initial_states = False
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# Additional cache-related varaiables:
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if num_prefills > 0:
# [batch,]
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
if common.num_prefills > 0:
prep_initial_states = (
torch.any(common.has_initial_states_p).item()
if common.has_initial_states_p is not None
else False
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
@ -239,137 +223,33 @@ class Mamba2AttentionMetadataBuilder(
- num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
return replace(
common,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
has_initial_states_p=has_initial_states_p,
seq_idx_p=seq_idx_p,
state_indices_tensor=state_indices_tensor,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import copy
from dataclasses import dataclass
from typing import ClassVar, TypeVar
import torch
@ -9,20 +11,52 @@ import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
@dataclass
class BaseMambaAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_reqs: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
metadata_cls: type[M]
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
supports_update_block_table: bool = True
def __init__(
self,
@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
return self.build(0, m)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return self._compute_common_metadata(common_attn_metadata)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)
def _compute_common_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens = None
num_computed_tokens_p = None
# for prefix caching
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
block_idx_last_computed_token = None
block_idx_last_scheduled_token = None
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return self.metadata_cls(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata

View File

@ -2,15 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
@dataclass
class ShortConvAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
# For causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
pass
class ShortConvAttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
if num_prefills > 0:
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata(
query_start_loc=query_start_loc,
state_indices_tensor=state_indices_tensor,
has_initial_states_p=has_initial_states_p,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
metadata_cls = ShortConvAttentionMetadata

View File

@ -75,6 +75,12 @@ class EngineCoreRequest(
trace_headers: Mapping[str, str] | None = None
# The user-provided request ID. This field is set internally,
# copied from the provided request_id that's originally assigned
# to the request_id field, see InputProcessor.assign_request_id().
# Used in outputs and to support abort(req_id, internal=False).
external_req_id: str | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""

View File

@ -290,12 +290,15 @@ class AsyncLLM(EngineClient):
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
@ -314,6 +317,11 @@ class AsyncLLM(EngineClient):
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Create a new output collector for the request.
queue = RequestOutputCollector(params.output_kind, request.request_id)
# Use cloned params that may have been updated in process_inputs()
params = request.params
@ -325,7 +333,7 @@ class AsyncLLM(EngineClient):
assert isinstance(parent_params, SamplingParams)
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, parent_params)
parent_request = ParentRequest(request)
for idx in range(parent_params.n):
request_id, child_params = parent_request.get_child_info(idx)
child_request = request if idx == parent_params.n - 1 else copy(request)
@ -396,6 +404,7 @@ class AsyncLLM(EngineClient):
"prompt logprobs"
)
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
@ -446,7 +455,8 @@ class AsyncLLM(EngineClient):
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
@ -465,7 +475,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
@ -541,13 +552,15 @@ class AsyncLLM(EngineClient):
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str | Iterable[str]) -> None:
async def abort(
self, request_id: str | Iterable[str], internal: bool = False
) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = (
(request_id,) if isinstance(request_id, str) else as_list(request_id)
)
all_request_ids = self.output_processor.abort_requests(request_ids)
all_request_ids = self.output_processor.abort_requests(request_ids, internal)
await self.engine_core.abort_requests_async(all_request_ids)
if self.log_requests:
@ -581,7 +594,7 @@ class AsyncLLM(EngineClient):
if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids)
await self.abort(request_ids, internal=True)
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests():
@ -633,6 +646,7 @@ class AsyncLLM(EngineClient):
TODO: Remove truncate_prompt_tokens in v0.15.
"""
q: RequestOutputCollector | None = None
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
@ -687,7 +701,8 @@ class AsyncLLM(EngineClient):
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
@ -706,7 +721,8 @@ class AsyncLLM(EngineClient):
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e

View File

@ -21,7 +21,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import (
@ -406,6 +406,19 @@ class InputProcessor:
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
@staticmethod
def assign_request_id(request: EngineCoreRequest):
"""Replace the externally supplied request ID with an internal request ID
that adds 8 random characters in order to ensure uniquness.
"""
if request.external_req_id is not None:
raise ValueError(
"The external_req_id field should not be set on EngineCoreRequests"
" passed to vLLM; use the request_id field."
)
request.external_req_id = request.request_id
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
def process_inputs(
self,
request_id: str,

View File

@ -213,10 +213,10 @@ class LLMEngine:
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None:
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
request_ids = self.output_processor.abort_requests(request_ids, internal)
self.engine_core.abort_requests(request_ids)
def add_request(
@ -238,6 +238,12 @@ class LLMEngine:
# Process raw inputs into the request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
if request_id != request.request_id:
logger.warning_once(
"AsyncLLM.add_request() was passed a request_id parameter that "
"does not match the EngineCoreRequest.request_id attribute. The "
"latter will be used, and the former will be ignored."
)
else:
assert prompt_text is None
request = self.input_processor.process_inputs(
@ -255,6 +261,8 @@ class LLMEngine:
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
self.input_processor.assign_request_id(request)
# Use cloned params that may have been updated in process_inputs()
params = request.params
@ -268,7 +276,7 @@ class LLMEngine:
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
parent_req = ParentRequest(request)
for idx in range(n):
request_id, child_params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
@ -40,8 +41,9 @@ class RequestOutputCollector:
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind):
def __init__(self, output_kind: RequestOutputKind, request_id: str):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.request_id = request_id
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
@ -92,6 +94,7 @@ class RequestState:
def __init__(
self,
request_id: str,
external_req_id: str,
parent_req: ParentRequest | None,
request_index: int,
lora_request: LoRARequest | None,
@ -111,6 +114,7 @@ class RequestState:
temperature: float | None = None,
):
self.request_id = request_id
self.external_req_id = external_req_id
self.parent_req = parent_req
self.request_index = request_index
self.lora_request = lora_request
@ -176,8 +180,10 @@ class RequestState:
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
assert request.external_req_id is not None
return cls(
request_id=request.request_id,
external_req_id=request.external_req_id,
parent_req=parent_req,
request_index=request_index,
lora_request=request.lora_request,
@ -235,10 +241,13 @@ class RequestState:
]
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
request_id = self.request_id
external_req_id = self.external_req_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)], finished
external_req_id,
[self._new_pooling_output(pooling_output)],
finished,
)
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
@ -246,19 +255,18 @@ class RequestState:
if self.parent_req is None:
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, output
)
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
if not outputs:
return None
external_req_id = self.parent_req.external_req_id
return self._new_request_output(
request_id, outputs, finished, kv_transfer_params
external_req_id, outputs, finished, kv_transfer_params
)
def _new_request_output(
self,
request_id: str,
external_req_id: str,
outputs: list[CompletionOutput] | list[PoolingOutput],
finished: bool,
kv_transfer_params: dict[str, Any] | None = None,
@ -269,7 +277,7 @@ class RequestState:
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput(
request_id=request_id,
request_id=external_req_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
@ -288,7 +296,7 @@ class RequestState:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput(
request_id=request_id,
request_id=external_req_id, # request_id is what was provided externally
lora_request=self.lora_request,
prompt=self.prompt,
prompt_token_ids=prompt_token_ids,
@ -352,6 +360,7 @@ class OutputProcessor:
self.stream_interval = stream_interval
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
self._requests_drained = asyncio.Event()
@ -375,12 +384,41 @@ class OutputProcessor:
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
"""Abort a list of requests.
The request_ids may be either external request IDs (those passed to
InputProcessor.process_inputs()) or internal request IDs (those randomly
generated when creating the EngineCoreRequest).
If an external request ID is provided, and that external request ID
was used for multiple requests, all requests associated with that external
request ID are aborted.
In the case of parallel sampling, a request ID may be used to identify
a parent request, in which case the associated child requests are aborted
also.
"""
internal_req_ids = []
for request_id in request_ids:
if internal:
# Internal ID - this may be a parent request
internal_req_ids.append(request_id)
# Remove internal ID from the external->internal mapping
if req_state := self.request_states.get(request_id):
external_req_id = req_state.external_req_id
internal_ids = self.external_req_ids[external_req_id]
internal_ids.remove(request_id)
if not internal_ids:
del self.external_req_ids[external_req_id]
elif internal_ids := self.external_req_ids.pop(request_id, []):
# External ID - abort all requests in the external->internal mapping
internal_req_ids.extend(internal_ids)
request_ids_to_abort = []
for request_id in internal_req_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.request_finished(request_id, req_state.lora_name)
@ -404,7 +442,7 @@ class OutputProcessor:
# Abort children prior to removing the parent.
if parent.child_requests:
child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs)
child_reqs = self.abort_requests(child_reqs, internal=True)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
if not self.request_states:
@ -439,6 +477,9 @@ class OutputProcessor:
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
# Track the external_req_id -> [internal_req_id, ...] mapping
self.external_req_ids[req_state.external_req_id].append(request_id)
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
@ -522,6 +563,12 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
@ -597,7 +644,9 @@ class OutputProcessor:
)
# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param:

View File

@ -6,6 +6,7 @@ from typing import Optional, cast
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import IterationStats
@ -17,6 +18,7 @@ class ParentRequest:
"""
request_id: str
external_req_id: str
sampling_params: SamplingParams
# To track the completion of child requests
@ -31,8 +33,11 @@ class ParentRequest:
# To efficiently obtain child sampling params
cached_child_sampling_params: SamplingParams | None
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
self.request_id = request_id
def __init__(self, request: EngineCoreRequest) -> None:
assert request.external_req_id is not None
sampling_params = request.params
self.request_id = request.request_id
self.external_req_id = request.external_req_id
self.sampling_params = sampling_params
self.child_requests = set()
@ -96,7 +101,7 @@ class ParentRequest:
self,
child_request_id: str,
completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
) -> tuple[list[CompletionOutput], bool]:
already_finished_and_returned: bool = False
if completion_output.finished():
if child_request_id in self.child_requests:
@ -118,7 +123,7 @@ class ParentRequest:
outputs = [] if self.child_requests else self.output_aggregator
finished = not self.child_requests
return self.request_id, outputs, finished
return outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(