mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 18:57:17 +08:00
Merge branch 'main' into elvischenv/update-flashinfer
This commit is contained in:
commit
30d32ef5c0
@ -2,4 +2,4 @@
|
||||
|
||||
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
|
||||
|
||||
Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
|
||||
You can use vLLM with KServe's [Hugging Face serving runtime](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) or via [`LLMInferenceService` that uses llm-d](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).
|
||||
|
||||
5
docs/deployment/integrations/llm-d.md
Normal file
5
docs/deployment/integrations/llm-d.md
Normal file
@ -0,0 +1,5 @@
|
||||
# llm-d
|
||||
|
||||
vLLM can be deployed with [llm-d](https://github.com/llm-d/llm-d), a Kubernetes-native distributed inference serving stack providing well-lit paths for anyone to serve large generative AI models at scale. It helps achieve the fastest "time to state-of-the-art (SOTA) performance" for key OSS models across most hardware accelerators and infrastructure providers.
|
||||
|
||||
You can use vLLM with llm-d directly by following [this guide](https://llm-d.ai/docs/guide) or via [KServe's LLMInferenceService](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).
|
||||
@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
|
||||
|
||||
- [Helm](frameworks/helm.md)
|
||||
- [InftyAI/llmaz](integrations/llmaz.md)
|
||||
- [llm-d](integrations/llm-d.md)
|
||||
- [KAITO](integrations/kaito.md)
|
||||
- [KServe](integrations/kserve.md)
|
||||
- [Kthena](integrations/kthena.md)
|
||||
|
||||
11
tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml
Normal file
11
tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
model_name: "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
|
||||
accuracy_threshold: 0.85
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
server_args: >-
|
||||
--max-model-len 4096
|
||||
--tensor-parallel-size 2
|
||||
--enable-expert-parallel
|
||||
--async-scheduling
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: "1"
|
||||
@ -4,3 +4,4 @@ Qwen1.5-MoE-W4A16-CT.yaml
|
||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||
Qwen3-30B-A3B-NVFP4.yaml
|
||||
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
||||
Qwen3-Next-FP8-EP2.yaml
|
||||
|
||||
@ -71,6 +71,7 @@ def test_gsm8k_correctness(config_filename):
|
||||
print(f"Number of questions: {eval_config['num_questions']}")
|
||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
print(f"Environment variables: {env_dict}")
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
|
||||
@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
|
||||
return
|
||||
|
||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||
# accuracy issues
|
||||
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
|
||||
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
|
||||
@ -106,6 +106,7 @@ class RemoteOpenAIServer:
|
||||
env.update(env_dict)
|
||||
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
||||
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
||||
print(f"Environment variables: {env}")
|
||||
self.proc: subprocess.Popen = subprocess.Popen(
|
||||
serve_cmd,
|
||||
env=env,
|
||||
|
||||
@ -260,7 +260,7 @@ async def test_multi_abort(output_kind: RequestOutputKind):
|
||||
|
||||
# Use multi-abort to abort multiple requests at once
|
||||
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
|
||||
await engine.abort(abort_request_ids)
|
||||
await engine.abort(abort_request_ids, internal=False)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@ -609,7 +609,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Abort the request
|
||||
await engine.abort(request_id)
|
||||
await engine.abort(request_id, internal=False)
|
||||
|
||||
# Wait for generation to complete and return final output
|
||||
final_output = await generated
|
||||
|
||||
@ -40,10 +40,16 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
PROMPT = "I am Gyoubu Masataka Oniwa"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
_REQUEST_COUNTER = 0
|
||||
|
||||
|
||||
def make_request() -> EngineCoreRequest:
|
||||
global _REQUEST_COUNTER
|
||||
_REQUEST_COUNTER += 1
|
||||
request_id = f"request-{_REQUEST_COUNTER}"
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
request_id=request_id,
|
||||
external_req_id=f"{request_id}-{uuid.uuid4()}",
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
mm_features=None,
|
||||
sampling_params=SamplingParams(),
|
||||
|
||||
@ -45,6 +45,8 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
_REQUEST_COUNTER = 0
|
||||
|
||||
|
||||
def make_request(
|
||||
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
|
||||
@ -52,8 +54,12 @@ def make_request(
|
||||
if not prompt_tokens_ids:
|
||||
prompt_tokens_ids = PROMPT_TOKENS
|
||||
|
||||
global _REQUEST_COUNTER
|
||||
_REQUEST_COUNTER += 1
|
||||
request_id = f"request-{_REQUEST_COUNTER}"
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
request_id=request_id,
|
||||
external_req_id=f"{request_id}-{uuid.uuid4()}",
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
|
||||
@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
params = SamplingParams(skip_special_tokens=True)
|
||||
request = EngineCoreRequest(
|
||||
request_id="test",
|
||||
external_req_id="test-ext",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
|
||||
@ -58,12 +58,12 @@ def test_incremental_detokenization(
|
||||
output_processor = OutputProcessor(
|
||||
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
|
||||
)
|
||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
request_id=f"request-{idx}-int",
|
||||
external_req_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -83,6 +83,11 @@ def test_incremental_detokenization(
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
request_ids=[req.request_id for req in requests],
|
||||
)
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||
output_processor.add_request(request, prompt)
|
||||
@ -438,15 +443,6 @@ def test_logprobs_processor(
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=None
|
||||
if num_sample_logprobs is None
|
||||
else dummy_test_vectors.generation_logprobs,
|
||||
prompt_logprobs_raw=None
|
||||
if num_prompt_logprobs is None
|
||||
else dummy_test_vectors.prompt_logprobs,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
@ -454,7 +450,8 @@ def test_logprobs_processor(
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=request_id_list[idx],
|
||||
request_id=request_id_list[idx] + "-int",
|
||||
external_req_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -476,6 +473,17 @@ def test_logprobs_processor(
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=None
|
||||
if num_sample_logprobs is None
|
||||
else dummy_test_vectors.generation_logprobs,
|
||||
prompt_logprobs_raw=None
|
||||
if num_prompt_logprobs is None
|
||||
else dummy_test_vectors.prompt_logprobs,
|
||||
request_ids=[req.request_id for req in requests],
|
||||
)
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||
output_processor.add_request(request, prompt)
|
||||
@ -621,19 +629,12 @@ def test_stop_token(
|
||||
]
|
||||
prompt_string = dummy_test_vectors.prompt_strings[0]
|
||||
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=[generation_tokens],
|
||||
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
|
||||
prompt_logprobs_raw=None,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
|
||||
# Make request.
|
||||
request_id = "request-0"
|
||||
request = EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
external_req_id=request_id + "-ext",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -655,6 +656,16 @@ def test_stop_token(
|
||||
pooling_params=None,
|
||||
)
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=[generation_tokens],
|
||||
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
|
||||
prompt_logprobs_raw=None,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids,
|
||||
ignore_eos=ignore_eos,
|
||||
request_ids=[request.request_id],
|
||||
)
|
||||
|
||||
# Add request to the detokenizer.
|
||||
output_processor.add_request(request, prompt_string)
|
||||
|
||||
@ -720,13 +731,6 @@ def test_stop_string(
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||
if num_sample_logprobs
|
||||
else None,
|
||||
prompt_logprobs_raw=None,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
@ -734,7 +738,8 @@ def test_stop_string(
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=request_id_list[idx],
|
||||
request_id=request_id_list[idx] + "-int",
|
||||
external_req_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -756,6 +761,15 @@ def test_stop_string(
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||
if num_sample_logprobs
|
||||
else None,
|
||||
prompt_logprobs_raw=None,
|
||||
request_ids=[req.request_id for req in requests],
|
||||
)
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||
output_processor.add_request(request, prompt)
|
||||
@ -813,9 +827,12 @@ def test_stop_string(
|
||||
for idx, (ref_gen_str, stop_str) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
|
||||
):
|
||||
# Request should be aborted.
|
||||
# Request should be aborted (check internal ID in abort list).
|
||||
internal_request_id = f"request-{idx}-int"
|
||||
assert internal_request_id in aborted
|
||||
|
||||
# Use external ID for collecting outputs
|
||||
request_id = f"request-{idx}"
|
||||
assert request_id in aborted
|
||||
|
||||
# Collected values that were generated.
|
||||
gen_str = gen_strings[request_id]
|
||||
@ -848,13 +865,13 @@ def test_stop_string(
|
||||
|
||||
def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
external_req_id=f"request-{idx}-ext",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
dummy_test_vectors.generation_tokens,
|
||||
request_ids=[req.request_id for req in requests],
|
||||
)
|
||||
|
||||
# Add all requests except one to the OutputProcessor.
|
||||
num_active = len(dummy_test_vectors.generation_tokens) - 1
|
||||
for request in requests[:num_active]:
|
||||
@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
output_processor = OutputProcessor(
|
||||
dummy_test_vectors.tokenizer, log_stats=log_stats
|
||||
)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Create LoRA requests
|
||||
@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
lora_assignments = [lora1, lora2, None]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
request_id=f"request-{idx}-int",
|
||||
external_req_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
engine_core = MockEngineCore(
|
||||
dummy_test_vectors.generation_tokens,
|
||||
request_ids=[req.request_id for req in requests],
|
||||
)
|
||||
|
||||
# Add all requests to the OutputProcessor
|
||||
for request in requests:
|
||||
output_processor.add_request(request, None)
|
||||
@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-0 as finished (it uses lora-1)
|
||||
# Find and mark request-0-int as finished (it uses lora-1)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-0":
|
||||
if output.request_id == "request-0-int":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-1 as finished (it uses lora-2)
|
||||
# Find and mark request-1-int as finished (it uses lora-2)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-1":
|
||||
if output.request_id == "request-1-int":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-2 as finished (it has no LoRA)
|
||||
# Find and mark request-2-int as finished (it has no LoRA)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-2":
|
||||
if output.request_id == "request-2-int":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
@ -1107,7 +1134,9 @@ async def test_request_output_collector():
|
||||
for idx in range(NUM_REQS)
|
||||
]
|
||||
|
||||
collector = RequestOutputCollector(RequestOutputKind.DELTA)
|
||||
collector = RequestOutputCollector(
|
||||
RequestOutputKind.DELTA, request_id="my-request-id-int"
|
||||
)
|
||||
|
||||
# CASE 1: Put then get.
|
||||
outputs = make_outputs()
|
||||
@ -1163,7 +1192,9 @@ async def test_request_output_collector():
|
||||
@pytest.mark.asyncio
|
||||
async def test_cumulative_output_collector_n():
|
||||
"""Test collector correctly handles multiple outputs by index."""
|
||||
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
|
||||
collector = RequestOutputCollector(
|
||||
RequestOutputKind.CUMULATIVE, request_id="my-request-id-int"
|
||||
)
|
||||
outputs = [
|
||||
RequestOutput(
|
||||
request_id="my-request-id",
|
||||
@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||
def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
@pytest.mark.parametrize("abort_by", ["internal", "external"])
|
||||
def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
external_req_id=f"external-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
output_kind = request.sampling_params.output_kind
|
||||
else:
|
||||
output_kind = request.pooling_params.output_kind
|
||||
queue = RequestOutputCollector(output_kind=output_kind)
|
||||
queue = RequestOutputCollector(
|
||||
output_kind=output_kind, request_id=request.request_id
|
||||
)
|
||||
output_processor.add_request(request, None, queue=queue)
|
||||
|
||||
for request in requests:
|
||||
output_processor.abort_requests([request.request_id])
|
||||
if abort_by == "internal":
|
||||
output_processor.abort_requests([request.request_id], internal=True)
|
||||
else:
|
||||
output_processor.abort_requests([request.external_req_id], internal=False)
|
||||
|
||||
@ -4,11 +4,12 @@
|
||||
from vllm import SamplingParams
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
|
||||
|
||||
def test_parent_request_to_output_stream() -> None:
|
||||
parent_request = ParentRequest("parent_id", SamplingParams(n=2))
|
||||
parent_request = ParentRequest(make_request(SamplingParams(n=2)))
|
||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||
output_0 = CompletionOutput(
|
||||
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
@ -17,51 +18,31 @@ def test_parent_request_to_output_stream() -> None:
|
||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
# Request not finished
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||
|
||||
# output_1 finished
|
||||
output_1.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||
# Finished output_1 had already returned, DO NOT returned again
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||
|
||||
# output_0 finished
|
||||
output_0.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0], True) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
||||
assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
|
||||
# Finished output_0 had already returned, DO NOT returned again
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == ([], True)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
|
||||
|
||||
|
||||
def test_parent_request_to_output_final_only() -> None:
|
||||
parent_request = ParentRequest(
|
||||
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY)
|
||||
make_request(SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY))
|
||||
)
|
||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||
output_0 = CompletionOutput(
|
||||
@ -71,33 +52,33 @@ def test_parent_request_to_output_final_only() -> None:
|
||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
# Request not finished, return nothing
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||
# output_1 finished, but outputs won't be returned until all child requests finished
|
||||
output_1.finish_reason = "ended"
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||
# output_0 finished, as all child requests finished, the output would be returned
|
||||
output_0.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
||||
assert ([output_0, output_1], True) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
||||
assert ([output_0, output_1], True) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
|
||||
|
||||
def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id="parent_id",
|
||||
external_req_id="ext_parent_id",
|
||||
prompt_token_ids=None,
|
||||
mm_features=None,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.multimodal import MultiModalUUIDDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import input_processor as input_processor_mod
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
captured: dict[str, MultiModalUUIDDict] = {}
|
||||
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
@ -196,7 +197,16 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
)
|
||||
|
||||
# Expect request-id-based overrides are passed through
|
||||
assert captured["mm_uuids"] == {
|
||||
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
|
||||
"video": [f"{request_id}-video-0"],
|
||||
}
|
||||
mm_uuids = captured["mm_uuids"]
|
||||
assert set(mm_uuids.keys()) == {"image", "video"}
|
||||
assert len(mm_uuids["image"]) == 2
|
||||
assert len(mm_uuids["video"]) == 1
|
||||
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
|
||||
"image"
|
||||
][0].endswith("-0")
|
||||
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
|
||||
"image"
|
||||
][1].endswith("-1")
|
||||
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
|
||||
"video"
|
||||
][0].endswith("-0")
|
||||
|
||||
@ -343,6 +343,7 @@ class MockEngineCore:
|
||||
eos_token_id: int | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
ignore_eos: bool = False,
|
||||
request_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
self.num_requests = len(tokens_list)
|
||||
self.tokens_list = tokens_list
|
||||
@ -355,6 +356,11 @@ class MockEngineCore:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.ignore_eos = ignore_eos
|
||||
self.request_ids = (
|
||||
request_ids
|
||||
if request_ids is not None
|
||||
else [f"request-{i}" for i in range(self.num_requests)]
|
||||
)
|
||||
|
||||
def get_outputs(self) -> list[EngineCoreOutput]:
|
||||
do_logprobs = self.do_logprobs
|
||||
@ -386,7 +392,7 @@ class MockEngineCore:
|
||||
prompt_logprobs = None
|
||||
new_token_id = token_ids[token_idx]
|
||||
output = EngineCoreOutput(
|
||||
request_id=f"request-{req_idx}",
|
||||
request_id=self.request_ids[req_idx],
|
||||
new_token_ids=[new_token_id],
|
||||
new_logprobs=logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs,
|
||||
|
||||
@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@ -1265,6 +1268,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
run_test_and_cleanup()
|
||||
|
||||
|
||||
class RequestIdMapper:
|
||||
"""Helper class to map external request IDs to internal request IDs."""
|
||||
|
||||
def __init__(self, output_processor: OutputProcessor):
|
||||
self.req_id_mapping: dict[str, str] = {}
|
||||
self.original_add_request = output_processor.add_request
|
||||
output_processor.add_request = self._add_request
|
||||
|
||||
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
|
||||
self.req_id_mapping[request.external_req_id] = request.request_id
|
||||
return self.original_add_request(request, *args, **kwargs)
|
||||
|
||||
def __call__(self, external_req_id: str) -> str:
|
||||
return self.req_id_mapping[external_req_id]
|
||||
|
||||
|
||||
def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
"""Helper function to run the abort timeout test logic."""
|
||||
remote_prefill_opts = {
|
||||
@ -1286,24 +1305,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
0
|
||||
].req_to_blocks
|
||||
|
||||
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
|
||||
|
||||
def req_id(outputs: list[RequestOutput]) -> str:
|
||||
assert len(outputs) == 1
|
||||
return id_mapper(outputs[0].request_id)
|
||||
|
||||
padding = "Just making this request a little longer so that we're sure "
|
||||
"we're not hitting the small-request lower bound beneath which we don't "
|
||||
"actually trigger the whole kv transfer, but rather just recompute the "
|
||||
"blocks on D."
|
||||
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
|
||||
req0_id = req_id(
|
||||
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
|
||||
)
|
||||
|
||||
# Request finished but not freed
|
||||
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
|
||||
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
|
||||
# Some other request, 0 still not freed
|
||||
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
|
||||
assert "0" in req_to_blocks
|
||||
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
|
||||
req1_id = req_id(
|
||||
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
|
||||
)
|
||||
assert req0_id in req_to_blocks
|
||||
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
|
||||
|
||||
# Wait for timeout and trigger another scheduler loop
|
||||
time.sleep(timeout)
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert "0" not in req_to_blocks
|
||||
assert req0_id not in req_to_blocks
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
|
||||
@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.view(bsz, q_len, -1)
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
|
||||
fa_version=self._fa_version,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.view(bsz, q_len, -1)
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
|
||||
@ -1621,7 +1621,7 @@ class LLM:
|
||||
added_request_ids.append(request_id)
|
||||
except Exception as e:
|
||||
if added_request_ids:
|
||||
self.llm_engine.abort_request(added_request_ids)
|
||||
self.llm_engine.abort_request(added_request_ids, internal=True)
|
||||
raise e
|
||||
|
||||
def _validate_mm_data_and_uuids(
|
||||
@ -1731,7 +1731,7 @@ class LLM:
|
||||
priority=priority,
|
||||
prompt_text=prompt_text,
|
||||
)
|
||||
return request_id
|
||||
return engine_request.request_id
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool | Callable[..., tqdm] = True
|
||||
|
||||
@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
BCx, _ = self.in_proj(hidden_states)
|
||||
|
||||
@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
||||
if has_prefill
|
||||
else None
|
||||
)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
|
||||
@ -3,17 +3,11 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc_p: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
||||
# We should consolidate this code
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
if num_prefills > 0:
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
metadata_cls = Mamba1AttentionMetadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@ -1,19 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc_p: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
seq_idx_p: torch.Tensor | None
|
||||
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||
prep_initial_states: bool = False
|
||||
chunk_size: int = 0
|
||||
|
||||
# Chunk-related metadata (only for prefill)
|
||||
seq_idx_p: torch.Tensor | None = None
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None
|
||||
|
||||
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
last_chunk_indices_p: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
supports_update_block_table: bool = True
|
||||
metadata_cls = Mamba2AttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -150,87 +128,93 @@ class Mamba2AttentionMetadataBuilder(
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
|
||||
def _compute_chunk_metadata(
|
||||
self,
|
||||
num_prefills: int,
|
||||
num_computed_tokens_p_cpu: torch.Tensor,
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Compute chunk-specific metadata for Mamba2.
|
||||
|
||||
The code below carefully constructs the chunks such that:
|
||||
1. Chunks contain tokens from a *single* sequence only.
|
||||
2. For every sequence, we are guaranteed that we can
|
||||
retrieve the mamba state *every* chunk_size tokens.
|
||||
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
Constraint (2) dramatically simplifies the implementation
|
||||
of prefix caching for mamba2 (wip). We need to take care
|
||||
of the interaction with chunked prefill in order to
|
||||
satisfy constraint (2).
|
||||
"""
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
common = self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
query_start_loc_p = None
|
||||
seq_idx_p = None
|
||||
cu_chunk_seqlen_p = None
|
||||
last_chunk_indices_p = None
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
# Additional cache-related varaiables:
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Compute seq_idx for prefill only
|
||||
if num_prefills > 0:
|
||||
# [batch,]
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
if common.num_prefills > 0:
|
||||
prep_initial_states = (
|
||||
torch.any(common.has_initial_states_p).item()
|
||||
if common.has_initial_states_p is not None
|
||||
else False
|
||||
)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
num_reqs = common.num_reqs
|
||||
num_prefills = common.num_prefills
|
||||
num_decode_tokens = common.num_decode_tokens
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
@ -239,137 +223,33 @@ class Mamba2AttentionMetadataBuilder(
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
# The code below carefully constructs the chunks such that:
|
||||
# 1. Chunks contain tokens from a *single* sequence only.
|
||||
# 2. For every sequence, we are guaranteed that we can
|
||||
# retrieve the mamba state *every* chunk_size tokens.
|
||||
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
# Constraint (2) dramatically simplifies the implementation
|
||||
# of prefix caching for mamba2 (wip). We need to take care
|
||||
# of the interaction with chunked prefill in order to
|
||||
# satisfy constraint (2).
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||
num_prefills,
|
||||
num_computed_tokens_p_cpu,
|
||||
query_start_loc_p_cpu,
|
||||
)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
|
||||
seq_idx,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
|
||||
cu_chunk_seqlen,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
|
||||
last_chunk_indices,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
seq_lens=seq_lens,
|
||||
return replace(
|
||||
common,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
seq_idx_p=seq_idx_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: Mamba2AttentionMetadata,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and num_reqs <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
||||
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
||||
state_indices_t = persistent_state_indices_t
|
||||
|
||||
new_metadata.state_indices_tensor = state_indices_t
|
||||
return new_metadata
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
@ -9,20 +11,52 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M")
|
||||
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMambaAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_reqs: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
num_computed_tokens_p: torch.Tensor | None
|
||||
|
||||
state_indices_tensor: torch.Tensor
|
||||
|
||||
# The following tensors are only used for prefix caching and are None if disabled
|
||||
block_idx_last_scheduled_token: torch.Tensor | None
|
||||
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||
block_idx_last_computed_token: torch.Tensor | None
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
metadata_cls: type[M]
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> M:
|
||||
"""
|
||||
Default build implementation for Mamba-like attention backends.
|
||||
Subclasses (e.g., Mamba2) can override to add additional metadata.
|
||||
"""
|
||||
return self._compute_common_metadata(common_attn_metadata)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
|
||||
def _compute_common_metadata(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> M:
|
||||
"""
|
||||
Compute metadata common to both Mamba1 and Mamba2.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
num_computed_tokens = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
# for prefix caching
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
block_idx_last_computed_token = None
|
||||
block_idx_last_scheduled_token = None
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
if num_prefills > 0:
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
return self.metadata_cls(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
num_reqs=num_reqs,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and num_reqs <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
||||
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
||||
state_indices_t = persistent_state_indices_t
|
||||
|
||||
new_metadata.state_indices_tensor = state_indices_t
|
||||
return new_metadata
|
||||
|
||||
@ -2,15 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
|
||||
@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
||||
):
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> ShortConvAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
if num_prefills > 0:
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
metadata_cls = ShortConvAttentionMetadata
|
||||
|
||||
@ -75,6 +75,12 @@ class EngineCoreRequest(
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
|
||||
# The user-provided request ID. This field is set internally,
|
||||
# copied from the provided request_id that's originally assigned
|
||||
# to the request_id field, see InputProcessor.assign_request_id().
|
||||
# Used in outputs and to support abort(req_id, internal=False).
|
||||
external_req_id: str | None = None
|
||||
|
||||
@property
|
||||
def params(self) -> SamplingParams | PoolingParams:
|
||||
"""Return the processed params (sampling or pooling)."""
|
||||
|
||||
@ -290,12 +290,15 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
is_pooling = isinstance(params, PoolingParams)
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
if request_id != request.request_id:
|
||||
logger.warning_once(
|
||||
"AsyncLLM.add_request() was passed a request_id parameter that "
|
||||
"does not match the EngineCoreRequest.request_id attribute. The "
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
assert prompt_text is None
|
||||
request = self.input_processor.process_inputs(
|
||||
@ -314,6 +317,11 @@ class AsyncLLM(EngineClient):
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(params.output_kind, request.request_id)
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
@ -325,7 +333,7 @@ class AsyncLLM(EngineClient):
|
||||
assert isinstance(parent_params, SamplingParams)
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, parent_params)
|
||||
parent_request = ParentRequest(request)
|
||||
for idx in range(parent_params.n):
|
||||
request_id, child_params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == parent_params.n - 1 else copy(request)
|
||||
@ -396,6 +404,7 @@ class AsyncLLM(EngineClient):
|
||||
"prompt logprobs"
|
||||
)
|
||||
|
||||
q: RequestOutputCollector | None = None
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
@ -446,7 +455,8 @@ class AsyncLLM(EngineClient):
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if q is not None:
|
||||
await self.abort(q.request_id, internal=True)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
@ -465,7 +475,8 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if q is not None:
|
||||
await self.abort(q.request_id, internal=True)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
@ -541,13 +552,15 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.output_handler = asyncio.create_task(output_handler())
|
||||
|
||||
async def abort(self, request_id: str | Iterable[str]) -> None:
|
||||
async def abort(
|
||||
self, request_id: str | Iterable[str], internal: bool = False
|
||||
) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = (
|
||||
(request_id,) if isinstance(request_id, str) else as_list(request_id)
|
||||
)
|
||||
all_request_ids = self.output_processor.abort_requests(request_ids)
|
||||
all_request_ids = self.output_processor.abort_requests(request_ids, internal)
|
||||
await self.engine_core.abort_requests_async(all_request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
@ -581,7 +594,7 @@ class AsyncLLM(EngineClient):
|
||||
if not wait_for_inflight_requests:
|
||||
request_ids = list(self.output_processor.request_states.keys())
|
||||
if request_ids:
|
||||
await self.abort(request_ids)
|
||||
await self.abort(request_ids, internal=True)
|
||||
|
||||
# Wait for running requests to drain before clearing cache.
|
||||
if self.output_processor.has_unfinished_requests():
|
||||
@ -633,6 +646,7 @@ class AsyncLLM(EngineClient):
|
||||
TODO: Remove truncate_prompt_tokens in v0.15.
|
||||
"""
|
||||
|
||||
q: RequestOutputCollector | None = None
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
@ -687,7 +701,8 @@ class AsyncLLM(EngineClient):
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled. So, we abort the request if we end up here.
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
if q is not None:
|
||||
await self.abort(q.request_id, internal=True)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
@ -706,7 +721,8 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if q is not None:
|
||||
await self.abort(q.request_id, internal=True)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
@ -406,6 +406,19 @@ class InputProcessor:
|
||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||
return mm_uuids
|
||||
|
||||
@staticmethod
|
||||
def assign_request_id(request: EngineCoreRequest):
|
||||
"""Replace the externally supplied request ID with an internal request ID
|
||||
that adds 8 random characters in order to ensure uniquness.
|
||||
"""
|
||||
if request.external_req_id is not None:
|
||||
raise ValueError(
|
||||
"The external_req_id field should not be set on EngineCoreRequests"
|
||||
" passed to vLLM; use the request_id field."
|
||||
)
|
||||
request.external_req_id = request.request_id
|
||||
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
||||
@ -213,10 +213,10 @@ class LLMEngine:
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.engine_core.get_supported_tasks()
|
||||
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
request_ids = self.output_processor.abort_requests(request_ids, internal)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
@ -238,6 +238,12 @@ class LLMEngine:
|
||||
# Process raw inputs into the request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
if request_id != request.request_id:
|
||||
logger.warning_once(
|
||||
"AsyncLLM.add_request() was passed a request_id parameter that "
|
||||
"does not match the EngineCoreRequest.request_id attribute. The "
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
assert prompt_text is None
|
||||
request = self.input_processor.process_inputs(
|
||||
@ -255,6 +261,8 @@ class LLMEngine:
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
@ -268,7 +276,7 @@ class LLMEngine:
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
parent_req = ParentRequest(request)
|
||||
for idx in range(n):
|
||||
request_id, child_params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
@ -40,8 +41,9 @@ class RequestOutputCollector:
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
def __init__(self, output_kind: RequestOutputKind, request_id: str):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.request_id = request_id
|
||||
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
@ -92,6 +94,7 @@ class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
external_req_id: str,
|
||||
parent_req: ParentRequest | None,
|
||||
request_index: int,
|
||||
lora_request: LoRARequest | None,
|
||||
@ -111,6 +114,7 @@ class RequestState:
|
||||
temperature: float | None = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.external_req_id = external_req_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_request = lora_request
|
||||
@ -176,8 +180,10 @@ class RequestState:
|
||||
assert request.pooling_params is not None
|
||||
output_kind = request.pooling_params.output_kind
|
||||
|
||||
assert request.external_req_id is not None
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
external_req_id=request.external_req_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_request=request.lora_request,
|
||||
@ -235,10 +241,13 @@ class RequestState:
|
||||
]
|
||||
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
||||
|
||||
request_id = self.request_id
|
||||
external_req_id = self.external_req_id
|
||||
|
||||
if pooling_output is not None:
|
||||
return self._new_request_output(
|
||||
request_id, [self._new_pooling_output(pooling_output)], finished
|
||||
external_req_id,
|
||||
[self._new_pooling_output(pooling_output)],
|
||||
finished,
|
||||
)
|
||||
|
||||
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
||||
@ -246,19 +255,18 @@ class RequestState:
|
||||
if self.parent_req is None:
|
||||
outputs = [output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, output
|
||||
)
|
||||
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
|
||||
if not outputs:
|
||||
return None
|
||||
external_req_id = self.parent_req.external_req_id
|
||||
|
||||
return self._new_request_output(
|
||||
request_id, outputs, finished, kv_transfer_params
|
||||
external_req_id, outputs, finished, kv_transfer_params
|
||||
)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
external_req_id: str,
|
||||
outputs: list[CompletionOutput] | list[PoolingOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
@ -269,7 +277,7 @@ class RequestState:
|
||||
# Prompt embeddings are currently not supported by pooling requests.
|
||||
assert self.prompt_token_ids is not None
|
||||
return PoolingRequestOutput(
|
||||
request_id=request_id,
|
||||
request_id=external_req_id,
|
||||
outputs=first_output,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
@ -288,7 +296,7 @@ class RequestState:
|
||||
prompt_token_ids = [0] * len(self.prompt_embeds)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
request_id=external_req_id, # request_id is what was provided externally
|
||||
lora_request=self.lora_request,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
@ -352,6 +360,7 @@ class OutputProcessor:
|
||||
self.stream_interval = stream_interval
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracer: Tracer | None = None
|
||||
self._requests_drained = asyncio.Event()
|
||||
@ -375,12 +384,41 @@ class OutputProcessor:
|
||||
assert state.queue is not None
|
||||
state.queue.put(e)
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
|
||||
"""Abort a list of requests.
|
||||
|
||||
The request_ids may be either external request IDs (those passed to
|
||||
InputProcessor.process_inputs()) or internal request IDs (those randomly
|
||||
generated when creating the EngineCoreRequest).
|
||||
|
||||
If an external request ID is provided, and that external request ID
|
||||
was used for multiple requests, all requests associated with that external
|
||||
request ID are aborted.
|
||||
|
||||
In the case of parallel sampling, a request ID may be used to identify
|
||||
a parent request, in which case the associated child requests are aborted
|
||||
also.
|
||||
"""
|
||||
|
||||
internal_req_ids = []
|
||||
for request_id in request_ids:
|
||||
if internal:
|
||||
# Internal ID - this may be a parent request
|
||||
internal_req_ids.append(request_id)
|
||||
|
||||
# Remove internal ID from the external->internal mapping
|
||||
if req_state := self.request_states.get(request_id):
|
||||
external_req_id = req_state.external_req_id
|
||||
internal_ids = self.external_req_ids[external_req_id]
|
||||
internal_ids.remove(request_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[external_req_id]
|
||||
elif internal_ids := self.external_req_ids.pop(request_id, []):
|
||||
# External ID - abort all requests in the external->internal mapping
|
||||
internal_req_ids.extend(internal_ids)
|
||||
|
||||
request_ids_to_abort = []
|
||||
for request_id in internal_req_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.request_finished(request_id, req_state.lora_name)
|
||||
@ -404,7 +442,7 @@ class OutputProcessor:
|
||||
# Abort children prior to removing the parent.
|
||||
if parent.child_requests:
|
||||
child_reqs = list(parent.child_requests)
|
||||
child_reqs = self.abort_requests(child_reqs)
|
||||
child_reqs = self.abort_requests(child_reqs, internal=True)
|
||||
request_ids_to_abort.extend(child_reqs)
|
||||
self.parent_requests.pop(request_id, None)
|
||||
if not self.request_states:
|
||||
@ -439,6 +477,9 @@ class OutputProcessor:
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
# Track the external_req_id -> [internal_req_id, ...] mapping
|
||||
self.external_req_ids[req_state.external_req_id].append(request_id)
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
@ -522,6 +563,12 @@ class OutputProcessor:
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
|
||||
internal_ids = self.external_req_ids[req_state.external_req_id]
|
||||
internal_ids.remove(req_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[req_state.external_req_id]
|
||||
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
@ -597,7 +644,9 @@ class OutputProcessor:
|
||||
)
|
||||
|
||||
# meta
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
|
||||
)
|
||||
if req_state.top_p:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
|
||||
if req_state.max_tokens_param:
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional, cast
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@ -17,6 +18,7 @@ class ParentRequest:
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
external_req_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To track the completion of child requests
|
||||
@ -31,8 +33,11 @@ class ParentRequest:
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: SamplingParams | None
|
||||
|
||||
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
def __init__(self, request: EngineCoreRequest) -> None:
|
||||
assert request.external_req_id is not None
|
||||
sampling_params = request.params
|
||||
self.request_id = request.request_id
|
||||
self.external_req_id = request.external_req_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
@ -96,7 +101,7 @@ class ParentRequest:
|
||||
self,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
) -> tuple[list[CompletionOutput], bool]:
|
||||
already_finished_and_returned: bool = False
|
||||
if completion_output.finished():
|
||||
if child_request_id in self.child_requests:
|
||||
@ -118,7 +123,7 @@ class ParentRequest:
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
return outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user