mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 22:47:51 +08:00
[V1] [2/n] Logging and Metrics - OutputProcessor Abstraction (#11973)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
d14e98d924
commit
619ae268c3
@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -13,6 +13,7 @@ if not current_platform.is_cuda():
|
|||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
|
|
||||||
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
||||||
|
enforce_eager=True,
|
||||||
disable_log_requests=True)
|
disable_log_requests=True)
|
||||||
|
|
||||||
|
|
||||||
@ -53,17 +54,63 @@ async def test_load(monkeypatch):
|
|||||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||||
failed_request_id = None
|
|
||||||
tokens = None
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
num_generated_tokens, request_id = await task
|
num_generated_tokens, request_id = await task
|
||||||
if (num_generated_tokens != NUM_EXPECTED_TOKENS
|
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||||
and failed_request_id is None):
|
f"{request_id} generated {num_generated_tokens} but "
|
||||||
failed_request_id = request_id
|
f"expected {NUM_EXPECTED_TOKENS}")
|
||||||
tokens = num_generated_tokens
|
|
||||||
|
|
||||||
assert failed_request_id is None, (
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
f"{failed_request_id} generated {tokens} but "
|
engine.shutdown()
|
||||||
f"expected {NUM_EXPECTED_TOKENS}")
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abort(monkeypatch):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||||
|
|
||||||
|
NUM_REQUESTS = 100
|
||||||
|
NUM_EXPECTED_TOKENS = 100
|
||||||
|
REQUEST_IDS_TO_ABORT = range(1, 100, 10)
|
||||||
|
|
||||||
|
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||||
|
|
||||||
|
# Create concurrent requests.
|
||||||
|
tasks: List[asyncio.Task] = []
|
||||||
|
for request_id in request_ids:
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
|
# API server cancels requests when they disconnect.
|
||||||
|
for idx in REQUEST_IDS_TO_ABORT:
|
||||||
|
tasks[idx].cancel()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Confirm the other requests are okay.
|
||||||
|
for idx, task in enumerate(tasks):
|
||||||
|
# Confirm that it was actually canceled.
|
||||||
|
if idx in REQUEST_IDS_TO_ABORT:
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
else:
|
||||||
|
# Otherwise, make sure the request was not impacted.
|
||||||
|
num_generated_tokens, request_id = await task
|
||||||
|
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||||
|
f"{request_id} generated {num_generated_tokens} but "
|
||||||
|
f"expected {NUM_EXPECTED_TOKENS}")
|
||||||
|
|
||||||
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
# Confirm we can do another generation.
|
||||||
|
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||||
|
task = asyncio.create_task(
|
||||||
|
generate(engine, request_id, NUM_EXPECTED_TOKENS))
|
||||||
|
num_generated_tokens, request_id = await task
|
||||||
|
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||||
|
assert not engine.output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|||||||
@ -3,11 +3,18 @@ from typing import List
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
|
|
||||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
|
||||||
|
TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config,
|
||||||
|
VLLM_CONFIG.scheduler_config,
|
||||||
|
VLLM_CONFIG.parallel_config,
|
||||||
|
VLLM_CONFIG.lora_config)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||||
|
|
||||||
FULL_STRINGS = [
|
FULL_STRINGS = [
|
||||||
@ -66,7 +73,7 @@ class MockEngineCore:
|
|||||||
"request_output_kind",
|
"request_output_kind",
|
||||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||||
detokenizer = Detokenizer(TOKENIZER_NAME)
|
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
|
||||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
@ -93,7 +100,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
|||||||
|
|
||||||
# Add requests to the detokenizer.
|
# Add requests to the detokenizer.
|
||||||
for request in requests:
|
for request in requests:
|
||||||
detokenizer.add_request(request)
|
output_processor.add_request(request)
|
||||||
|
|
||||||
gen_strings = {}
|
gen_strings = {}
|
||||||
gen_tokens = {}
|
gen_tokens = {}
|
||||||
@ -104,7 +111,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Step the Detokenizer.
|
# Step the Detokenizer.
|
||||||
request_outputs, requests_to_abort = detokenizer.step(outputs)
|
processed_outputs = output_processor.process_outputs(outputs, )
|
||||||
|
request_outputs = processed_outputs.request_outputs
|
||||||
|
requests_to_abort = processed_outputs.reqs_to_abort
|
||||||
assert len(requests_to_abort) == 0
|
assert len(requests_to_abort) == 0
|
||||||
|
|
||||||
# Update tracking.
|
# Update tracking.
|
||||||
@ -128,13 +137,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
|||||||
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
|
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
|
||||||
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
|
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
|
||||||
|
|
||||||
assert detokenizer.get_num_unfinished_requests() == 0
|
assert output_processor.get_num_unfinished_requests() == 0
|
||||||
assert not detokenizer.has_unfinished_requests()
|
assert not output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||||
def test_stop_string(include_stop_str_in_output: bool):
|
def test_stop_string(include_stop_str_in_output: bool):
|
||||||
detokenizer = Detokenizer(TOKENIZER_NAME)
|
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
|
||||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
@ -162,7 +171,7 @@ def test_stop_string(include_stop_str_in_output: bool):
|
|||||||
|
|
||||||
# Add requests to the detokenizer.
|
# Add requests to the detokenizer.
|
||||||
for request in requests:
|
for request in requests:
|
||||||
detokenizer.add_request(request)
|
output_processor.add_request(request)
|
||||||
|
|
||||||
gen_strings = {}
|
gen_strings = {}
|
||||||
aborted = []
|
aborted = []
|
||||||
@ -173,7 +182,9 @@ def test_stop_string(include_stop_str_in_output: bool):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Step the Detokenizer.
|
# Step the Detokenizer.
|
||||||
request_outputs, requests_to_abort = detokenizer.step(outputs)
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
request_outputs = processed_outputs.request_outputs
|
||||||
|
requests_to_abort = processed_outputs.reqs_to_abort
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
# If aborted, we should not get a request output.
|
# If aborted, we should not get a request output.
|
||||||
assert request_output.request_id not in aborted
|
assert request_output.request_id not in aborted
|
||||||
@ -214,5 +225,71 @@ def test_stop_string(include_stop_str_in_output: bool):
|
|||||||
assert gen_str == ref_str_exc_stop, (
|
assert gen_str == ref_str_exc_stop, (
|
||||||
f"{gen_str=}, {ref_str_exc_stop=}")
|
f"{gen_str=}, {ref_str_exc_stop=}")
|
||||||
|
|
||||||
assert detokenizer.get_num_unfinished_requests() == 0
|
assert output_processor.get_num_unfinished_requests() == 0
|
||||||
assert not detokenizer.has_unfinished_requests()
|
assert not output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
|
|
||||||
|
def test_iteration_stats():
|
||||||
|
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True)
|
||||||
|
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||||
|
|
||||||
|
# Make N requests.
|
||||||
|
requests = [
|
||||||
|
EngineCoreRequest(
|
||||||
|
request_id=f"request-{idx}",
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_tokens,
|
||||||
|
arrival_time=0,
|
||||||
|
mm_inputs=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
mm_placeholders=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
lora_request=None,
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
) for idx, (
|
||||||
|
prompt,
|
||||||
|
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add all requests except one to the OutputProcessor.
|
||||||
|
num_active = len(GENERATION_TOKENS) - 1
|
||||||
|
for request in requests[:num_active]:
|
||||||
|
output_processor.add_request(request)
|
||||||
|
inactive_request = requests[num_active]
|
||||||
|
|
||||||
|
# First iteration has 2 prefills.
|
||||||
|
outputs = engine_core.get_outputs()[:num_active]
|
||||||
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
iteration_stats = processed_outputs.iteration_stats
|
||||||
|
total_prompt_tokens = sum(
|
||||||
|
[len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]])
|
||||||
|
|
||||||
|
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||||
|
assert iteration_stats.num_generation_tokens == num_active
|
||||||
|
|
||||||
|
# Just decodes in this step.
|
||||||
|
outputs = engine_core.get_outputs()[:num_active]
|
||||||
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
iteration_stats = processed_outputs.iteration_stats
|
||||||
|
|
||||||
|
assert iteration_stats.num_prompt_tokens == 0
|
||||||
|
assert iteration_stats.num_generation_tokens == num_active
|
||||||
|
|
||||||
|
# Add a new request - prefill and 2 decodes in this step.
|
||||||
|
output_processor.add_request(inactive_request)
|
||||||
|
num_active += 1
|
||||||
|
outputs = engine_core.get_outputs()[:num_active]
|
||||||
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
iteration_stats = processed_outputs.iteration_stats
|
||||||
|
total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1])
|
||||||
|
|
||||||
|
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||||
|
assert iteration_stats.num_generation_tokens == num_active
|
||||||
|
|
||||||
|
# Just decodes in this step.
|
||||||
|
outputs = engine_core.get_outputs()[:num_active]
|
||||||
|
processed_outputs = output_processor.process_outputs(outputs)
|
||||||
|
iteration_stats = processed_outputs.iteration_stats
|
||||||
|
|
||||||
|
assert iteration_stats.num_prompt_tokens == 0
|
||||||
|
assert iteration_stats.num_generation_tokens == num_active
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
|
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -18,11 +18,11 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import kill_process_tree
|
from vllm.utils import kill_process_tree
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
|
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -59,9 +59,6 @@ class AsyncLLM(EngineClient):
|
|||||||
lora_config=vllm_config.lora_config)
|
lora_config=vllm_config.lora_config)
|
||||||
self.tokenizer.ping()
|
self.tokenizer.ping()
|
||||||
|
|
||||||
# Request streams (map of request_id -> queue).
|
|
||||||
self.rid_to_queue: Dict[str, asyncio.Queue] = {}
|
|
||||||
|
|
||||||
# Processor (converts Inputs --> EngineCoreRequests).
|
# Processor (converts Inputs --> EngineCoreRequests).
|
||||||
self.processor = Processor(
|
self.processor = Processor(
|
||||||
model_config=vllm_config.model_config,
|
model_config=vllm_config.model_config,
|
||||||
@ -71,13 +68,9 @@ class AsyncLLM(EngineClient):
|
|||||||
input_registry=input_registry,
|
input_registry=input_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
|
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||||
self.detokenizer = Detokenizer(
|
self.output_processor = OutputProcessor(self.tokenizer,
|
||||||
tokenizer_name=vllm_config.model_config.tokenizer,
|
log_stats=self.log_stats)
|
||||||
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
|
||||||
revision=vllm_config.model_config.tokenizer_revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# EngineCore (starts the engine in background process).
|
# EngineCore (starts the engine in background process).
|
||||||
self.engine_core = EngineCoreClient.make_client(
|
self.engine_core = EngineCoreClient.make_client(
|
||||||
@ -140,9 +133,9 @@ class AsyncLLM(EngineClient):
|
|||||||
"""Add new request to the AsyncLLM."""
|
"""Add new request to the AsyncLLM."""
|
||||||
|
|
||||||
# 1) Create a new output queue for the request.
|
# 1) Create a new output queue for the request.
|
||||||
if request_id in self.rid_to_queue:
|
if self.output_processor.is_request_active(request_id):
|
||||||
raise ValueError(f"Request id {request_id} already running.")
|
raise ValueError(f"Request id {request_id} already running.")
|
||||||
self.rid_to_queue[request_id] = asyncio.Queue()
|
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||||
|
|
||||||
# 2) Convert Input --> Request.
|
# 2) Convert Input --> Request.
|
||||||
request = self.processor.process_inputs(request_id, prompt, params,
|
request = self.processor.process_inputs(request_id, prompt, params,
|
||||||
@ -151,8 +144,8 @@ class AsyncLLM(EngineClient):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
priority)
|
priority)
|
||||||
|
|
||||||
# 3) Add the request to Detokenizer (this process).
|
# 3) Add the request to OutputProcessor (this process).
|
||||||
self.detokenizer.add_request(request)
|
self.output_processor.add_request(request, queue)
|
||||||
|
|
||||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||||
await self.engine_core.add_request_async(request)
|
await self.engine_core.add_request_async(request)
|
||||||
@ -160,7 +153,7 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Added request %s.", request_id)
|
logger.info("Added request %s.", request_id)
|
||||||
|
|
||||||
return self.rid_to_queue[request_id]
|
return queue
|
||||||
|
|
||||||
# TODO: we should support multiple prompts in one call, as you
|
# TODO: we should support multiple prompts in one call, as you
|
||||||
# can do with LLM.generate. So that for multi-prompt completion
|
# can do with LLM.generate. So that for multi-prompt completion
|
||||||
@ -217,10 +210,9 @@ class AsyncLLM(EngineClient):
|
|||||||
# task switching under load which helps performance).
|
# task switching under load which helps performance).
|
||||||
out = q.get_nowait() if q.qsize() > 0 else await q.get()
|
out = q.get_nowait() if q.qsize() > 0 else await q.get()
|
||||||
|
|
||||||
# Note: both Detokenizer and EngineCore handle their
|
# Note: both OutputProcessor and EngineCore handle their
|
||||||
# own request cleanup based on finished.
|
# own request cleanup based on finished.
|
||||||
if out.finished:
|
if out.finished:
|
||||||
del self.rid_to_queue[request_id]
|
|
||||||
yield out
|
yield out
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -233,57 +225,51 @@ class AsyncLLM(EngineClient):
|
|||||||
await self.abort(request_id)
|
await self.abort(request_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
|
|
||||||
"""Process outputs by putting them into per-request queues."""
|
|
||||||
|
|
||||||
for request_output in request_outputs:
|
|
||||||
request_id = request_output.request_id
|
|
||||||
|
|
||||||
# Note: it is possible a request was aborted and removed from
|
|
||||||
# the state due to client cancellations, so if we encounter a
|
|
||||||
# request id not in the state, we skip.
|
|
||||||
if request_id in self.rid_to_queue:
|
|
||||||
self.rid_to_queue[request_id].put_nowait(request_output)
|
|
||||||
|
|
||||||
async def _run_output_handler(self):
|
async def _run_output_handler(self):
|
||||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# 1) Pull EngineCoreOutput from the EngineCore.
|
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||||
outputs = await self.engine_core.get_output_async()
|
outputs = await self.engine_core.get_output_async()
|
||||||
|
|
||||||
# 2) Detokenize based on the output.
|
# 2) Process EngineCoreOutputs.
|
||||||
request_outputs, reqs_to_abort = self.detokenizer.step(
|
processed_outputs = self.output_processor.process_outputs(
|
||||||
outputs.outputs)
|
outputs.outputs)
|
||||||
|
# NOTE: RequestOutputs are pushed to their queues.
|
||||||
|
assert len(processed_outputs.request_outputs) == 0
|
||||||
|
|
||||||
# 3) Put the RequestOutputs into the per-request queues.
|
# 3) Abort any reqs that finished due to stop strings.
|
||||||
self._process_request_outputs(request_outputs)
|
await self.engine_core.abort_requests_async(
|
||||||
|
processed_outputs.reqs_to_abort)
|
||||||
|
|
||||||
# 4) Abort any requests that finished due to stop strings.
|
# 4) Logging.
|
||||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
# TODO(rob): make into a coroutine and launch it in
|
||||||
|
# background thread once we add Prometheus.
|
||||||
# 5) Log any stats.
|
self._log_stats(
|
||||||
await self._log_stats(scheduler_stats=outputs.scheduler_stats)
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
|
iteration_stats=processed_outputs.iteration_stats,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("EngineCore output handler hit an error: %s", e)
|
logger.exception("EngineCore output handler hit an error: %s", e)
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort RequestId in self, detokenizer, and engine core."""
|
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||||
|
|
||||||
request_ids = [request_id]
|
request_ids = [request_id]
|
||||||
await self.engine_core.abort_requests_async(request_ids)
|
await self.engine_core.abort_requests_async(request_ids)
|
||||||
self.detokenizer.abort_requests(request_ids)
|
self.output_processor.abort_requests(request_ids)
|
||||||
|
|
||||||
# If a request finishes while we await then the request_id
|
if self.log_requests:
|
||||||
# will be removed from the tracked queues before we get here.
|
logger.info("Aborted request %s.", request_id)
|
||||||
if request_id in self.rid_to_queue:
|
|
||||||
del self.rid_to_queue[request_id]
|
|
||||||
|
|
||||||
async def _log_stats(self, scheduler_stats: SchedulerStats):
|
def _log_stats(
|
||||||
"""Log stats to the stat loggers."""
|
self,
|
||||||
|
scheduler_stats: SchedulerStats,
|
||||||
|
iteration_stats: IterationStats,
|
||||||
|
):
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -314,8 +300,7 @@ class AsyncLLM(EngineClient):
|
|||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> AnyTokenizer:
|
) -> AnyTokenizer:
|
||||||
assert lora_request is None
|
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||||
return self.detokenizer.tokenizer
|
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
async def is_tracing_enabled(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -105,7 +105,8 @@ class InprocClient(EngineCoreClient):
|
|||||||
self.engine_core.add_request(request)
|
self.engine_core.add_request(request)
|
||||||
|
|
||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self.engine_core.abort_requests(request_ids)
|
if len(request_ids) > 0:
|
||||||
|
self.engine_core.abort_requests(request_ids)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.engine_core.shutdown()
|
self.engine_core.shutdown()
|
||||||
@ -221,7 +222,8 @@ class SyncMPClient(MPClient):
|
|||||||
self._send_input(EngineCoreRequestType.ADD, request)
|
self._send_input(EngineCoreRequestType.ADD, request)
|
||||||
|
|
||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
if len(request_ids) > 0:
|
||||||
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self._send_input(EngineCoreRequestType.PROFILE,
|
self._send_input(EngineCoreRequestType.PROFILE,
|
||||||
|
|||||||
@ -1,18 +1,25 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.transformers_utils.detokenizer_utils import (
|
from vllm.transformers_utils.detokenizer_utils import (
|
||||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DetokenizerOutput:
|
||||||
|
output_text: str
|
||||||
|
token_ids: List[int]
|
||||||
|
finished: bool
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IncrementalDetokenizer:
|
class IncrementalDetokenizer:
|
||||||
|
|
||||||
@ -20,6 +27,7 @@ class IncrementalDetokenizer:
|
|||||||
output_text: str
|
output_text: str
|
||||||
tokens: List[str]
|
tokens: List[str]
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
|
prompt_len: int
|
||||||
|
|
||||||
# Stop strings
|
# Stop strings
|
||||||
stop: List[str]
|
stop: List[str]
|
||||||
@ -34,11 +42,6 @@ class IncrementalDetokenizer:
|
|||||||
spaces_between_special_tokens: bool
|
spaces_between_special_tokens: bool
|
||||||
output_kind: RequestOutputKind
|
output_kind: RequestOutputKind
|
||||||
|
|
||||||
# TODO: Probably decouple these
|
|
||||||
request_id: str
|
|
||||||
prompt: Optional[str]
|
|
||||||
prompt_token_ids: List[int]
|
|
||||||
|
|
||||||
# Tokenizer for this request
|
# Tokenizer for this request
|
||||||
tokenizer: AnyTokenizer
|
tokenizer: AnyTokenizer
|
||||||
|
|
||||||
@ -48,8 +51,7 @@ class IncrementalDetokenizer:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids(self) -> List[int]:
|
def output_token_ids(self) -> List[int]:
|
||||||
assert len(self.token_ids) >= len(self.prompt_token_ids)
|
return self.token_ids[self.prompt_len:]
|
||||||
return self.token_ids[len(self.prompt_token_ids):]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_request(
|
def from_new_request(
|
||||||
@ -87,25 +89,25 @@ class IncrementalDetokenizer:
|
|||||||
spaces_between_special_tokens=request.sampling_params.
|
spaces_between_special_tokens=request.sampling_params.
|
||||||
spaces_between_special_tokens,
|
spaces_between_special_tokens,
|
||||||
output_kind=request.sampling_params.output_kind,
|
output_kind=request.sampling_params.output_kind,
|
||||||
request_id=request.request_id,
|
prompt_len=len(request.prompt_token_ids),
|
||||||
prompt=request.prompt,
|
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
stop_buffer_length=stop_buffer_length,
|
stop_buffer_length=stop_buffer_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_tokens(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
new_token_ids: List[int],
|
output: EngineCoreOutput,
|
||||||
finish_reason: Optional[str],
|
) -> Optional[DetokenizerOutput]:
|
||||||
stop_reason: Optional[Union[int, str, None]],
|
|
||||||
) -> Optional[RequestOutput]:
|
|
||||||
"""
|
"""
|
||||||
Update RequestState for the request_id by:
|
Update RequestState for the request_id by:
|
||||||
1) Detokenize the new token ids incrementally.
|
1) Detokenize the new token ids incrementally.
|
||||||
2) Update the RequestOutput with the new text.
|
2) Update the RequestOutput with the new text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
new_token_ids = output.new_token_ids
|
||||||
|
finish_reason = output.finish_reason
|
||||||
|
stop_reason = output.stop_reason
|
||||||
|
|
||||||
# 1) Detokenize the new token ids incrementally.
|
# 1) Detokenize the new token ids incrementally.
|
||||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||||
# new_token_ids is more than 1. We need to optimize this.
|
# new_token_ids is more than 1. We need to optimize this.
|
||||||
@ -158,21 +160,8 @@ class IncrementalDetokenizer:
|
|||||||
output_text = self._get_next_output_text(finished, delta)
|
output_text = self._get_next_output_text(finished, delta)
|
||||||
token_ids = new_token_ids if delta else self.output_token_ids
|
token_ids = new_token_ids if delta else self.output_token_ids
|
||||||
|
|
||||||
request_output = RequestOutput.new(
|
return DetokenizerOutput(output_text, token_ids, finished,
|
||||||
self.request_id,
|
finish_reason, stop_reason)
|
||||||
self.prompt,
|
|
||||||
self.prompt_token_ids,
|
|
||||||
output_text,
|
|
||||||
token_ids,
|
|
||||||
finished,
|
|
||||||
)
|
|
||||||
|
|
||||||
if finished:
|
|
||||||
completion_output = request_output.outputs[0]
|
|
||||||
completion_output.finish_reason = finish_reason
|
|
||||||
completion_output.stop_reason = stop_reason
|
|
||||||
|
|
||||||
return request_output
|
|
||||||
|
|
||||||
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
|
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||||
"""If delta is True, only new text since the last call to
|
"""If delta is True, only new text since the last call to
|
||||||
@ -189,85 +178,3 @@ class IncrementalDetokenizer:
|
|||||||
self._last_output_text_offset = length
|
self._last_output_text_offset = length
|
||||||
return self.output_text[last_offset:length]
|
return self.output_text[last_offset:length]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class Detokenizer:
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
tokenizer_name: str,
|
|
||||||
tokenizer_mode: str = "auto",
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
revision: Optional[str] = None):
|
|
||||||
# TODO: once we support LoRA, we should should pass the tokenizer
|
|
||||||
# here. We currently have two copies (this + in the LLMEngine).
|
|
||||||
self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
|
||||||
tokenizer_mode=tokenizer_mode,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
revision=revision)
|
|
||||||
|
|
||||||
# Request id -> IncrementalDetokenizer
|
|
||||||
self.request_states: Dict[str, IncrementalDetokenizer] = {}
|
|
||||||
|
|
||||||
def is_request_active(self, request_id: str):
|
|
||||||
return request_id in self.request_states
|
|
||||||
|
|
||||||
def get_num_unfinished_requests(self):
|
|
||||||
return len(self.request_states)
|
|
||||||
|
|
||||||
def has_unfinished_requests(self) -> bool:
|
|
||||||
return len(self.request_states) > 0
|
|
||||||
|
|
||||||
def abort_requests(
|
|
||||||
self,
|
|
||||||
request_ids: Iterable[str],
|
|
||||||
) -> None:
|
|
||||||
"""Remove the request_ids from the Detokenizer."""
|
|
||||||
|
|
||||||
for request_id in request_ids:
|
|
||||||
self.request_states.pop(request_id, None)
|
|
||||||
|
|
||||||
def add_request(
|
|
||||||
self,
|
|
||||||
request: EngineCoreRequest,
|
|
||||||
):
|
|
||||||
"""Add new request to the Detokenizer."""
|
|
||||||
|
|
||||||
assert (request.request_id not in self.request_states)
|
|
||||||
|
|
||||||
request_state = IncrementalDetokenizer.from_new_request(
|
|
||||||
self.tokenizer, request)
|
|
||||||
self.request_states[request.request_id] = request_state
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self, encore_core_outputs: List[EngineCoreOutput]
|
|
||||||
) -> Tuple[List[RequestOutput], List[str]]:
|
|
||||||
"""Update state and request the RequestOutputs to the LLMEngine."""
|
|
||||||
|
|
||||||
request_outputs: List[RequestOutput] = []
|
|
||||||
requests_to_abort: List[str] = []
|
|
||||||
for engine_core_output in encore_core_outputs:
|
|
||||||
request_id = engine_core_output.request_id
|
|
||||||
detokenizer = self.request_states.get(request_id)
|
|
||||||
if detokenizer is None:
|
|
||||||
# Ignore output for already-aborted request.
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Detokenize and update state.
|
|
||||||
request_output = detokenizer.add_tokens(
|
|
||||||
new_token_ids=engine_core_output.new_token_ids,
|
|
||||||
finish_reason=engine_core_output.finish_reason,
|
|
||||||
stop_reason=engine_core_output.stop_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
if request_output is not None:
|
|
||||||
# Add to RequestOutputs list.
|
|
||||||
request_outputs.append(request_output)
|
|
||||||
|
|
||||||
# Free completed requests.
|
|
||||||
if request_output.finished:
|
|
||||||
self.request_states.pop(request_id)
|
|
||||||
if not engine_core_output.finished:
|
|
||||||
requests_to_abort.append(request_id)
|
|
||||||
|
|
||||||
# Return to EngineClient.
|
|
||||||
return request_outputs, requests_to_abort
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
|||||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
@ -60,13 +60,9 @@ class LLMEngine:
|
|||||||
input_registry=input_registry,
|
input_registry=input_registry,
|
||||||
mm_registry=mm_registry)
|
mm_registry=mm_registry)
|
||||||
|
|
||||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
|
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||||
self.detokenizer = Detokenizer(
|
self.output_processor = OutputProcessor(self.tokenizer,
|
||||||
tokenizer_name=vllm_config.model_config.tokenizer,
|
log_stats=False)
|
||||||
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
|
||||||
revision=vllm_config.model_config.tokenizer_revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||||
self.engine_core = EngineCoreClient.make_client(
|
self.engine_core = EngineCoreClient.make_client(
|
||||||
@ -103,10 +99,10 @@ class LLMEngine:
|
|||||||
multiprocess_mode=enable_multiprocessing)
|
multiprocess_mode=enable_multiprocessing)
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return self.detokenizer.get_num_unfinished_requests()
|
return self.output_processor.get_num_unfinished_requests()
|
||||||
|
|
||||||
def has_unfinished_requests(self) -> bool:
|
def has_unfinished_requests(self) -> bool:
|
||||||
return self.detokenizer.has_unfinished_requests()
|
return self.output_processor.has_unfinished_requests()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_outputs(cls, outputs, output_type):
|
def validate_outputs(cls, outputs, output_type):
|
||||||
@ -116,7 +112,7 @@ class LLMEngine:
|
|||||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||||
|
|
||||||
self.engine_core.abort_requests(request_ids)
|
self.engine_core.abort_requests(request_ids)
|
||||||
self.detokenizer.abort_requests(request_ids)
|
self.output_processor.abort_requests(request_ids)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
@ -137,8 +133,8 @@ class LLMEngine:
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
priority)
|
priority)
|
||||||
|
|
||||||
# 2) Add the request to Detokenizer.
|
# 2) Make a new RequestState and queue.
|
||||||
self.detokenizer.add_request(request)
|
self.output_processor.add_request(request)
|
||||||
|
|
||||||
# 3) Add the request to EngineCore.
|
# 3) Add the request to EngineCore.
|
||||||
self.engine_core.add_request(request)
|
self.engine_core.add_request(request)
|
||||||
@ -148,15 +144,14 @@ class LLMEngine:
|
|||||||
# 1) Get EngineCoreOutput from the EngineCore.
|
# 1) Get EngineCoreOutput from the EngineCore.
|
||||||
outputs = self.engine_core.get_output()
|
outputs = self.engine_core.get_output()
|
||||||
|
|
||||||
# 2) Detokenizer the EngineCoreOutput.
|
# 2) Process EngineCoreOutputs.
|
||||||
request_outputs, requests_to_abort = self.detokenizer.step(
|
processed_outputs = self.output_processor.process_outputs(
|
||||||
outputs.outputs)
|
outputs.outputs)
|
||||||
|
|
||||||
# 3) Abort requests that finished due to stopping criteria.
|
# 3) Abort any reqs that finished due to stop strings.
|
||||||
if requests_to_abort:
|
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||||
self.abort_request(requests_to_abort)
|
|
||||||
|
|
||||||
return request_outputs
|
return processed_outputs.request_outputs
|
||||||
|
|
||||||
def get_model_config(self):
|
def get_model_config(self):
|
||||||
return self.model_config
|
return self.model_config
|
||||||
|
|||||||
200
vllm/v1/engine/output_processor.py
Normal file
200
vllm/v1/engine/output_processor.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||||
|
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
|
||||||
|
IncrementalDetokenizer)
|
||||||
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutputProcessorOutput:
|
||||||
|
|
||||||
|
request_outputs: List[RequestOutput]
|
||||||
|
reqs_to_abort: List[str]
|
||||||
|
iteration_stats: IterationStats
|
||||||
|
|
||||||
|
|
||||||
|
class RequestState:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
prompt_token_ids: List[int],
|
||||||
|
detokenizer: IncrementalDetokenizer,
|
||||||
|
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||||
|
):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.prompt = prompt
|
||||||
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
self.prompt_len = len(prompt_token_ids)
|
||||||
|
self.detokenizer = detokenizer
|
||||||
|
self.is_prefilling = True
|
||||||
|
self.queue = queue
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_new_request(
|
||||||
|
cls,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request: EngineCoreRequest,
|
||||||
|
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
||||||
|
) -> "RequestState":
|
||||||
|
return cls(
|
||||||
|
request_id=request.request_id,
|
||||||
|
prompt=request.prompt,
|
||||||
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
|
detokenizer=IncrementalDetokenizer.from_new_request(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
request=request,
|
||||||
|
),
|
||||||
|
queue=queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputProcessor:
|
||||||
|
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: BaseTokenizerGroup,
|
||||||
|
log_stats: bool,
|
||||||
|
):
|
||||||
|
self.log_stats = log_stats
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.request_states: Dict[str, RequestState] = {}
|
||||||
|
|
||||||
|
def is_request_active(self, request_id: str) -> bool:
|
||||||
|
return request_id in self.request_states
|
||||||
|
|
||||||
|
def get_num_unfinished_requests(self):
|
||||||
|
return len(self.request_states)
|
||||||
|
|
||||||
|
def has_unfinished_requests(self) -> bool:
|
||||||
|
return len(self.request_states) > 0
|
||||||
|
|
||||||
|
def abort_requests(
|
||||||
|
self,
|
||||||
|
request_ids: List[str],
|
||||||
|
) -> None:
|
||||||
|
for request_id in request_ids:
|
||||||
|
self.request_states.pop(request_id, None)
|
||||||
|
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request: EngineCoreRequest,
|
||||||
|
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
||||||
|
) -> None:
|
||||||
|
request_id = request.request_id
|
||||||
|
if request_id in self.request_states:
|
||||||
|
raise ValueError(f"Request id {request_id} already running.")
|
||||||
|
|
||||||
|
self.request_states[request_id] = RequestState.from_new_request(
|
||||||
|
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||||
|
request=request,
|
||||||
|
queue=queue)
|
||||||
|
|
||||||
|
def process_outputs(
|
||||||
|
self,
|
||||||
|
engine_core_outputs: List[EngineCoreOutput],
|
||||||
|
) -> OutputProcessorOutput:
|
||||||
|
"""
|
||||||
|
Process the EngineCoreOutputs:
|
||||||
|
1) Compute stats for logging
|
||||||
|
2) Detokenize
|
||||||
|
3) Create and handle RequestOutput objects:
|
||||||
|
* If there is a queue (for usage with AsyncLLM),
|
||||||
|
put the RequestOutput objects into the queue for
|
||||||
|
handling by the per-request generate() tasks.
|
||||||
|
|
||||||
|
* If there is no queue (for usage with LLMEngine),
|
||||||
|
return a list of RequestOutput objects.
|
||||||
|
|
||||||
|
****************** NOTE FOR DEVELOPERS ******************
|
||||||
|
|
||||||
|
VLLM V1 minimizes the number of python loops over the full
|
||||||
|
batch to ensure system overheads are minimized. This is the
|
||||||
|
only function that should loop over EngineCoreOutputs.
|
||||||
|
|
||||||
|
If you need to touch every element of the batch, implement a
|
||||||
|
method called XXXClass.update_from_output() to be called
|
||||||
|
within the loop below. For examples, see:
|
||||||
|
* IterationStats.update_from_output()
|
||||||
|
* Detokenizer.update_from_output()
|
||||||
|
|
||||||
|
TODO(rob): add Protocol makes update_from_output explicit.
|
||||||
|
|
||||||
|
**********************************************************
|
||||||
|
"""
|
||||||
|
|
||||||
|
request_outputs: List[RequestOutput] = []
|
||||||
|
reqs_to_abort: List[str] = []
|
||||||
|
iteration_stats = IterationStats(self.log_stats)
|
||||||
|
for engine_core_output in engine_core_outputs:
|
||||||
|
req_id = engine_core_output.request_id
|
||||||
|
req_state = self.request_states.get(req_id)
|
||||||
|
if req_state is None:
|
||||||
|
# Ignore output for already-aborted request.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 1) Compute stats for this iteration.
|
||||||
|
iteration_stats.update_from_output(engine_core_output,
|
||||||
|
req_state.is_prefilling,
|
||||||
|
req_state.prompt_len)
|
||||||
|
req_state.is_prefilling = False
|
||||||
|
|
||||||
|
# 2) Detokenize the token ids into text.
|
||||||
|
detokenizer_output = req_state.detokenizer.update_from_output(
|
||||||
|
engine_core_output)
|
||||||
|
|
||||||
|
# 3) Create and handle RequestOutput objects.
|
||||||
|
if request_output := self._make_request_output(
|
||||||
|
req_state, detokenizer_output):
|
||||||
|
if req_state.queue is not None:
|
||||||
|
# AsyncLLM: put into queue for handling by generate().
|
||||||
|
req_state.queue.put_nowait(request_output)
|
||||||
|
else:
|
||||||
|
# LLMEngine: return list of RequestOutputs.
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
|
# Free completed requests.
|
||||||
|
if request_output.finished:
|
||||||
|
self.request_states.pop(req_id)
|
||||||
|
if not engine_core_output.finished:
|
||||||
|
# If req not finished in EngineCore, but Detokenizer
|
||||||
|
# detected stop string, abort needed in EngineCore.
|
||||||
|
reqs_to_abort.append(req_id)
|
||||||
|
|
||||||
|
return OutputProcessorOutput(
|
||||||
|
request_outputs=request_outputs,
|
||||||
|
reqs_to_abort=reqs_to_abort,
|
||||||
|
iteration_stats=iteration_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_request_output(
|
||||||
|
self,
|
||||||
|
request_state: RequestState,
|
||||||
|
detokenizer_output: Optional[DetokenizerOutput],
|
||||||
|
) -> Optional[RequestOutput]:
|
||||||
|
|
||||||
|
if detokenizer_output is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
request_output = RequestOutput.new(
|
||||||
|
request_state.request_id,
|
||||||
|
request_state.prompt,
|
||||||
|
request_state.prompt_token_ids,
|
||||||
|
detokenizer_output.output_text,
|
||||||
|
detokenizer_output.token_ids,
|
||||||
|
detokenizer_output.finished,
|
||||||
|
)
|
||||||
|
if detokenizer_output.finished:
|
||||||
|
completion_output = request_output.outputs[0]
|
||||||
|
completion_output.finish_reason = detokenizer_output.finish_reason
|
||||||
|
completion_output.stop_reason = detokenizer_output.stop_reason
|
||||||
|
|
||||||
|
return request_output
|
||||||
@ -1,4 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.engine import EngineCoreOutput
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -10,3 +14,26 @@ class SchedulerStats:
|
|||||||
|
|
||||||
# gpu_cache_usage: float = 0.0
|
# gpu_cache_usage: float = 0.0
|
||||||
# gpu_prefix_cache_hit_rate: float = 0.0
|
# gpu_prefix_cache_hit_rate: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class IterationStats:
|
||||||
|
"""Stats associated with a single set of EngineCoreOutputs."""
|
||||||
|
|
||||||
|
def __init__(self, log_stats: bool):
|
||||||
|
self.log_stats = log_stats
|
||||||
|
self.num_generation_tokens = 0
|
||||||
|
self.num_prompt_tokens = 0
|
||||||
|
|
||||||
|
def update_from_output(self, output: "EngineCoreOutput",
|
||||||
|
is_prefilling: bool, prompt_len: int):
|
||||||
|
if not self.log_stats:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.num_generation_tokens += len(output.new_token_ids)
|
||||||
|
if is_prefilling:
|
||||||
|
# This relies on the invariant that EngineCore does
|
||||||
|
# not stream outputs for partially completed prefills
|
||||||
|
# (scheduler.update_from_output makes EngineCoreOutput
|
||||||
|
# iff num_computed_tokens == num_tokens).
|
||||||
|
assert (len(output.new_token_ids) > 0)
|
||||||
|
self.num_prompt_tokens += prompt_len
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user