mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 06:44:29 +08:00
Merge branch 'main' into elvischenv/update-flashinfer
This commit is contained in:
commit
30d32ef5c0
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
|
vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving.
|
||||||
|
|
||||||
Please see [this guide](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) for more details on using vLLM with KServe.
|
You can use vLLM with KServe's [Hugging Face serving runtime](https://kserve.github.io/website/docs/model-serving/generative-inference/overview) or via [`LLMInferenceService` that uses llm-d](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).
|
||||||
|
|||||||
5
docs/deployment/integrations/llm-d.md
Normal file
5
docs/deployment/integrations/llm-d.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# llm-d
|
||||||
|
|
||||||
|
vLLM can be deployed with [llm-d](https://github.com/llm-d/llm-d), a Kubernetes-native distributed inference serving stack providing well-lit paths for anyone to serve large generative AI models at scale. It helps achieve the fastest "time to state-of-the-art (SOTA) performance" for key OSS models across most hardware accelerators and infrastructure providers.
|
||||||
|
|
||||||
|
You can use vLLM with llm-d directly by following [this guide](https://llm-d.ai/docs/guide) or via [KServe's LLMInferenceService](https://kserve.github.io/website/docs/model-serving/generative-inference/llmisvc/llmisvc-overview).
|
||||||
@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
|
|||||||
|
|
||||||
- [Helm](frameworks/helm.md)
|
- [Helm](frameworks/helm.md)
|
||||||
- [InftyAI/llmaz](integrations/llmaz.md)
|
- [InftyAI/llmaz](integrations/llmaz.md)
|
||||||
|
- [llm-d](integrations/llm-d.md)
|
||||||
- [KAITO](integrations/kaito.md)
|
- [KAITO](integrations/kaito.md)
|
||||||
- [KServe](integrations/kserve.md)
|
- [KServe](integrations/kserve.md)
|
||||||
- [Kthena](integrations/kthena.md)
|
- [Kthena](integrations/kthena.md)
|
||||||
|
|||||||
11
tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml
Normal file
11
tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
model_name: "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
|
||||||
|
accuracy_threshold: 0.85
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
server_args: >-
|
||||||
|
--max-model-len 4096
|
||||||
|
--tensor-parallel-size 2
|
||||||
|
--enable-expert-parallel
|
||||||
|
--async-scheduling
|
||||||
|
env:
|
||||||
|
VLLM_USE_FLASHINFER_MOE_FP8: "1"
|
||||||
@ -4,3 +4,4 @@ Qwen1.5-MoE-W4A16-CT.yaml
|
|||||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||||
Qwen3-30B-A3B-NVFP4.yaml
|
Qwen3-30B-A3B-NVFP4.yaml
|
||||||
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
||||||
|
Qwen3-Next-FP8-EP2.yaml
|
||||||
|
|||||||
@ -71,6 +71,7 @@ def test_gsm8k_correctness(config_filename):
|
|||||||
print(f"Number of questions: {eval_config['num_questions']}")
|
print(f"Number of questions: {eval_config['num_questions']}")
|
||||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||||
print(f"Server args: {' '.join(server_args)}")
|
print(f"Server args: {' '.join(server_args)}")
|
||||||
|
print(f"Environment variables: {env_dict}")
|
||||||
|
|
||||||
# Launch server and run evaluation
|
# Launch server and run evaluation
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||||
# accuracy issues
|
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
|
||||||
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
||||||
torch.backends.cuda.enable_flash_sdp(False)
|
torch.backends.cuda.enable_flash_sdp(False)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||||
|
|||||||
@ -106,6 +106,7 @@ class RemoteOpenAIServer:
|
|||||||
env.update(env_dict)
|
env.update(env_dict)
|
||||||
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
||||||
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
||||||
|
print(f"Environment variables: {env}")
|
||||||
self.proc: subprocess.Popen = subprocess.Popen(
|
self.proc: subprocess.Popen = subprocess.Popen(
|
||||||
serve_cmd,
|
serve_cmd,
|
||||||
env=env,
|
env=env,
|
||||||
|
|||||||
@ -260,7 +260,7 @@ async def test_multi_abort(output_kind: RequestOutputKind):
|
|||||||
|
|
||||||
# Use multi-abort to abort multiple requests at once
|
# Use multi-abort to abort multiple requests at once
|
||||||
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
|
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
|
||||||
await engine.abort(abort_request_ids)
|
await engine.abort(abort_request_ids, internal=False)
|
||||||
|
|
||||||
# Wait for all tasks to complete
|
# Wait for all tasks to complete
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@ -609,7 +609,7 @@ async def test_abort_final_output(output_kind: RequestOutputKind):
|
|||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
# Abort the request
|
# Abort the request
|
||||||
await engine.abort(request_id)
|
await engine.abort(request_id, internal=False)
|
||||||
|
|
||||||
# Wait for generation to complete and return final output
|
# Wait for generation to complete and return final output
|
||||||
final_output = await generated
|
final_output = await generated
|
||||||
|
|||||||
@ -40,10 +40,16 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|||||||
PROMPT = "I am Gyoubu Masataka Oniwa"
|
PROMPT = "I am Gyoubu Masataka Oniwa"
|
||||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||||
|
|
||||||
|
_REQUEST_COUNTER = 0
|
||||||
|
|
||||||
|
|
||||||
def make_request() -> EngineCoreRequest:
|
def make_request() -> EngineCoreRequest:
|
||||||
|
global _REQUEST_COUNTER
|
||||||
|
_REQUEST_COUNTER += 1
|
||||||
|
request_id = f"request-{_REQUEST_COUNTER}"
|
||||||
return EngineCoreRequest(
|
return EngineCoreRequest(
|
||||||
request_id=str(uuid.uuid4()),
|
request_id=request_id,
|
||||||
|
external_req_id=f"{request_id}-{uuid.uuid4()}",
|
||||||
prompt_token_ids=PROMPT_TOKENS,
|
prompt_token_ids=PROMPT_TOKENS,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
|||||||
@ -45,6 +45,8 @@ TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|||||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||||
|
|
||||||
|
_REQUEST_COUNTER = 0
|
||||||
|
|
||||||
|
|
||||||
def make_request(
|
def make_request(
|
||||||
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
|
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
|
||||||
@ -52,8 +54,12 @@ def make_request(
|
|||||||
if not prompt_tokens_ids:
|
if not prompt_tokens_ids:
|
||||||
prompt_tokens_ids = PROMPT_TOKENS
|
prompt_tokens_ids = PROMPT_TOKENS
|
||||||
|
|
||||||
|
global _REQUEST_COUNTER
|
||||||
|
_REQUEST_COUNTER += 1
|
||||||
|
request_id = f"request-{_REQUEST_COUNTER}"
|
||||||
return EngineCoreRequest(
|
return EngineCoreRequest(
|
||||||
request_id=str(uuid.uuid4()),
|
request_id=request_id,
|
||||||
|
external_req_id=f"{request_id}-{uuid.uuid4()}",
|
||||||
prompt_token_ids=prompt_tokens_ids,
|
prompt_token_ids=prompt_tokens_ids,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
|||||||
params = SamplingParams(skip_special_tokens=True)
|
params = SamplingParams(skip_special_tokens=True)
|
||||||
request = EngineCoreRequest(
|
request = EngineCoreRequest(
|
||||||
request_id="test",
|
request_id="test",
|
||||||
|
external_req_id="test-ext",
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
|
|||||||
@ -58,12 +58,12 @@ def test_incremental_detokenization(
|
|||||||
output_processor = OutputProcessor(
|
output_processor = OutputProcessor(
|
||||||
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
|
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
|
||||||
)
|
)
|
||||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=f"request-{idx}",
|
request_id=f"request-{idx}-int",
|
||||||
|
external_req_id=f"request-{idx}",
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -83,6 +83,11 @@ def test_incremental_detokenization(
|
|||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
tokens_list=dummy_test_vectors.generation_tokens,
|
||||||
|
request_ids=[req.request_id for req in requests],
|
||||||
|
)
|
||||||
|
|
||||||
# Add requests to the detokenizer.
|
# Add requests to the detokenizer.
|
||||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||||
output_processor.add_request(request, prompt)
|
output_processor.add_request(request, prompt)
|
||||||
@ -438,15 +443,6 @@ def test_logprobs_processor(
|
|||||||
dummy_test_vectors,
|
dummy_test_vectors,
|
||||||
):
|
):
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||||
engine_core = MockEngineCore(
|
|
||||||
tokens_list=dummy_test_vectors.generation_tokens,
|
|
||||||
generated_logprobs_raw=None
|
|
||||||
if num_sample_logprobs is None
|
|
||||||
else dummy_test_vectors.generation_logprobs,
|
|
||||||
prompt_logprobs_raw=None
|
|
||||||
if num_prompt_logprobs is None
|
|
||||||
else dummy_test_vectors.prompt_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
request_id_list = [
|
request_id_list = [
|
||||||
@ -454,7 +450,8 @@ def test_logprobs_processor(
|
|||||||
]
|
]
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=request_id_list[idx],
|
request_id=request_id_list[idx] + "-int",
|
||||||
|
external_req_id=request_id_list[idx],
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -476,6 +473,17 @@ def test_logprobs_processor(
|
|||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
tokens_list=dummy_test_vectors.generation_tokens,
|
||||||
|
generated_logprobs_raw=None
|
||||||
|
if num_sample_logprobs is None
|
||||||
|
else dummy_test_vectors.generation_logprobs,
|
||||||
|
prompt_logprobs_raw=None
|
||||||
|
if num_prompt_logprobs is None
|
||||||
|
else dummy_test_vectors.prompt_logprobs,
|
||||||
|
request_ids=[req.request_id for req in requests],
|
||||||
|
)
|
||||||
|
|
||||||
# Add requests to the detokenizer.
|
# Add requests to the detokenizer.
|
||||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||||
output_processor.add_request(request, prompt)
|
output_processor.add_request(request, prompt)
|
||||||
@ -621,19 +629,12 @@ def test_stop_token(
|
|||||||
]
|
]
|
||||||
prompt_string = dummy_test_vectors.prompt_strings[0]
|
prompt_string = dummy_test_vectors.prompt_strings[0]
|
||||||
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
||||||
engine_core = MockEngineCore(
|
|
||||||
tokens_list=[generation_tokens],
|
|
||||||
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
|
|
||||||
prompt_logprobs_raw=None,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
ignore_eos=ignore_eos,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make request.
|
# Make request.
|
||||||
request_id = "request-0"
|
request_id = "request-0"
|
||||||
request = EngineCoreRequest(
|
request = EngineCoreRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
external_req_id=request_id + "-ext",
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
@ -655,6 +656,16 @@ def test_stop_token(
|
|||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
tokens_list=[generation_tokens],
|
||||||
|
generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
|
||||||
|
prompt_logprobs_raw=None,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
|
request_ids=[request.request_id],
|
||||||
|
)
|
||||||
|
|
||||||
# Add request to the detokenizer.
|
# Add request to the detokenizer.
|
||||||
output_processor.add_request(request, prompt_string)
|
output_processor.add_request(request, prompt_string)
|
||||||
|
|
||||||
@ -720,13 +731,6 @@ def test_stop_string(
|
|||||||
dummy_test_vectors,
|
dummy_test_vectors,
|
||||||
):
|
):
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||||
engine_core = MockEngineCore(
|
|
||||||
tokens_list=dummy_test_vectors.generation_tokens,
|
|
||||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
|
||||||
if num_sample_logprobs
|
|
||||||
else None,
|
|
||||||
prompt_logprobs_raw=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
request_id_list = [
|
request_id_list = [
|
||||||
@ -734,7 +738,8 @@ def test_stop_string(
|
|||||||
]
|
]
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=request_id_list[idx],
|
request_id=request_id_list[idx] + "-int",
|
||||||
|
external_req_id=request_id_list[idx],
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -756,6 +761,15 @@ def test_stop_string(
|
|||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
tokens_list=dummy_test_vectors.generation_tokens,
|
||||||
|
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||||
|
if num_sample_logprobs
|
||||||
|
else None,
|
||||||
|
prompt_logprobs_raw=None,
|
||||||
|
request_ids=[req.request_id for req in requests],
|
||||||
|
)
|
||||||
|
|
||||||
# Add requests to the detokenizer.
|
# Add requests to the detokenizer.
|
||||||
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
|
||||||
output_processor.add_request(request, prompt)
|
output_processor.add_request(request, prompt)
|
||||||
@ -813,9 +827,12 @@ def test_stop_string(
|
|||||||
for idx, (ref_gen_str, stop_str) in enumerate(
|
for idx, (ref_gen_str, stop_str) in enumerate(
|
||||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
|
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
|
||||||
):
|
):
|
||||||
# Request should be aborted.
|
# Request should be aborted (check internal ID in abort list).
|
||||||
|
internal_request_id = f"request-{idx}-int"
|
||||||
|
assert internal_request_id in aborted
|
||||||
|
|
||||||
|
# Use external ID for collecting outputs
|
||||||
request_id = f"request-{idx}"
|
request_id = f"request-{idx}"
|
||||||
assert request_id in aborted
|
|
||||||
|
|
||||||
# Collected values that were generated.
|
# Collected values that were generated.
|
||||||
gen_str = gen_strings[request_id]
|
gen_str = gen_strings[request_id]
|
||||||
@ -848,13 +865,13 @@ def test_stop_string(
|
|||||||
|
|
||||||
def test_iteration_stats(dummy_test_vectors):
|
def test_iteration_stats(dummy_test_vectors):
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
|
||||||
engine_core_timestamp = time.monotonic()
|
engine_core_timestamp = time.monotonic()
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=f"request-{idx}",
|
request_id=f"request-{idx}",
|
||||||
|
external_req_id=f"request-{idx}-ext",
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -868,6 +885,11 @@ def test_iteration_stats(dummy_test_vectors):
|
|||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
dummy_test_vectors.generation_tokens,
|
||||||
|
request_ids=[req.request_id for req in requests],
|
||||||
|
)
|
||||||
|
|
||||||
# Add all requests except one to the OutputProcessor.
|
# Add all requests except one to the OutputProcessor.
|
||||||
num_active = len(dummy_test_vectors.generation_tokens) - 1
|
num_active = len(dummy_test_vectors.generation_tokens) - 1
|
||||||
for request in requests[:num_active]:
|
for request in requests[:num_active]:
|
||||||
@ -922,7 +944,6 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
output_processor = OutputProcessor(
|
output_processor = OutputProcessor(
|
||||||
dummy_test_vectors.tokenizer, log_stats=log_stats
|
dummy_test_vectors.tokenizer, log_stats=log_stats
|
||||||
)
|
)
|
||||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
|
||||||
engine_core_timestamp = time.monotonic()
|
engine_core_timestamp = time.monotonic()
|
||||||
|
|
||||||
# Create LoRA requests
|
# Create LoRA requests
|
||||||
@ -936,7 +957,8 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
lora_assignments = [lora1, lora2, None]
|
lora_assignments = [lora1, lora2, None]
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=f"request-{idx}",
|
request_id=f"request-{idx}-int",
|
||||||
|
external_req_id=f"request-{idx}",
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -950,6 +972,11 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
engine_core = MockEngineCore(
|
||||||
|
dummy_test_vectors.generation_tokens,
|
||||||
|
request_ids=[req.request_id for req in requests],
|
||||||
|
)
|
||||||
|
|
||||||
# Add all requests to the OutputProcessor
|
# Add all requests to the OutputProcessor
|
||||||
for request in requests:
|
for request in requests:
|
||||||
output_processor.add_request(request, None)
|
output_processor.add_request(request, None)
|
||||||
@ -1015,9 +1042,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
outputs = EngineCoreOutputs(
|
outputs = EngineCoreOutputs(
|
||||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||||
)
|
)
|
||||||
# Find and mark request-0 as finished (it uses lora-1)
|
# Find and mark request-0-int as finished (it uses lora-1)
|
||||||
for output in outputs.outputs:
|
for output in outputs.outputs:
|
||||||
if output.request_id == "request-0":
|
if output.request_id == "request-0-int":
|
||||||
output.finish_reason = FinishReason.LENGTH
|
output.finish_reason = FinishReason.LENGTH
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -1040,9 +1067,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
outputs = EngineCoreOutputs(
|
outputs = EngineCoreOutputs(
|
||||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||||
)
|
)
|
||||||
# Find and mark request-1 as finished (it uses lora-2)
|
# Find and mark request-1-int as finished (it uses lora-2)
|
||||||
for output in outputs.outputs:
|
for output in outputs.outputs:
|
||||||
if output.request_id == "request-1":
|
if output.request_id == "request-1-int":
|
||||||
output.finish_reason = FinishReason.LENGTH
|
output.finish_reason = FinishReason.LENGTH
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -1064,9 +1091,9 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
|||||||
outputs = EngineCoreOutputs(
|
outputs = EngineCoreOutputs(
|
||||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||||
)
|
)
|
||||||
# Find and mark request-2 as finished (it has no LoRA)
|
# Find and mark request-2-int as finished (it has no LoRA)
|
||||||
for output in outputs.outputs:
|
for output in outputs.outputs:
|
||||||
if output.request_id == "request-2":
|
if output.request_id == "request-2-int":
|
||||||
output.finish_reason = FinishReason.LENGTH
|
output.finish_reason = FinishReason.LENGTH
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -1107,7 +1134,9 @@ async def test_request_output_collector():
|
|||||||
for idx in range(NUM_REQS)
|
for idx in range(NUM_REQS)
|
||||||
]
|
]
|
||||||
|
|
||||||
collector = RequestOutputCollector(RequestOutputKind.DELTA)
|
collector = RequestOutputCollector(
|
||||||
|
RequestOutputKind.DELTA, request_id="my-request-id-int"
|
||||||
|
)
|
||||||
|
|
||||||
# CASE 1: Put then get.
|
# CASE 1: Put then get.
|
||||||
outputs = make_outputs()
|
outputs = make_outputs()
|
||||||
@ -1163,7 +1192,9 @@ async def test_request_output_collector():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cumulative_output_collector_n():
|
async def test_cumulative_output_collector_n():
|
||||||
"""Test collector correctly handles multiple outputs by index."""
|
"""Test collector correctly handles multiple outputs by index."""
|
||||||
collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
|
collector = RequestOutputCollector(
|
||||||
|
RequestOutputKind.CUMULATIVE, request_id="my-request-id-int"
|
||||||
|
)
|
||||||
outputs = [
|
outputs = [
|
||||||
RequestOutput(
|
RequestOutput(
|
||||||
request_id="my-request-id",
|
request_id="my-request-id",
|
||||||
@ -1242,11 +1273,13 @@ async def test_cumulative_output_collector_n():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||||
def test_abort_requests(runner: str, dummy_test_vectors):
|
@pytest.mark.parametrize("abort_by", ["internal", "external"])
|
||||||
|
def test_abort_requests(runner: str, abort_by: str, dummy_test_vectors):
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||||
requests = [
|
requests = [
|
||||||
EngineCoreRequest(
|
EngineCoreRequest(
|
||||||
request_id=f"request-{idx}",
|
request_id=f"request-{idx}",
|
||||||
|
external_req_id=f"external-{idx}",
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
mm_features=None,
|
mm_features=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
@ -1265,8 +1298,13 @@ def test_abort_requests(runner: str, dummy_test_vectors):
|
|||||||
output_kind = request.sampling_params.output_kind
|
output_kind = request.sampling_params.output_kind
|
||||||
else:
|
else:
|
||||||
output_kind = request.pooling_params.output_kind
|
output_kind = request.pooling_params.output_kind
|
||||||
queue = RequestOutputCollector(output_kind=output_kind)
|
queue = RequestOutputCollector(
|
||||||
|
output_kind=output_kind, request_id=request.request_id
|
||||||
|
)
|
||||||
output_processor.add_request(request, None, queue=queue)
|
output_processor.add_request(request, None, queue=queue)
|
||||||
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
output_processor.abort_requests([request.request_id])
|
if abort_by == "internal":
|
||||||
|
output_processor.abort_requests([request.request_id], internal=True)
|
||||||
|
else:
|
||||||
|
output_processor.abort_requests([request.external_req_id], internal=False)
|
||||||
|
|||||||
@ -4,11 +4,12 @@
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.outputs import CompletionOutput
|
from vllm.outputs import CompletionOutput
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
|
|
||||||
|
|
||||||
def test_parent_request_to_output_stream() -> None:
|
def test_parent_request_to_output_stream() -> None:
|
||||||
parent_request = ParentRequest("parent_id", SamplingParams(n=2))
|
parent_request = ParentRequest(make_request(SamplingParams(n=2)))
|
||||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||||
output_0 = CompletionOutput(
|
output_0 = CompletionOutput(
|
||||||
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
|
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||||
@ -17,51 +18,31 @@ def test_parent_request_to_output_stream() -> None:
|
|||||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||||
)
|
)
|
||||||
# Request not finished
|
# Request not finished
|
||||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||||
"child_id_0", output_0
|
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||||
)
|
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||||
"child_id_1", output_1
|
|
||||||
)
|
|
||||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
|
||||||
"child_id_0", output_0
|
|
||||||
)
|
|
||||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
|
||||||
"child_id_1", output_1
|
|
||||||
)
|
|
||||||
|
|
||||||
# output_1 finished
|
# output_1 finished
|
||||||
output_1.finish_reason = "ended"
|
output_1.finish_reason = "ended"
|
||||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||||
"child_id_0", output_0
|
assert ([output_1], False) == parent_request.get_outputs("child_id_1", output_1)
|
||||||
)
|
|
||||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
|
||||||
"child_id_1", output_1
|
|
||||||
)
|
|
||||||
# Finished output_1 had already returned, DO NOT returned again
|
# Finished output_1 had already returned, DO NOT returned again
|
||||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
assert ([output_0], False) == parent_request.get_outputs("child_id_0", output_0)
|
||||||
"child_id_0", output_0
|
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||||
)
|
|
||||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
|
||||||
"parent_id",
|
|
||||||
[],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# output_0 finished
|
# output_0 finished
|
||||||
output_0.finish_reason = "ended"
|
output_0.finish_reason = "ended"
|
||||||
assert ("parent_id", [output_0], True) == parent_request.get_outputs(
|
assert ([output_0], True) == parent_request.get_outputs("child_id_0", output_0)
|
||||||
"child_id_0", output_0
|
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
|
||||||
)
|
|
||||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
|
||||||
# Finished output_0 had already returned, DO NOT returned again
|
# Finished output_0 had already returned, DO NOT returned again
|
||||||
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True)
|
assert parent_request.get_outputs("child_id_0", output_0) == ([], True)
|
||||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
assert parent_request.get_outputs("child_id_1", output_1) == ([], True)
|
||||||
|
|
||||||
|
|
||||||
def test_parent_request_to_output_final_only() -> None:
|
def test_parent_request_to_output_final_only() -> None:
|
||||||
parent_request = ParentRequest(
|
parent_request = ParentRequest(
|
||||||
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY)
|
make_request(SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY))
|
||||||
)
|
)
|
||||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||||
output_0 = CompletionOutput(
|
output_0 = CompletionOutput(
|
||||||
@ -71,33 +52,33 @@ def test_parent_request_to_output_final_only() -> None:
|
|||||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||||
)
|
)
|
||||||
# Request not finished, return nothing
|
# Request not finished, return nothing
|
||||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
|
||||||
"parent_id",
|
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||||
[],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
|
||||||
"parent_id",
|
|
||||||
[],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
# output_1 finished, but outputs won't be returned until all child requests finished
|
# output_1 finished, but outputs won't be returned until all child requests finished
|
||||||
output_1.finish_reason = "ended"
|
output_1.finish_reason = "ended"
|
||||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
assert parent_request.get_outputs("child_id_0", output_0) == ([], False)
|
||||||
"parent_id",
|
assert parent_request.get_outputs("child_id_1", output_1) == ([], False)
|
||||||
[],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
|
||||||
"parent_id",
|
|
||||||
[],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
# output_0 finished, as all child requests finished, the output would be returned
|
# output_0 finished, as all child requests finished, the output would be returned
|
||||||
output_0.finish_reason = "ended"
|
output_0.finish_reason = "ended"
|
||||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
assert ([output_0, output_1], True) == parent_request.get_outputs(
|
||||||
"child_id_0", output_0
|
"child_id_0", output_0
|
||||||
)
|
)
|
||||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
assert ([output_0, output_1], True) == parent_request.get_outputs(
|
||||||
"child_id_1", output_1
|
"child_id_1", output_1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_request(sampling_params: SamplingParams) -> EngineCoreRequest:
|
||||||
|
return EngineCoreRequest(
|
||||||
|
request_id="parent_id",
|
||||||
|
external_req_id="ext_parent_id",
|
||||||
|
prompt_token_ids=None,
|
||||||
|
mm_features=None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
arrival_time=0.0,
|
||||||
|
lora_request=None,
|
||||||
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
|
)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||||
|
from vllm.multimodal import MultiModalUUIDDict
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.engine import input_processor as input_processor_mod
|
from vllm.v1.engine import input_processor as input_processor_mod
|
||||||
from vllm.v1.engine.input_processor import InputProcessor
|
from vllm.v1.engine.input_processor import InputProcessor
|
||||||
@ -166,7 +167,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
|||||||
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
|
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
|
||||||
)
|
)
|
||||||
|
|
||||||
captured: dict[str, object] = {}
|
captured: dict[str, MultiModalUUIDDict] = {}
|
||||||
|
|
||||||
def fake_preprocess(
|
def fake_preprocess(
|
||||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||||
@ -196,7 +197,16 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Expect request-id-based overrides are passed through
|
# Expect request-id-based overrides are passed through
|
||||||
assert captured["mm_uuids"] == {
|
mm_uuids = captured["mm_uuids"]
|
||||||
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
|
assert set(mm_uuids.keys()) == {"image", "video"}
|
||||||
"video": [f"{request_id}-video-0"],
|
assert len(mm_uuids["image"]) == 2
|
||||||
}
|
assert len(mm_uuids["video"]) == 1
|
||||||
|
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
|
||||||
|
"image"
|
||||||
|
][0].endswith("-0")
|
||||||
|
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
|
||||||
|
"image"
|
||||||
|
][1].endswith("-1")
|
||||||
|
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
|
||||||
|
"video"
|
||||||
|
][0].endswith("-0")
|
||||||
|
|||||||
@ -343,6 +343,7 @@ class MockEngineCore:
|
|||||||
eos_token_id: int | None = None,
|
eos_token_id: int | None = None,
|
||||||
stop_token_ids: list[int] | None = None,
|
stop_token_ids: list[int] | None = None,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
|
request_ids: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_requests = len(tokens_list)
|
self.num_requests = len(tokens_list)
|
||||||
self.tokens_list = tokens_list
|
self.tokens_list = tokens_list
|
||||||
@ -355,6 +356,11 @@ class MockEngineCore:
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_token_ids = stop_token_ids
|
self.stop_token_ids = stop_token_ids
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
|
self.request_ids = (
|
||||||
|
request_ids
|
||||||
|
if request_ids is not None
|
||||||
|
else [f"request-{i}" for i in range(self.num_requests)]
|
||||||
|
)
|
||||||
|
|
||||||
def get_outputs(self) -> list[EngineCoreOutput]:
|
def get_outputs(self) -> list[EngineCoreOutput]:
|
||||||
do_logprobs = self.do_logprobs
|
do_logprobs = self.do_logprobs
|
||||||
@ -386,7 +392,7 @@ class MockEngineCore:
|
|||||||
prompt_logprobs = None
|
prompt_logprobs = None
|
||||||
new_token_id = token_ids[token_idx]
|
new_token_id = token_ids[token_idx]
|
||||||
output = EngineCoreOutput(
|
output = EngineCoreOutput(
|
||||||
request_id=f"request-{req_idx}",
|
request_id=self.request_ids[req_idx],
|
||||||
new_token_ids=[new_token_id],
|
new_token_ids=[new_token_id],
|
||||||
new_logprobs=logprobs,
|
new_logprobs=logprobs,
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs,
|
new_prompt_logprobs_tensors=prompt_logprobs,
|
||||||
|
|||||||
@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
|
|||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import Platform
|
from vllm.platforms.interface import Platform
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
@ -1265,6 +1268,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
|||||||
run_test_and_cleanup()
|
run_test_and_cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class RequestIdMapper:
|
||||||
|
"""Helper class to map external request IDs to internal request IDs."""
|
||||||
|
|
||||||
|
def __init__(self, output_processor: OutputProcessor):
|
||||||
|
self.req_id_mapping: dict[str, str] = {}
|
||||||
|
self.original_add_request = output_processor.add_request
|
||||||
|
output_processor.add_request = self._add_request
|
||||||
|
|
||||||
|
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
|
||||||
|
self.req_id_mapping[request.external_req_id] = request.request_id
|
||||||
|
return self.original_add_request(request, *args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, external_req_id: str) -> str:
|
||||||
|
return self.req_id_mapping[external_req_id]
|
||||||
|
|
||||||
|
|
||||||
def _run_abort_timeout_test(llm: LLM, timeout: int):
|
def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||||
"""Helper function to run the abort timeout test logic."""
|
"""Helper function to run the abort timeout test logic."""
|
||||||
remote_prefill_opts = {
|
remote_prefill_opts = {
|
||||||
@ -1286,24 +1305,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
|||||||
0
|
0
|
||||||
].req_to_blocks
|
].req_to_blocks
|
||||||
|
|
||||||
|
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
|
||||||
|
|
||||||
|
def req_id(outputs: list[RequestOutput]) -> str:
|
||||||
|
assert len(outputs) == 1
|
||||||
|
return id_mapper(outputs[0].request_id)
|
||||||
|
|
||||||
padding = "Just making this request a little longer so that we're sure "
|
padding = "Just making this request a little longer so that we're sure "
|
||||||
"we're not hitting the small-request lower bound beneath which we don't "
|
"we're not hitting the small-request lower bound beneath which we don't "
|
||||||
"actually trigger the whole kv transfer, but rather just recompute the "
|
"actually trigger the whole kv transfer, but rather just recompute the "
|
||||||
"blocks on D."
|
"blocks on D."
|
||||||
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
|
req0_id = req_id(
|
||||||
|
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
|
||||||
|
)
|
||||||
|
|
||||||
# Request finished but not freed
|
# Request finished but not freed
|
||||||
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
|
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
|
||||||
# Some other request, 0 still not freed
|
# Some other request, 0 still not freed
|
||||||
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
|
req1_id = req_id(
|
||||||
assert "0" in req_to_blocks
|
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
|
||||||
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
|
)
|
||||||
|
assert req0_id in req_to_blocks
|
||||||
|
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
|
||||||
|
|
||||||
# Wait for timeout and trigger another scheduler loop
|
# Wait for timeout and trigger another scheduler loop
|
||||||
time.sleep(timeout)
|
time.sleep(timeout)
|
||||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||||
# Request-0 times out and is cleared!
|
# Request-0 times out and is cleared!
|
||||||
assert "0" not in req_to_blocks
|
assert req0_id not in req_to_blocks
|
||||||
# Need to shutdown the background thread to release NIXL side channel port
|
# Need to shutdown the background thread to release NIXL side channel port
|
||||||
llm.llm_engine.engine_core.shutdown()
|
llm.llm_engine.engine_core.shutdown()
|
||||||
|
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
if is_reshaped:
|
if is_reshaped:
|
||||||
output = output.view(bsz, q_len, -1)
|
output = output.reshape(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_fa(
|
def _forward_fa(
|
||||||
@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
|
|||||||
fa_version=self._fa_version,
|
fa_version=self._fa_version,
|
||||||
)
|
)
|
||||||
if is_reshaped:
|
if is_reshaped:
|
||||||
output = output.view(bsz, q_len, -1)
|
output = output.reshape(bsz, q_len, -1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
@ -1621,7 +1621,7 @@ class LLM:
|
|||||||
added_request_ids.append(request_id)
|
added_request_ids.append(request_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if added_request_ids:
|
if added_request_ids:
|
||||||
self.llm_engine.abort_request(added_request_ids)
|
self.llm_engine.abort_request(added_request_ids, internal=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _validate_mm_data_and_uuids(
|
def _validate_mm_data_and_uuids(
|
||||||
@ -1731,7 +1731,7 @@ class LLM:
|
|||||||
priority=priority,
|
priority=priority,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
)
|
)
|
||||||
return request_id
|
return engine_request.request_id
|
||||||
|
|
||||||
def _run_engine(
|
def _run_engine(
|
||||||
self, *, use_tqdm: bool | Callable[..., tqdm] = True
|
self, *, use_tqdm: bool | Callable[..., tqdm] = True
|
||||||
|
|||||||
@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
|
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||||
|
|
||||||
BCx, _ = self.in_proj(hidden_states)
|
BCx, _ = self.in_proj(hidden_states)
|
||||||
|
|
||||||
@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
[num_decodes, num_prefills],
|
[num_decodes, num_prefills],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
query_start_loc_p = (
|
|
||||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
|
||||||
if has_prefill
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_output_list = []
|
conv_output_list = []
|
||||||
|
|
||||||
|
|||||||
@ -3,17 +3,11 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
from vllm.config import VllmConfig
|
BaseMambaAttentionMetadata,
|
||||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
BaseMambaAttentionMetadataBuilder,
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
CommonAttentionMetadata,
|
|
||||||
split_decodes_and_prefills,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba1AttentionBackend(AttentionBackend):
|
class Mamba1AttentionBackend(AttentionBackend):
|
||||||
@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Mamba1AttentionMetadata:
|
class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
|
||||||
query_start_loc_p: torch.Tensor
|
pass
|
||||||
state_indices_tensor: torch.Tensor
|
|
||||||
has_initial_states_p: torch.Tensor | None
|
|
||||||
num_prefills: int
|
|
||||||
num_prefill_tokens: int
|
|
||||||
num_decodes: int
|
|
||||||
num_decode_tokens: int
|
|
||||||
|
|
||||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
|
||||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
|
||||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
|
||||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba1AttentionMetadataBuilder(
|
class Mamba1AttentionMetadataBuilder(
|
||||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||||
):
|
):
|
||||||
def __init__(
|
metadata_cls = Mamba1AttentionMetadata
|
||||||
self,
|
supports_update_block_table: bool = False
|
||||||
kv_cache_spec: AttentionSpec,
|
|
||||||
layer_names: list[str],
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
|
||||||
|
|
||||||
def build(
|
|
||||||
self,
|
|
||||||
common_prefix_len: int,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
fast_build: bool = False,
|
|
||||||
) -> Mamba1AttentionMetadata:
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
||||||
split_decodes_and_prefills(
|
|
||||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
has_initial_states_p = None
|
|
||||||
query_start_loc_p = None
|
|
||||||
num_computed_tokens, num_computed_tokens_p = None, None
|
|
||||||
block_idx_first_scheduled_token = None
|
|
||||||
block_idx_first_scheduled_token_p = None
|
|
||||||
|
|
||||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
|
||||||
# We should consolidate this code
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
# Return a tensor of shape (#requests, #max blocks)
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
|
||||||
mamba_block_size = self.kv_cache_spec.block_size
|
|
||||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
(
|
|
||||||
block_idx_last_computed_token,
|
|
||||||
block_idx_first_scheduled_token,
|
|
||||||
block_idx_last_scheduled_token,
|
|
||||||
) = self._compute_prefix_caching_block_indices(
|
|
||||||
common_attn_metadata, mamba_block_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Always return just a single block per each request:
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
|
||||||
block_idx_last_scheduled_token = None
|
|
||||||
block_idx_last_computed_token = None
|
|
||||||
|
|
||||||
if num_prefills > 0:
|
|
||||||
query_start_loc_p = (
|
|
||||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
|
||||||
- num_decode_tokens
|
|
||||||
)
|
|
||||||
has_initial_states_cpu = (
|
|
||||||
common_attn_metadata.num_computed_tokens_cpu[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
> 0
|
|
||||||
)
|
|
||||||
has_initial_states_p = has_initial_states_cpu.to(
|
|
||||||
common_attn_metadata.query_start_loc.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
assert num_computed_tokens is not None
|
|
||||||
num_computed_tokens_p = num_computed_tokens[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
assert block_idx_first_scheduled_token is not None
|
|
||||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
|
|
||||||
elif (
|
|
||||||
num_decodes > 0
|
|
||||||
and num_decodes <= self.decode_cudagraph_max_bs
|
|
||||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
||||||
):
|
|
||||||
self.state_indices_tensor[:num_decodes].copy_(
|
|
||||||
state_indices_tensor, non_blocking=True
|
|
||||||
)
|
|
||||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
|
||||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
|
||||||
block_idx_last_scheduled_token, non_blocking=True
|
|
||||||
)
|
|
||||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
|
||||||
:num_decode_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
|
||||||
block_idx_last_computed_token, non_blocking=True
|
|
||||||
)
|
|
||||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
|
||||||
:num_decode_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
return Mamba1AttentionMetadata(
|
|
||||||
query_start_loc_p=query_start_loc_p,
|
|
||||||
has_initial_states_p=has_initial_states_p,
|
|
||||||
state_indices_tensor=state_indices_tensor,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
|
||||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
|
||||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
|
||||||
num_computed_tokens_p=num_computed_tokens_p,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,19 +1,19 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
|
BaseMambaAttentionMetadata,
|
||||||
|
BaseMambaAttentionMetadataBuilder,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
compute_causal_conv1d_metadata,
|
|
||||||
split_decodes_and_prefills,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Mamba2AttentionMetadata:
|
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
||||||
num_prefills: int
|
prep_initial_states: bool = False
|
||||||
num_prefill_tokens: int
|
chunk_size: int = 0
|
||||||
num_decodes: int
|
|
||||||
num_decode_tokens: int
|
|
||||||
query_start_loc_p: torch.Tensor
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
|
|
||||||
prep_initial_states: bool
|
|
||||||
chunk_size: int
|
|
||||||
|
|
||||||
# The following tensors only contain prefill requests and will be None if
|
|
||||||
# the batch has no prefill request.
|
|
||||||
has_initial_states_p: torch.Tensor | None
|
|
||||||
seq_idx_p: torch.Tensor | None
|
|
||||||
|
|
||||||
|
# Chunk-related metadata (only for prefill)
|
||||||
|
seq_idx_p: torch.Tensor | None = None
|
||||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||||
# each chunk, its offests into the varlen sequence dimension. It is defined
|
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||||
# cu_chunk_seqlen_p[i+1].
|
# cu_chunk_seqlen_p[i+1].
|
||||||
cu_chunk_seqlen_p: torch.Tensor | None
|
cu_chunk_seqlen_p: torch.Tensor | None = None
|
||||||
|
|
||||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||||
# index of the last chunk for every sequence in the (prefill) batch.
|
# index of the last chunk for every sequence in the (prefill) batch.
|
||||||
last_chunk_indices_p: torch.Tensor | None
|
last_chunk_indices_p: torch.Tensor | None = None
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
|
||||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
|
||||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
|
||||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
|
||||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
|
||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
|
||||||
nums_dict: dict | None = None
|
|
||||||
batch_ptr: torch.Tensor | None = None
|
|
||||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||||
):
|
):
|
||||||
supports_update_block_table: bool = True
|
metadata_cls = Mamba2AttentionMetadata
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -150,87 +128,93 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_chunk_metadata(
|
||||||
|
self,
|
||||||
|
num_prefills: int,
|
||||||
|
num_computed_tokens_p_cpu: torch.Tensor,
|
||||||
|
query_start_loc_p_cpu: torch.Tensor,
|
||||||
|
) -> tuple[list[int], list[int], list[int]]:
|
||||||
|
"""
|
||||||
|
Compute chunk-specific metadata for Mamba2.
|
||||||
|
|
||||||
|
The code below carefully constructs the chunks such that:
|
||||||
|
1. Chunks contain tokens from a *single* sequence only.
|
||||||
|
2. For every sequence, we are guaranteed that we can
|
||||||
|
retrieve the mamba state *every* chunk_size tokens.
|
||||||
|
Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||||
|
Constraint (2) dramatically simplifies the implementation
|
||||||
|
of prefix caching for mamba2 (wip). We need to take care
|
||||||
|
of the interaction with chunked prefill in order to
|
||||||
|
satisfy constraint (2).
|
||||||
|
"""
|
||||||
|
# TODO (tdoublep): This code could probably be optimized.
|
||||||
|
cu_chunk_seqlen = []
|
||||||
|
seq_idx = []
|
||||||
|
last_chunk_indices = []
|
||||||
|
seqlen_pos = 0
|
||||||
|
|
||||||
|
for req_idx in range(num_prefills):
|
||||||
|
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||||
|
this_new_tokens = (
|
||||||
|
query_start_loc_p_cpu[req_idx + 1].item()
|
||||||
|
- query_start_loc_p_cpu[req_idx].item()
|
||||||
|
)
|
||||||
|
|
||||||
|
# if computed tokens are not chunk-aligned, use the first
|
||||||
|
# chunk to finish it off
|
||||||
|
if this_num_computed % self.chunk_size != 0:
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
# how many tokens to finish the chunk?
|
||||||
|
chunk_len = (
|
||||||
|
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||||
|
- this_num_computed
|
||||||
|
)
|
||||||
|
# we can only use at most this_new_tokens
|
||||||
|
chunk_len = min(chunk_len, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||||
|
for chunk in range(n_chunks):
|
||||||
|
seq_idx.append(req_idx)
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||||
|
seqlen_pos += chunk_len
|
||||||
|
this_new_tokens -= chunk_len
|
||||||
|
|
||||||
|
assert this_new_tokens == 0
|
||||||
|
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||||
|
|
||||||
|
cu_chunk_seqlen.append(seqlen_pos)
|
||||||
|
|
||||||
|
return cu_chunk_seqlen, seq_idx, last_chunk_indices
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False,
|
||||||
) -> Mamba2AttentionMetadata:
|
) -> Mamba2AttentionMetadata:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
common = self._compute_common_metadata(common_attn_metadata)
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
|
||||||
|
|
||||||
query_start_loc_p = None
|
|
||||||
seq_idx_p = None
|
seq_idx_p = None
|
||||||
cu_chunk_seqlen_p = None
|
cu_chunk_seqlen_p = None
|
||||||
last_chunk_indices_p = None
|
last_chunk_indices_p = None
|
||||||
|
|
||||||
# Need flags to indicate if there are initial states
|
|
||||||
has_initial_states_p = None
|
|
||||||
prep_initial_states = False
|
prep_initial_states = False
|
||||||
|
|
||||||
# for causal_conv1d
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
|
||||||
|
|
||||||
num_computed_tokens, num_computed_tokens_p = None, None
|
|
||||||
block_idx_first_scheduled_token = None
|
|
||||||
block_idx_first_scheduled_token_p = None
|
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
# Return a tensor of shape (#requests, #max blocks)
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
|
||||||
# Additional cache-related varaiables:
|
|
||||||
mamba_block_size = self.kv_cache_spec.block_size
|
|
||||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
(
|
|
||||||
block_idx_last_computed_token,
|
|
||||||
block_idx_first_scheduled_token,
|
|
||||||
block_idx_last_scheduled_token,
|
|
||||||
) = self._compute_prefix_caching_block_indices(
|
|
||||||
common_attn_metadata, mamba_block_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Always return just a single block per each request:
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
|
||||||
# Additional cache-related varaiables:
|
|
||||||
block_idx_last_scheduled_token = None
|
|
||||||
block_idx_last_computed_token = None
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
||||||
split_decodes_and_prefills(
|
|
||||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute seq_idx for prefill only
|
# Compute seq_idx for prefill only
|
||||||
if num_prefills > 0:
|
if common.num_prefills > 0:
|
||||||
# [batch,]
|
prep_initial_states = (
|
||||||
has_initial_states_cpu = (
|
torch.any(common.has_initial_states_p).item()
|
||||||
common_attn_metadata.num_computed_tokens_cpu[
|
if common.has_initial_states_p is not None
|
||||||
num_reqs - num_prefills : num_reqs
|
else False
|
||||||
]
|
|
||||||
> 0
|
|
||||||
)
|
|
||||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
|
||||||
has_initial_states_p = has_initial_states_cpu.to(
|
|
||||||
common_attn_metadata.query_start_loc.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_start_loc_p = (
|
num_reqs = common.num_reqs
|
||||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
num_prefills = common.num_prefills
|
||||||
- num_decode_tokens
|
num_decode_tokens = common.num_decode_tokens
|
||||||
)
|
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
assert num_computed_tokens is not None
|
|
||||||
num_computed_tokens_p = num_computed_tokens[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
assert block_idx_first_scheduled_token is not None
|
|
||||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||||
num_reqs - num_prefills : num_reqs
|
num_reqs - num_prefills : num_reqs
|
||||||
]
|
]
|
||||||
@ -239,137 +223,33 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
- num_decode_tokens
|
- num_decode_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# The code below carefully constructs the chunks such that:
|
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
|
||||||
# 1. Chunks contain tokens from a *single* sequence only.
|
num_prefills,
|
||||||
# 2. For every sequence, we are guaranteed that we can
|
num_computed_tokens_p_cpu,
|
||||||
# retrieve the mamba state *every* chunk_size tokens.
|
query_start_loc_p_cpu,
|
||||||
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
)
|
||||||
# Constraint (2) dramatically simplifies the implementation
|
|
||||||
# of prefix caching for mamba2 (wip). We need to take care
|
|
||||||
# of the interaction with chunked prefill in order to
|
|
||||||
# satisfy constraint (2).
|
|
||||||
# TODO (tdoublep): This code could probably be optimized.
|
|
||||||
cu_chunk_seqlen = []
|
|
||||||
seq_idx = []
|
|
||||||
last_chunk_indices = []
|
|
||||||
seqlen_pos = 0
|
|
||||||
for req_idx in range(num_prefills):
|
|
||||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
|
||||||
this_new_tokens = (
|
|
||||||
query_start_loc_p_cpu[req_idx + 1].item()
|
|
||||||
- query_start_loc_p_cpu[req_idx].item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# if computed tokens are not chunk-aligned, use the first
|
|
||||||
# chunk to finish it off
|
|
||||||
if this_num_computed % self.chunk_size != 0:
|
|
||||||
seq_idx.append(req_idx)
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
# how many tokens to finish the chunk?
|
|
||||||
chunk_len = (
|
|
||||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
|
||||||
- this_num_computed
|
|
||||||
)
|
|
||||||
# we can only use at most this_new_tokens
|
|
||||||
chunk_len = min(chunk_len, this_new_tokens)
|
|
||||||
seqlen_pos += chunk_len
|
|
||||||
this_new_tokens -= chunk_len
|
|
||||||
|
|
||||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
|
||||||
for chunk in range(n_chunks):
|
|
||||||
seq_idx.append(req_idx)
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
|
||||||
seqlen_pos += chunk_len
|
|
||||||
this_new_tokens -= chunk_len
|
|
||||||
|
|
||||||
assert this_new_tokens == 0
|
|
||||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
|
||||||
|
|
||||||
cu_chunk_seqlen.append(seqlen_pos)
|
|
||||||
|
|
||||||
seq_idx_p = torch.as_tensor(
|
seq_idx_p = torch.as_tensor(
|
||||||
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
|
seq_idx,
|
||||||
|
device=common_attn_metadata.query_start_loc.device,
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
cu_chunk_seqlen_p = torch.as_tensor(
|
cu_chunk_seqlen_p = torch.as_tensor(
|
||||||
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
|
cu_chunk_seqlen,
|
||||||
|
device=common_attn_metadata.query_start_loc.device,
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
last_chunk_indices_p = torch.as_tensor(
|
last_chunk_indices_p = torch.as_tensor(
|
||||||
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
|
last_chunk_indices,
|
||||||
|
device=common_attn_metadata.query_start_loc.device,
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
return replace(
|
||||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
common,
|
||||||
)
|
|
||||||
|
|
||||||
elif (
|
|
||||||
num_decodes <= self.decode_cudagraph_max_bs
|
|
||||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
||||||
):
|
|
||||||
self.state_indices_tensor[:num_decodes].copy_(
|
|
||||||
state_indices_tensor, non_blocking=True
|
|
||||||
)
|
|
||||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
|
||||||
|
|
||||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
||||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
|
||||||
block_idx_last_scheduled_token, non_blocking=True
|
|
||||||
)
|
|
||||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
|
||||||
:num_decode_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
|
||||||
block_idx_last_computed_token, non_blocking=True
|
|
||||||
)
|
|
||||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
|
||||||
:num_decode_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
attn_metadata = Mamba2AttentionMetadata(
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
query_start_loc_p=query_start_loc_p,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
has_initial_states_p=has_initial_states_p,
|
|
||||||
seq_idx_p=seq_idx_p,
|
seq_idx_p=seq_idx_p,
|
||||||
state_indices_tensor=state_indices_tensor,
|
|
||||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||||
last_chunk_indices_p=last_chunk_indices_p,
|
last_chunk_indices_p=last_chunk_indices_p,
|
||||||
nums_dict=nums_dict,
|
|
||||||
batch_ptr=batch_ptr,
|
|
||||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
|
||||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
|
||||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
|
||||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
|
||||||
num_computed_tokens_p=num_computed_tokens_p,
|
|
||||||
)
|
)
|
||||||
return attn_metadata
|
|
||||||
|
|
||||||
def update_block_table(
|
|
||||||
self,
|
|
||||||
metadata: Mamba2AttentionMetadata,
|
|
||||||
blk_table: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
) -> Mamba2AttentionMetadata:
|
|
||||||
new_metadata = copy.copy(metadata)
|
|
||||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
|
||||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
|
||||||
num_reqs = blk_table.shape[0]
|
|
||||||
|
|
||||||
# For CUDA graphs, copy to persistent buffer
|
|
||||||
if (
|
|
||||||
metadata.num_prefills == 0
|
|
||||||
and num_reqs <= self.decode_cudagraph_max_bs
|
|
||||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
||||||
):
|
|
||||||
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
|
||||||
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
|
||||||
state_indices_t = persistent_state_indices_t
|
|
||||||
|
|
||||||
new_metadata.state_indices_tensor = state_indices_t
|
|
||||||
return new_metadata
|
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, TypeVar
|
from typing import ClassVar, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,20 +11,52 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
PAD_SLOT_ID,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
compute_causal_conv1d_metadata,
|
||||||
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
M = TypeVar("M")
|
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseMambaAttentionMetadata:
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_reqs: int
|
||||||
|
|
||||||
|
# The following tensors only contain prefill requests and will be None if
|
||||||
|
# the batch has no prefill request.
|
||||||
|
has_initial_states_p: torch.Tensor | None
|
||||||
|
query_start_loc_p: torch.Tensor | None
|
||||||
|
num_computed_tokens_p: torch.Tensor | None
|
||||||
|
|
||||||
|
state_indices_tensor: torch.Tensor
|
||||||
|
|
||||||
|
# The following tensors are only used for prefix caching and are None if disabled
|
||||||
|
block_idx_last_scheduled_token: torch.Tensor | None
|
||||||
|
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||||
|
block_idx_last_computed_token: torch.Tensor | None
|
||||||
|
|
||||||
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
|
nums_dict: dict | None = None
|
||||||
|
batch_ptr: torch.Tensor | None = None
|
||||||
|
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||||
|
metadata_cls: type[M]
|
||||||
reorder_batch_threshold: int = 1
|
reorder_batch_threshold: int = 1
|
||||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
)
|
)
|
||||||
|
supports_update_block_table: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
|
|
||||||
return self.build(0, m)
|
return self.build(0, m)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> M:
|
||||||
|
"""
|
||||||
|
Default build implementation for Mamba-like attention backends.
|
||||||
|
Subclasses (e.g., Mamba2) can override to add additional metadata.
|
||||||
|
"""
|
||||||
|
return self._compute_common_metadata(common_attn_metadata)
|
||||||
|
|
||||||
def _compute_prefix_caching_block_indices(
|
def _compute_prefix_caching_block_indices(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
block_idx_first_scheduled_token,
|
block_idx_first_scheduled_token,
|
||||||
block_idx_last_scheduled_token,
|
block_idx_last_scheduled_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_common_metadata(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> M:
|
||||||
|
"""
|
||||||
|
Compute metadata common to both Mamba1 and Mamba2.
|
||||||
|
"""
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(
|
||||||
|
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Need flags to indicate if there are initial states
|
||||||
|
has_initial_states_p = None
|
||||||
|
query_start_loc_p = None
|
||||||
|
num_computed_tokens = None
|
||||||
|
num_computed_tokens_p = None
|
||||||
|
|
||||||
|
# for prefix caching
|
||||||
|
block_idx_first_scheduled_token = None
|
||||||
|
block_idx_first_scheduled_token_p = None
|
||||||
|
block_idx_last_computed_token = None
|
||||||
|
block_idx_last_scheduled_token = None
|
||||||
|
|
||||||
|
# for causal_conv1d
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
# Return a tensor of shape (#requests, #max blocks)
|
||||||
|
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
# Additional cache-related varaiables:
|
||||||
|
mamba_block_size = self.kv_cache_spec.block_size
|
||||||
|
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
(
|
||||||
|
block_idx_last_computed_token,
|
||||||
|
block_idx_first_scheduled_token,
|
||||||
|
block_idx_last_scheduled_token,
|
||||||
|
) = self._compute_prefix_caching_block_indices(
|
||||||
|
common_attn_metadata, mamba_block_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Always return just a single block per each request:
|
||||||
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
|
|
||||||
|
if num_prefills > 0:
|
||||||
|
query_start_loc_p = (
|
||||||
|
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||||
|
- num_decode_tokens
|
||||||
|
)
|
||||||
|
has_initial_states_cpu = (
|
||||||
|
common_attn_metadata.num_computed_tokens_cpu[
|
||||||
|
num_reqs - num_prefills : num_reqs
|
||||||
|
]
|
||||||
|
> 0
|
||||||
|
)
|
||||||
|
has_initial_states_p = has_initial_states_cpu.to(
|
||||||
|
common_attn_metadata.query_start_loc.device
|
||||||
|
)
|
||||||
|
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||||
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
assert num_computed_tokens is not None
|
||||||
|
num_computed_tokens_p = num_computed_tokens[
|
||||||
|
num_reqs - num_prefills : num_reqs
|
||||||
|
]
|
||||||
|
assert block_idx_first_scheduled_token is not None
|
||||||
|
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||||
|
num_reqs - num_prefills : num_reqs
|
||||||
|
]
|
||||||
|
elif (
|
||||||
|
num_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
):
|
||||||
|
self.state_indices_tensor[:num_decodes].copy_(
|
||||||
|
state_indices_tensor, non_blocking=True
|
||||||
|
)
|
||||||
|
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||||
|
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||||
|
|
||||||
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||||
|
block_idx_last_scheduled_token, non_blocking=True
|
||||||
|
)
|
||||||
|
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||||
|
:num_decode_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||||
|
block_idx_last_computed_token, non_blocking=True
|
||||||
|
)
|
||||||
|
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||||
|
:num_decode_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.metadata_cls(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
query_start_loc_p=query_start_loc_p,
|
||||||
|
has_initial_states_p=has_initial_states_p,
|
||||||
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||||
|
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||||
|
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||||
|
num_computed_tokens_p=num_computed_tokens_p,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
nums_dict=nums_dict,
|
||||||
|
batch_ptr=batch_ptr,
|
||||||
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_block_table(
|
||||||
|
self,
|
||||||
|
metadata: M,
|
||||||
|
blk_table: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> M:
|
||||||
|
new_metadata = copy.copy(metadata)
|
||||||
|
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||||
|
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||||
|
num_reqs = blk_table.shape[0]
|
||||||
|
|
||||||
|
# For CUDA graphs, copy to persistent buffer
|
||||||
|
if (
|
||||||
|
metadata.num_prefills == 0
|
||||||
|
and num_reqs <= self.decode_cudagraph_max_bs
|
||||||
|
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
|
):
|
||||||
|
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
|
||||||
|
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
|
||||||
|
state_indices_t = persistent_state_indices_t
|
||||||
|
|
||||||
|
new_metadata.state_indices_tensor = state_indices_t
|
||||||
|
return new_metadata
|
||||||
|
|||||||
@ -2,15 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
from vllm.v1.attention.backends.utils import (
|
BaseMambaAttentionMetadata,
|
||||||
PAD_SLOT_ID,
|
BaseMambaAttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
|
||||||
compute_causal_conv1d_metadata,
|
|
||||||
split_decodes_and_prefills,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShortConvAttentionMetadata:
|
class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
|
||||||
num_prefills: int
|
pass
|
||||||
num_prefill_tokens: int
|
|
||||||
num_decodes: int
|
|
||||||
num_decode_tokens: int
|
|
||||||
|
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
state_indices_tensor: torch.Tensor
|
|
||||||
has_initial_states_p: torch.Tensor | None
|
|
||||||
|
|
||||||
# For causal_conv1d
|
|
||||||
nums_dict: dict | None = None
|
|
||||||
batch_ptr: torch.Tensor | None = None
|
|
||||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ShortConvAttentionMetadataBuilder(
|
class ShortConvAttentionMetadataBuilder(
|
||||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
||||||
):
|
):
|
||||||
def build(
|
metadata_cls = ShortConvAttentionMetadata
|
||||||
self,
|
|
||||||
common_prefix_len: int,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
fast_build: bool = False,
|
|
||||||
) -> ShortConvAttentionMetadata:
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
|
||||||
|
|
||||||
# for causal_conv1d
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
||||||
split_decodes_and_prefills(
|
|
||||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
has_initial_states_p = None
|
|
||||||
if num_prefills > 0:
|
|
||||||
has_initial_states_cpu = (
|
|
||||||
common_attn_metadata.num_computed_tokens_cpu[
|
|
||||||
num_reqs - num_prefills : num_reqs
|
|
||||||
]
|
|
||||||
> 0
|
|
||||||
)
|
|
||||||
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
|
|
||||||
|
|
||||||
query_start_loc_p = (
|
|
||||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
|
||||||
- num_decode_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
|
||||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif (
|
|
||||||
num_decodes > 0
|
|
||||||
and num_decodes <= self.decode_cudagraph_max_bs
|
|
||||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
||||||
):
|
|
||||||
self.state_indices_tensor[:num_decodes].copy_(
|
|
||||||
state_indices_tensor, non_blocking=True
|
|
||||||
)
|
|
||||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
|
||||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
|
||||||
|
|
||||||
attn_metadata = ShortConvAttentionMetadata(
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
state_indices_tensor=state_indices_tensor,
|
|
||||||
has_initial_states_p=has_initial_states_p,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
nums_dict=nums_dict,
|
|
||||||
batch_ptr=batch_ptr,
|
|
||||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
|
||||||
)
|
|
||||||
return attn_metadata
|
|
||||||
|
|||||||
@ -75,6 +75,12 @@ class EngineCoreRequest(
|
|||||||
|
|
||||||
trace_headers: Mapping[str, str] | None = None
|
trace_headers: Mapping[str, str] | None = None
|
||||||
|
|
||||||
|
# The user-provided request ID. This field is set internally,
|
||||||
|
# copied from the provided request_id that's originally assigned
|
||||||
|
# to the request_id field, see InputProcessor.assign_request_id().
|
||||||
|
# Used in outputs and to support abort(req_id, internal=False).
|
||||||
|
external_req_id: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def params(self) -> SamplingParams | PoolingParams:
|
def params(self) -> SamplingParams | PoolingParams:
|
||||||
"""Return the processed params (sampling or pooling)."""
|
"""Return the processed params (sampling or pooling)."""
|
||||||
|
|||||||
@ -290,12 +290,15 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
is_pooling = isinstance(params, PoolingParams)
|
is_pooling = isinstance(params, PoolingParams)
|
||||||
|
|
||||||
# Create a new output collector for the request.
|
|
||||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
|
||||||
|
|
||||||
# Convert Input --> Request.
|
# Convert Input --> Request.
|
||||||
if isinstance(prompt, EngineCoreRequest):
|
if isinstance(prompt, EngineCoreRequest):
|
||||||
request = prompt
|
request = prompt
|
||||||
|
if request_id != request.request_id:
|
||||||
|
logger.warning_once(
|
||||||
|
"AsyncLLM.add_request() was passed a request_id parameter that "
|
||||||
|
"does not match the EngineCoreRequest.request_id attribute. The "
|
||||||
|
"latter will be used, and the former will be ignored."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert prompt_text is None
|
assert prompt_text is None
|
||||||
request = self.input_processor.process_inputs(
|
request = self.input_processor.process_inputs(
|
||||||
@ -314,6 +317,11 @@ class AsyncLLM(EngineClient):
|
|||||||
elif isinstance(prompt, Mapping):
|
elif isinstance(prompt, Mapping):
|
||||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
|
self.input_processor.assign_request_id(request)
|
||||||
|
|
||||||
|
# Create a new output collector for the request.
|
||||||
|
queue = RequestOutputCollector(params.output_kind, request.request_id)
|
||||||
|
|
||||||
# Use cloned params that may have been updated in process_inputs()
|
# Use cloned params that may have been updated in process_inputs()
|
||||||
params = request.params
|
params = request.params
|
||||||
|
|
||||||
@ -325,7 +333,7 @@ class AsyncLLM(EngineClient):
|
|||||||
assert isinstance(parent_params, SamplingParams)
|
assert isinstance(parent_params, SamplingParams)
|
||||||
|
|
||||||
# Fan out child requests (for n>1).
|
# Fan out child requests (for n>1).
|
||||||
parent_request = ParentRequest(request_id, parent_params)
|
parent_request = ParentRequest(request)
|
||||||
for idx in range(parent_params.n):
|
for idx in range(parent_params.n):
|
||||||
request_id, child_params = parent_request.get_child_info(idx)
|
request_id, child_params = parent_request.get_child_info(idx)
|
||||||
child_request = request if idx == parent_params.n - 1 else copy(request)
|
child_request = request if idx == parent_params.n - 1 else copy(request)
|
||||||
@ -396,6 +404,7 @@ class AsyncLLM(EngineClient):
|
|||||||
"prompt logprobs"
|
"prompt logprobs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
q: RequestOutputCollector | None = None
|
||||||
try:
|
try:
|
||||||
# We start the output_handler on the first call to generate() so
|
# We start the output_handler on the first call to generate() so
|
||||||
# we can call __init__ before the event loop, which enables us
|
# we can call __init__ before the event loop, which enables us
|
||||||
@ -446,7 +455,8 @@ class AsyncLLM(EngineClient):
|
|||||||
# is cancelled or the generator is garbage collected. So,
|
# is cancelled or the generator is garbage collected. So,
|
||||||
# we abort the request if we end up here.
|
# we abort the request if we end up here.
|
||||||
except (asyncio.CancelledError, GeneratorExit):
|
except (asyncio.CancelledError, GeneratorExit):
|
||||||
await self.abort(request_id)
|
if q is not None:
|
||||||
|
await self.abort(q.request_id, internal=True)
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Request %s aborted.", request_id)
|
logger.info("Request %s aborted.", request_id)
|
||||||
raise
|
raise
|
||||||
@ -465,7 +475,8 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
# Unexpected error in the generate() task (possibly recoverable).
|
# Unexpected error in the generate() task (possibly recoverable).
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.abort(request_id)
|
if q is not None:
|
||||||
|
await self.abort(q.request_id, internal=True)
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Request %s failed.", request_id)
|
logger.info("Request %s failed.", request_id)
|
||||||
raise EngineGenerateError() from e
|
raise EngineGenerateError() from e
|
||||||
@ -541,13 +552,15 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
self.output_handler = asyncio.create_task(output_handler())
|
self.output_handler = asyncio.create_task(output_handler())
|
||||||
|
|
||||||
async def abort(self, request_id: str | Iterable[str]) -> None:
|
async def abort(
|
||||||
|
self, request_id: str | Iterable[str], internal: bool = False
|
||||||
|
) -> None:
|
||||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||||
|
|
||||||
request_ids = (
|
request_ids = (
|
||||||
(request_id,) if isinstance(request_id, str) else as_list(request_id)
|
(request_id,) if isinstance(request_id, str) else as_list(request_id)
|
||||||
)
|
)
|
||||||
all_request_ids = self.output_processor.abort_requests(request_ids)
|
all_request_ids = self.output_processor.abort_requests(request_ids, internal)
|
||||||
await self.engine_core.abort_requests_async(all_request_ids)
|
await self.engine_core.abort_requests_async(all_request_ids)
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
@ -581,7 +594,7 @@ class AsyncLLM(EngineClient):
|
|||||||
if not wait_for_inflight_requests:
|
if not wait_for_inflight_requests:
|
||||||
request_ids = list(self.output_processor.request_states.keys())
|
request_ids = list(self.output_processor.request_states.keys())
|
||||||
if request_ids:
|
if request_ids:
|
||||||
await self.abort(request_ids)
|
await self.abort(request_ids, internal=True)
|
||||||
|
|
||||||
# Wait for running requests to drain before clearing cache.
|
# Wait for running requests to drain before clearing cache.
|
||||||
if self.output_processor.has_unfinished_requests():
|
if self.output_processor.has_unfinished_requests():
|
||||||
@ -633,6 +646,7 @@ class AsyncLLM(EngineClient):
|
|||||||
TODO: Remove truncate_prompt_tokens in v0.15.
|
TODO: Remove truncate_prompt_tokens in v0.15.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
q: RequestOutputCollector | None = None
|
||||||
try:
|
try:
|
||||||
# We start the output_handler on the first call to generate() so
|
# We start the output_handler on the first call to generate() so
|
||||||
# we can call __init__ before the event loop, which enables us
|
# we can call __init__ before the event loop, which enables us
|
||||||
@ -687,7 +701,8 @@ class AsyncLLM(EngineClient):
|
|||||||
# If the request is disconnected by the client, generate()
|
# If the request is disconnected by the client, generate()
|
||||||
# is cancelled. So, we abort the request if we end up here.
|
# is cancelled. So, we abort the request if we end up here.
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self.abort(request_id)
|
if q is not None:
|
||||||
|
await self.abort(q.request_id, internal=True)
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Request %s aborted.", request_id)
|
logger.info("Request %s aborted.", request_id)
|
||||||
raise
|
raise
|
||||||
@ -706,7 +721,8 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
# Unexpected error in the generate() task (possibly recoverable).
|
# Unexpected error in the generate() task (possibly recoverable).
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.abort(request_id)
|
if q is not None:
|
||||||
|
await self.abort(q.request_id, internal=True)
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Request %s failed.", request_id)
|
logger.info("Request %s failed.", request_id)
|
||||||
raise EngineGenerateError() from e
|
raise EngineGenerateError() from e
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
from vllm.tokenizers.mistral import MistralTokenizer
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||||
from vllm.v1.structured_output.backend_guidance import (
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
@ -406,6 +406,19 @@ class InputProcessor:
|
|||||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||||
return mm_uuids
|
return mm_uuids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def assign_request_id(request: EngineCoreRequest):
|
||||||
|
"""Replace the externally supplied request ID with an internal request ID
|
||||||
|
that adds 8 random characters in order to ensure uniquness.
|
||||||
|
"""
|
||||||
|
if request.external_req_id is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"The external_req_id field should not be set on EngineCoreRequests"
|
||||||
|
" passed to vLLM; use the request_id field."
|
||||||
|
)
|
||||||
|
request.external_req_id = request.request_id
|
||||||
|
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|||||||
@ -213,10 +213,10 @@ class LLMEngine:
|
|||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.engine_core.get_supported_tasks()
|
return self.engine_core.get_supported_tasks()
|
||||||
|
|
||||||
def abort_request(self, request_ids: list[str]) -> None:
|
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
|
||||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||||
|
|
||||||
request_ids = self.output_processor.abort_requests(request_ids)
|
request_ids = self.output_processor.abort_requests(request_ids, internal)
|
||||||
self.engine_core.abort_requests(request_ids)
|
self.engine_core.abort_requests(request_ids)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
@ -238,6 +238,12 @@ class LLMEngine:
|
|||||||
# Process raw inputs into the request.
|
# Process raw inputs into the request.
|
||||||
if isinstance(prompt, EngineCoreRequest):
|
if isinstance(prompt, EngineCoreRequest):
|
||||||
request = prompt
|
request = prompt
|
||||||
|
if request_id != request.request_id:
|
||||||
|
logger.warning_once(
|
||||||
|
"AsyncLLM.add_request() was passed a request_id parameter that "
|
||||||
|
"does not match the EngineCoreRequest.request_id attribute. The "
|
||||||
|
"latter will be used, and the former will be ignored."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert prompt_text is None
|
assert prompt_text is None
|
||||||
request = self.input_processor.process_inputs(
|
request = self.input_processor.process_inputs(
|
||||||
@ -255,6 +261,8 @@ class LLMEngine:
|
|||||||
elif isinstance(prompt, Mapping):
|
elif isinstance(prompt, Mapping):
|
||||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||||
|
|
||||||
|
self.input_processor.assign_request_id(request)
|
||||||
|
|
||||||
# Use cloned params that may have been updated in process_inputs()
|
# Use cloned params that may have been updated in process_inputs()
|
||||||
params = request.params
|
params = request.params
|
||||||
|
|
||||||
@ -268,7 +276,7 @@ class LLMEngine:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Fan out child requests (for n>1).
|
# Fan out child requests (for n>1).
|
||||||
parent_req = ParentRequest(request_id, params)
|
parent_req = ParentRequest(request)
|
||||||
for idx in range(n):
|
for idx in range(n):
|
||||||
request_id, child_params = parent_req.get_child_info(idx)
|
request_id, child_params = parent_req.get_child_info(idx)
|
||||||
child_request = request if idx == n - 1 else copy(request)
|
child_request = request if idx == n - 1 else copy(request)
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@ -40,8 +41,9 @@ class RequestOutputCollector:
|
|||||||
producer gets ahead of the consumer.
|
producer gets ahead of the consumer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_kind: RequestOutputKind):
|
def __init__(self, output_kind: RequestOutputKind, request_id: str):
|
||||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||||
|
self.request_id = request_id
|
||||||
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
||||||
self.ready = asyncio.Event()
|
self.ready = asyncio.Event()
|
||||||
|
|
||||||
@ -92,6 +94,7 @@ class RequestState:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
external_req_id: str,
|
||||||
parent_req: ParentRequest | None,
|
parent_req: ParentRequest | None,
|
||||||
request_index: int,
|
request_index: int,
|
||||||
lora_request: LoRARequest | None,
|
lora_request: LoRARequest | None,
|
||||||
@ -111,6 +114,7 @@ class RequestState:
|
|||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self.external_req_id = external_req_id
|
||||||
self.parent_req = parent_req
|
self.parent_req = parent_req
|
||||||
self.request_index = request_index
|
self.request_index = request_index
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
@ -176,8 +180,10 @@ class RequestState:
|
|||||||
assert request.pooling_params is not None
|
assert request.pooling_params is not None
|
||||||
output_kind = request.pooling_params.output_kind
|
output_kind = request.pooling_params.output_kind
|
||||||
|
|
||||||
|
assert request.external_req_id is not None
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
|
external_req_id=request.external_req_id,
|
||||||
parent_req=parent_req,
|
parent_req=parent_req,
|
||||||
request_index=request_index,
|
request_index=request_index,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
@ -235,10 +241,13 @@ class RequestState:
|
|||||||
]
|
]
|
||||||
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
||||||
|
|
||||||
request_id = self.request_id
|
external_req_id = self.external_req_id
|
||||||
|
|
||||||
if pooling_output is not None:
|
if pooling_output is not None:
|
||||||
return self._new_request_output(
|
return self._new_request_output(
|
||||||
request_id, [self._new_pooling_output(pooling_output)], finished
|
external_req_id,
|
||||||
|
[self._new_pooling_output(pooling_output)],
|
||||||
|
finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
||||||
@ -246,19 +255,18 @@ class RequestState:
|
|||||||
if self.parent_req is None:
|
if self.parent_req is None:
|
||||||
outputs = [output]
|
outputs = [output]
|
||||||
else:
|
else:
|
||||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
outputs, finished = self.parent_req.get_outputs(self.request_id, output)
|
||||||
request_id, output
|
|
||||||
)
|
|
||||||
if not outputs:
|
if not outputs:
|
||||||
return None
|
return None
|
||||||
|
external_req_id = self.parent_req.external_req_id
|
||||||
|
|
||||||
return self._new_request_output(
|
return self._new_request_output(
|
||||||
request_id, outputs, finished, kv_transfer_params
|
external_req_id, outputs, finished, kv_transfer_params
|
||||||
)
|
)
|
||||||
|
|
||||||
def _new_request_output(
|
def _new_request_output(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
external_req_id: str,
|
||||||
outputs: list[CompletionOutput] | list[PoolingOutput],
|
outputs: list[CompletionOutput] | list[PoolingOutput],
|
||||||
finished: bool,
|
finished: bool,
|
||||||
kv_transfer_params: dict[str, Any] | None = None,
|
kv_transfer_params: dict[str, Any] | None = None,
|
||||||
@ -269,7 +277,7 @@ class RequestState:
|
|||||||
# Prompt embeddings are currently not supported by pooling requests.
|
# Prompt embeddings are currently not supported by pooling requests.
|
||||||
assert self.prompt_token_ids is not None
|
assert self.prompt_token_ids is not None
|
||||||
return PoolingRequestOutput(
|
return PoolingRequestOutput(
|
||||||
request_id=request_id,
|
request_id=external_req_id,
|
||||||
outputs=first_output,
|
outputs=first_output,
|
||||||
num_cached_tokens=self.num_cached_tokens,
|
num_cached_tokens=self.num_cached_tokens,
|
||||||
prompt_token_ids=self.prompt_token_ids,
|
prompt_token_ids=self.prompt_token_ids,
|
||||||
@ -288,7 +296,7 @@ class RequestState:
|
|||||||
prompt_token_ids = [0] * len(self.prompt_embeds)
|
prompt_token_ids = [0] * len(self.prompt_embeds)
|
||||||
|
|
||||||
return RequestOutput(
|
return RequestOutput(
|
||||||
request_id=request_id,
|
request_id=external_req_id, # request_id is what was provided externally
|
||||||
lora_request=self.lora_request,
|
lora_request=self.lora_request,
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
@ -352,6 +360,7 @@ class OutputProcessor:
|
|||||||
self.stream_interval = stream_interval
|
self.stream_interval = stream_interval
|
||||||
self.request_states: dict[str, RequestState] = {}
|
self.request_states: dict[str, RequestState] = {}
|
||||||
self.parent_requests: dict[str, ParentRequest] = {}
|
self.parent_requests: dict[str, ParentRequest] = {}
|
||||||
|
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
self.lora_states = LoRARequestStates(log_stats)
|
self.lora_states = LoRARequestStates(log_stats)
|
||||||
self.tracer: Tracer | None = None
|
self.tracer: Tracer | None = None
|
||||||
self._requests_drained = asyncio.Event()
|
self._requests_drained = asyncio.Event()
|
||||||
@ -375,12 +384,41 @@ class OutputProcessor:
|
|||||||
assert state.queue is not None
|
assert state.queue is not None
|
||||||
state.queue.put(e)
|
state.queue.put(e)
|
||||||
|
|
||||||
def abort_requests(
|
def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str]:
|
||||||
self,
|
"""Abort a list of requests.
|
||||||
request_ids: Iterable[str],
|
|
||||||
) -> list[str]:
|
The request_ids may be either external request IDs (those passed to
|
||||||
request_ids_to_abort = []
|
InputProcessor.process_inputs()) or internal request IDs (those randomly
|
||||||
|
generated when creating the EngineCoreRequest).
|
||||||
|
|
||||||
|
If an external request ID is provided, and that external request ID
|
||||||
|
was used for multiple requests, all requests associated with that external
|
||||||
|
request ID are aborted.
|
||||||
|
|
||||||
|
In the case of parallel sampling, a request ID may be used to identify
|
||||||
|
a parent request, in which case the associated child requests are aborted
|
||||||
|
also.
|
||||||
|
"""
|
||||||
|
|
||||||
|
internal_req_ids = []
|
||||||
for request_id in request_ids:
|
for request_id in request_ids:
|
||||||
|
if internal:
|
||||||
|
# Internal ID - this may be a parent request
|
||||||
|
internal_req_ids.append(request_id)
|
||||||
|
|
||||||
|
# Remove internal ID from the external->internal mapping
|
||||||
|
if req_state := self.request_states.get(request_id):
|
||||||
|
external_req_id = req_state.external_req_id
|
||||||
|
internal_ids = self.external_req_ids[external_req_id]
|
||||||
|
internal_ids.remove(request_id)
|
||||||
|
if not internal_ids:
|
||||||
|
del self.external_req_ids[external_req_id]
|
||||||
|
elif internal_ids := self.external_req_ids.pop(request_id, []):
|
||||||
|
# External ID - abort all requests in the external->internal mapping
|
||||||
|
internal_req_ids.extend(internal_ids)
|
||||||
|
|
||||||
|
request_ids_to_abort = []
|
||||||
|
for request_id in internal_req_ids:
|
||||||
req_state = self.request_states.pop(request_id, None)
|
req_state = self.request_states.pop(request_id, None)
|
||||||
if req_state is not None:
|
if req_state is not None:
|
||||||
self.lora_states.request_finished(request_id, req_state.lora_name)
|
self.lora_states.request_finished(request_id, req_state.lora_name)
|
||||||
@ -404,7 +442,7 @@ class OutputProcessor:
|
|||||||
# Abort children prior to removing the parent.
|
# Abort children prior to removing the parent.
|
||||||
if parent.child_requests:
|
if parent.child_requests:
|
||||||
child_reqs = list(parent.child_requests)
|
child_reqs = list(parent.child_requests)
|
||||||
child_reqs = self.abort_requests(child_reqs)
|
child_reqs = self.abort_requests(child_reqs, internal=True)
|
||||||
request_ids_to_abort.extend(child_reqs)
|
request_ids_to_abort.extend(child_reqs)
|
||||||
self.parent_requests.pop(request_id, None)
|
self.parent_requests.pop(request_id, None)
|
||||||
if not self.request_states:
|
if not self.request_states:
|
||||||
@ -439,6 +477,9 @@ class OutputProcessor:
|
|||||||
if parent_req:
|
if parent_req:
|
||||||
self.parent_requests[parent_req.request_id] = parent_req
|
self.parent_requests[parent_req.request_id] = parent_req
|
||||||
|
|
||||||
|
# Track the external_req_id -> [internal_req_id, ...] mapping
|
||||||
|
self.external_req_ids[req_state.external_req_id].append(request_id)
|
||||||
|
|
||||||
def process_outputs(
|
def process_outputs(
|
||||||
self,
|
self,
|
||||||
engine_core_outputs: list[EngineCoreOutput],
|
engine_core_outputs: list[EngineCoreOutput],
|
||||||
@ -522,6 +563,12 @@ class OutputProcessor:
|
|||||||
# Free completed requests.
|
# Free completed requests.
|
||||||
if finish_reason is not None:
|
if finish_reason is not None:
|
||||||
self.request_states.pop(req_id)
|
self.request_states.pop(req_id)
|
||||||
|
|
||||||
|
internal_ids = self.external_req_ids[req_state.external_req_id]
|
||||||
|
internal_ids.remove(req_id)
|
||||||
|
if not internal_ids:
|
||||||
|
del self.external_req_ids[req_state.external_req_id]
|
||||||
|
|
||||||
# Remove parent request if applicable.
|
# Remove parent request if applicable.
|
||||||
parent_req = req_state.parent_req
|
parent_req = req_state.parent_req
|
||||||
if parent_req and not parent_req.child_requests:
|
if parent_req and not parent_req.child_requests:
|
||||||
@ -597,7 +644,9 @@ class OutputProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# meta
|
# meta
|
||||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
|
span.set_attribute(
|
||||||
|
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
|
||||||
|
)
|
||||||
if req_state.top_p:
|
if req_state.top_p:
|
||||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
|
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
|
||||||
if req_state.max_tokens_param:
|
if req_state.max_tokens_param:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from vllm.outputs import CompletionOutput
|
from vllm.outputs import CompletionOutput
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.metrics.stats import IterationStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ class ParentRequest:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
|
external_req_id: str
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
# To track the completion of child requests
|
# To track the completion of child requests
|
||||||
@ -31,8 +33,11 @@ class ParentRequest:
|
|||||||
# To efficiently obtain child sampling params
|
# To efficiently obtain child sampling params
|
||||||
cached_child_sampling_params: SamplingParams | None
|
cached_child_sampling_params: SamplingParams | None
|
||||||
|
|
||||||
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
|
def __init__(self, request: EngineCoreRequest) -> None:
|
||||||
self.request_id = request_id
|
assert request.external_req_id is not None
|
||||||
|
sampling_params = request.params
|
||||||
|
self.request_id = request.request_id
|
||||||
|
self.external_req_id = request.external_req_id
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
|
|
||||||
self.child_requests = set()
|
self.child_requests = set()
|
||||||
@ -96,7 +101,7 @@ class ParentRequest:
|
|||||||
self,
|
self,
|
||||||
child_request_id: str,
|
child_request_id: str,
|
||||||
completion_output: CompletionOutput,
|
completion_output: CompletionOutput,
|
||||||
) -> tuple[str, list[CompletionOutput], bool]:
|
) -> tuple[list[CompletionOutput], bool]:
|
||||||
already_finished_and_returned: bool = False
|
already_finished_and_returned: bool = False
|
||||||
if completion_output.finished():
|
if completion_output.finished():
|
||||||
if child_request_id in self.child_requests:
|
if child_request_id in self.child_requests:
|
||||||
@ -118,7 +123,7 @@ class ParentRequest:
|
|||||||
outputs = [] if self.child_requests else self.output_aggregator
|
outputs = [] if self.child_requests else self.output_aggregator
|
||||||
|
|
||||||
finished = not self.child_requests
|
finished = not self.child_requests
|
||||||
return self.request_id, outputs, finished
|
return outputs, finished
|
||||||
|
|
||||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||||
self.max_num_generation_tokens = max(
|
self.max_num_generation_tokens = max(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user