From 6ace6fba2ca42b79a948a9b47af00487b5f73868 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:05:38 -0500 Subject: [PATCH] [V1] `AsyncLLM` Implementation (#9826) Signed-off-by: Nick Hill Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Nick Hill Co-authored-by: Nick Hill Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Nick Hill Co-authored-by: Tyler Michael Smith --- .buildkite/test-pipeline.yaml | 8 + tests/entrypoints/llm/test_accuracy.py | 56 ++ tests/entrypoints/openai/test_accuracy.py | 25 +- {vllm/v1/tokenizer => tests/v1}/__init__.py | 0 tests/v1/engine/__init__.py | 0 tests/v1/engine/test_async_llm.py | 66 +++ tests/v1/engine/test_detokenizer.py | 205 +++++++ tests/v1/engine/test_engine_core.py | 137 +++++ tests/v1/engine/test_engine_core_client.py | 202 +++++++ vllm/config.py | 41 ++ vllm/engine/multiprocessing/engine.py | 13 +- vllm/engine/output_processor/stop_checker.py | 43 +- vllm/entrypoints/llm.py | 3 + vllm/entrypoints/openai/api_server.py | 11 +- vllm/envs.py | 5 + vllm/outputs.py | 30 + vllm/v1/__init__.py | 0 vllm/v1/core/kv_cache_manager.py | 13 +- vllm/v1/core/scheduler.py | 26 +- vllm/v1/engine/__init__.py | 72 +++ vllm/v1/engine/async_llm.py | 368 +++++++++++++ vllm/v1/engine/async_stream.py | 55 ++ vllm/v1/engine/core.py | 352 ++++++++++++ vllm/v1/engine/core_client.py | 218 ++++++++ vllm/v1/engine/detokenizer.py | 265 +++++++++ vllm/v1/engine/llm_engine.py | 546 ++++--------------- vllm/v1/engine/processor.py | 128 +++++ vllm/v1/request.py | 17 +- vllm/v1/tokenizer/detokenizer.py | 228 -------- 29 files changed, 2409 insertions(+), 724 deletions(-) create mode 100644 tests/entrypoints/llm/test_accuracy.py rename {vllm/v1/tokenizer => tests/v1}/__init__.py (100%) create mode 100644 tests/v1/engine/__init__.py create mode 100644 tests/v1/engine/test_async_llm.py create mode 100644 tests/v1/engine/test_detokenizer.py create mode 100644 tests/v1/engine/test_engine_core.py create mode 100644 tests/v1/engine/test_engine_core_client.py create mode 100644 vllm/v1/__init__.py create mode 100644 vllm/v1/engine/async_llm.py create mode 100644 vllm/v1/engine/async_stream.py create mode 100644 vllm/v1/engine/core.py create mode 100644 vllm/v1/engine/core_client.py create mode 100644 vllm/v1/engine/detokenizer.py create mode 100644 vllm/v1/engine/processor.py delete mode 100644 vllm/v1/tokenizer/detokenizer.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e8456357e6db1..fbaa427bb7270 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -165,6 +165,14 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization +- label: V1 Test + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1 + - label: Examples Test # 15min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py new file mode 100644 index 0000000000000..6bf7190a656b8 --- /dev/null +++ b/tests/entrypoints/llm/test_accuracy.py @@ -0,0 +1,56 @@ +""" +This file test accuracy of the vLLM server via LMEval. +It uses local-completions, which interacts with vLLM +through the OAI API with N concurrent connections. +This simulates real work usage of the API and makes +sure that the zmq frontend mp RPC message passing and +AsyncLLMEngine are working correctly. +""" + +import lm_eval +import pytest + +from vllm.platforms import current_platform + +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +NUM_CONCURRENT = 500 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.58 + + +def run_test(): + """Run the end to end accuracy test.""" + + model_args = f"pretrained={MODEL_NAME},max_model_len=2048" + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks="gsm8k", + batch_size="auto", + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="V1 is currently only supported on CUDA.") +def test_lm_eval_accuracy_v1_engine(monkeypatch): + """Run with the V1 Engine.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + run_test() + + +def test_lm_eval_accuracy_v0_engine(monkeypatch): + """Run with the V0 Engine.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + run_test() diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index a16e95f94171e..b1d4461d164aa 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -37,11 +37,11 @@ if current_platform.is_tpu(): MAX_WAIT_SECONDS = 600 -@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) -def test_lm_eval_accuracy(more_args): +def run_test(more_args): + """Run the end to end accuracy test.""" + args = list(DEFAULT_ARGS) args.extend(more_args) - print(f"Running with: {args}") with RemoteOpenAIServer( @@ -64,3 +64,22 @@ def test_lm_eval_accuracy(more_args): assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="V1 currently only supported on CUDA") +def test_lm_eval_accuracy_v1_engine(monkeypatch): + """Run with the V1 Engine.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + run_test([]) + + +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy_v0_engine(monkeypatch, more_args): + """Run with the V0 Engine.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "0") + run_test(more_args) diff --git a/vllm/v1/tokenizer/__init__.py b/tests/v1/__init__.py similarity index 100% rename from vllm/v1/tokenizer/__init__.py rename to tests/v1/__init__.py diff --git a/tests/v1/engine/__init__.py b/tests/v1/engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py new file mode 100644 index 0000000000000..1f26fe0fc892f --- /dev/null +++ b/tests/v1/engine/test_async_llm.py @@ -0,0 +1,66 @@ +import asyncio +from typing import Tuple + +import pytest + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.platforms import current_platform +from vllm.v1.engine.async_llm import AsyncLLM + +if not current_platform.is_cuda(): + pytest.skip(reason="V1 currently only supported on CUDA.", + allow_module_level=True) + +ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", + disable_log_requests=True) + + +async def generate(engine: AsyncLLM, request_id: str, + max_tokens: int) -> Tuple[int, str]: + count = 0 + async for _ in engine.generate(request_id=request_id, + prompt="Hello my name is Robert and", + sampling_params=SamplingParams( + max_tokens=max_tokens, temperature=0)): + + count += 1 + await asyncio.sleep(0.) + + return count, request_id + + +@pytest.mark.asyncio +async def test_load(monkeypatch): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + + NUM_REQUESTS = 10000 + NUM_EXPECTED_TOKENS = 10 + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(engine, request_id, NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + engine.shutdown() diff --git a/tests/v1/engine/test_detokenizer.py b/tests/v1/engine/test_detokenizer.py new file mode 100644 index 0000000000000..07f343666cb5e --- /dev/null +++ b/tests/v1/engine/test_detokenizer.py @@ -0,0 +1,205 @@ +from typing import List + +import pytest +from transformers import AutoTokenizer + +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine import EngineCoreOutput +from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest + +TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" +tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + +FULL_STRINGS = [ + "My name is Robert from Neural Magic and I love working on vLLM so much!", + "Red Hat is the best open source company by far across Linux, K8s, and AI.", + "Nick is the name of my brother in addition to my colleague from Red Hat.", +] + +STOP_STRINGS = ["I love working on", "company by far", "brother in"] + +FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS] +PROMPT_LEN = 5 +PROMPT_TOKENS = [ + tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS +] +GENERATION_TOKENS = [ + tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS +] +PROMPT_STRINGS = [ + tokenizer.decode(prompt_tokens, skip_special_tokens=True) + for prompt_tokens in PROMPT_TOKENS +] +PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS] +GENERATION_STRINGS = [ + text[prompt_len:] + for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN) +] + + +class MockEngineCore: + """Mock outputs form premade tokens lists.""" + + def __init__(self, tokens_list: List[List[int]]): + self.tokens_list = tokens_list + self.current_idx = 0 + + def get_outputs(self) -> List[EngineCoreOutput]: + token_idx = self.current_idx + self.current_idx += 1 + + outputs = [] + for req_idx, token_ids in enumerate(self.tokens_list): + if len(token_ids) > token_idx: + output = EngineCoreOutput(request_id=f"request-{req_idx}", + new_token_ids=[token_ids[token_idx]], + finished=False) + if token_idx == len(token_ids) - 1: + output.finished = True + output.finish_reason = "stopped" + outputs.append(output) + + return outputs + + +@pytest.mark.parametrize( + "request_output_kind", + [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +def test_incremental_detokenization(request_output_kind: RequestOutputKind): + detokenizer = Detokenizer(TOKENIZER_NAME) + engine_core = MockEngineCore(GENERATION_TOKENS) + + # Make N requests. + requests = [ + DetokenizerRequest( + request_id=f"request-{idx}", + prompt=prompt, + prompt_token_ids=prompt_tokens, + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + ) for idx, ( + prompt, + prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ] + + # Add requests to the detokenizer. + for request in requests: + detokenizer.add_request(request) + + gen_strings = {} + gen_tokens = {} + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the Detokenizer. + request_outputs, requests_to_abort = detokenizer.step(outputs) + assert len(requests_to_abort) == 0 + + # Update tracking. + for request_output in request_outputs: + request_id = request_output.request_id + new_text = request_output.outputs[0].text + new_tokens = request_output.outputs[0].token_ids + if request_id not in gen_strings: + gen_strings[request_id] = new_text + gen_tokens[request_id] = new_tokens + else: + gen_strings[request_id] += new_text + gen_tokens[request_id].extend(new_tokens) + + # Confirmed tracked values matches what we expected. + for idx, (ref_gen_str, ref_gen_toks) in enumerate( + zip(GENERATION_STRINGS, GENERATION_TOKENS)): + gen_str = gen_strings[f"request-{idx}"] + gen_toks = gen_tokens[f"request-{idx}"] + + assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}" + assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}" + + assert detokenizer.get_num_unfinished_requests() == 0 + assert not detokenizer.has_unfinished_requests() + + +@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) +def test_stop_string(include_stop_str_in_output: bool): + detokenizer = Detokenizer(TOKENIZER_NAME) + engine_core = MockEngineCore(GENERATION_TOKENS) + + # Make N requests. + requests = [ + DetokenizerRequest( + request_id=f"request-{idx}", + prompt=prompt, + prompt_token_ids=prompt_tokens, + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=RequestOutputKind.DELTA, + stop=STOP_STRINGS, + include_stop_str_in_output=include_stop_str_in_output, + ) for idx, ( + prompt, + prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ] + + # Add requests to the detokenizer. + for request in requests: + detokenizer.add_request(request) + + gen_strings = {} + aborted = [] + while True: + # Mock output from the EngineCore. + outputs = engine_core.get_outputs() + if len(outputs) == 0: + break + + # Step the Detokenizer. + request_outputs, requests_to_abort = detokenizer.step(outputs) + for request_output in request_outputs: + # If aborted, we should not get a request output. + assert request_output.request_id not in aborted + aborted.extend(requests_to_abort) + + # Update tracking. + for request_output in request_outputs: + if request_output.finished: + assert request_output.outputs[0].finish_reason == "stop" + + request_id = request_output.request_id + new_text = request_output.outputs[0].text + if request_id not in gen_strings: + gen_strings[request_id] = new_text + else: + gen_strings[request_id] += new_text + + # Confirmed tracked values matches what we expected. + for idx, (ref_gen_str, + stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)): + + # Request should be aborted. + request_id = f"request-{idx}" + assert request_id in aborted + + # Collected values that were generated. + gen_str = gen_strings[request_id] + + # Construct reference strings. + stop_str_idx = ref_gen_str.find(stop_str) + ref_str_exc_stop = ref_gen_str[:stop_str_idx] + ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str + + if include_stop_str_in_output: + assert gen_str == ref_str_inc_stop, ( + f"{gen_str=}, {ref_str_inc_stop=}") + else: + assert gen_str == ref_str_exc_stop, ( + f"{gen_str=}, {ref_str_exc_stop=}") + + assert detokenizer.get_num_unfinished_requests() == 0 + assert not detokenizer.has_unfinished_requests() diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py new file mode 100644 index 0000000000000..8451aac33acc4 --- /dev/null +++ b/tests/v1/engine/test_engine_core.py @@ -0,0 +1,137 @@ +import time +import uuid + +import pytest +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.core import EngineCore + +if not current_platform.is_cuda(): + pytest.skip(reason="V1 currently only supported on CUDA.", + allow_module_level=True) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) +PROMPT = "Hello my name is Robert and I love quantization kernels" +PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids + + +def make_request() -> EngineCoreRequest: + return EngineCoreRequest( + request_id=uuid.uuid4(), + prompt=PROMPT, + prompt_token_ids=PROMPT_TOKENS, + sampling_params=SamplingParams(), + eos_token_id=None, + arrival_time=time.time(), + lora_request=None, + ) + + +def test_engine_core(monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + """Setup the EngineCore.""" + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = AsyncLLM._get_executor_cls(vllm_config) + + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + usage_context=UsageContext.UNKNOWN_CONTEXT) + """Test basic request lifecycle.""" + + # First request. + engine_core.add_request(make_request()) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 1 + + # Second request. + engine_core.add_request(make_request()) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 1 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 + + # Add two requests in a row. + engine_core.add_request(make_request()) + engine_core.add_request(make_request()) + assert len(engine_core.scheduler.waiting) == 2 + assert len(engine_core.scheduler.running) == 2 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 4 + + # Loop through until they are all done. + while len(engine_core.step()) > 0: + pass + + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 + """Test abort cycle.""" + + # Basic abort. + req = make_request() + request_id = req.request_id + + engine_core.add_request(req) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 1 + + engine_core.abort_requests([request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 + + # Add, step, abort 1 of the 3. + req0 = make_request() + req1 = make_request() + req2 = make_request() + + engine_core.add_request(req0) + engine_core.add_request(req1) + assert len(engine_core.scheduler.waiting) == 2 + assert len(engine_core.scheduler.running) == 0 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 + + engine_core.add_request(req2) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 2 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 3 + + # Abort just one. + engine_core.abort_requests([req1.request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 + + _ = engine_core.step() + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 2 + + # Abort the other requests at the same time. + engine_core.abort_requests([req2.request_id, req0.request_id]) + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py new file mode 100644 index 0000000000000..d582101a1164f --- /dev/null +++ b/tests/v1/engine/test_engine_core_client.py @@ -0,0 +1,202 @@ +import asyncio +import time +import uuid +from typing import Dict, List + +import pytest +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.core_client import EngineCoreClient + +if not current_platform.is_cuda(): + pytest.skip(reason="V1 currently only supported on CUDA.", + allow_module_level=True) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) +PROMPT = "Hello my name is Robert and I love quantization kernels" +PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids + + +def make_request(params: SamplingParams) -> EngineCoreRequest: + return EngineCoreRequest( + request_id=str(uuid.uuid4()), + prompt=PROMPT, + prompt_token_ids=PROMPT_TOKENS, + sampling_params=params, + eos_token_id=None, + arrival_time=time.time(), + lora_request=None, + ) + + +def loop_until_done(client: EngineCoreClient, outputs: Dict): + + while True: + engine_core_outputs = client.get_output() + + if len(engine_core_outputs) == 0: + break + + all_finished = True + for out in engine_core_outputs: + outputs[out.request_id].append(out) + if not out.finished: + all_finished = False + + if all_finished: + break + + +async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): + + while True: + engine_core_outputs = await client.get_output_async() + + if len(engine_core_outputs) == 0: + break + + all_finished = True + for out in engine_core_outputs: + outputs[out.request_id].append(out) + if not out.finished: + all_finished = False + + if all_finished: + break + + +@pytest.mark.parametrize("multiprocessing_mode", [True, False]) +def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = AsyncLLM._get_executor_cls(vllm_config) + client = EngineCoreClient.make_client( + vllm_config, + executor_class, + UsageContext.UNKNOWN_CONTEXT, + multiprocess_mode=multiprocessing_mode, + asyncio_mode=False, + ) + + MAX_TOKENS = 20 + params = SamplingParams(max_tokens=MAX_TOKENS) + """Normal Request Cycle.""" + requests = [make_request(params) for _ in range(10)] + request_ids = [req.request_id for req in requests] + + # Add requests to the engine. + for request in requests: + client.add_request(request) + time.sleep(0.01) + + outputs: Dict[str, List] = {req_id: [] for req_id in request_ids} + loop_until_done(client, outputs) + + for req_id in request_ids: + assert len(outputs[req_id]) == MAX_TOKENS, ( + f"{outputs[req_id]=}, {MAX_TOKENS=}") + """Abort Request Cycle.""" + + # Note: this code pathway will only work for multiprocessing + # since we have to call get_output() explicitly + + # Add requests to the engine. + for idx, request in enumerate(requests): + client.add_request(request) + time.sleep(0.01) + if idx % 2 == 0: + client.abort_requests([request.request_id]) + + outputs = {req_id: [] for req_id in request_ids} + loop_until_done(client, outputs) + + for idx, req_id in enumerate(request_ids): + if idx % 2 == 0: + assert len(outputs[req_id]) < MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + else: + assert len(outputs[req_id]) == MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + """Abort after request is finished.""" + + # Note: this code pathway will only work for multiprocessing + # since we have to call get_output() explicitly + + request = requests[0] + client.add_request(request) + time.sleep(10.) + + client.abort_requests([request.request_id]) + + # Shutdown the client. + client.shutdown() + + +@pytest.mark.asyncio +async def test_engine_core_client_asyncio(monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = AsyncLLM._get_executor_cls(vllm_config) + client = EngineCoreClient.make_client( + vllm_config, + executor_class, + UsageContext.UNKNOWN_CONTEXT, + multiprocess_mode=True, + asyncio_mode=True, + ) + + MAX_TOKENS = 20 + params = SamplingParams(max_tokens=MAX_TOKENS) + """Normal Request Cycle.""" + + requests = [make_request(params) for _ in range(10)] + request_ids = [req.request_id for req in requests] + + # Add requests to the engine. + for request in requests: + await client.add_request_async(request) + await asyncio.sleep(0.01) + + outputs: Dict[str, List] = {req_id: [] for req_id in request_ids} + await loop_until_done_async(client, outputs) + + for req_id in request_ids: + assert len(outputs[req_id]) == MAX_TOKENS, ( + f"{outputs[req_id]=}, {MAX_TOKENS=}") + """Abort Request Cycle.""" + + # Add requests to the engine. + for idx, request in enumerate(requests): + await client.add_request_async(request) + await asyncio.sleep(0.01) + if idx % 2 == 0: + await client.abort_requests_async([request.request_id]) + + outputs = {req_id: [] for req_id in request_ids} + await loop_until_done_async(client, outputs) + + for idx, req_id in enumerate(request_ids): + if idx % 2 == 0: + assert len(outputs[req_id]) < MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + else: + assert len(outputs[req_id]) == MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + + # Shutdown the client. + client.shutdown() diff --git a/vllm/config.py b/vllm/config.py index f9b230e1bc688..dc9c06d7fb16e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2106,3 +2106,44 @@ class VllmConfig: self.model_config is not None and self.load_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + + def __str__(self): + return ("model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ + (self.model_config.model, self.speculative_config, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.override_neuron_config, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.disable_custom_all_reduce, + self.model_config.quantization, + self.model_config.enforce_eager, + self.cache_config.cache_dtype, + self.model_config.quantization_param_path, + self.device_config.device, self.decoding_config, + self.observability_config, self.model_config.seed, + self.model_config.served_model_name, + self.scheduler_config.num_scheduler_steps, + self.cache_config.enable_prefix_caching, + self.model_config.use_async_output_proc, + self.model_config.mm_processor_kwargs) \ No newline at end of file diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 889845ee67312..7de23643a2e1c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -6,7 +6,6 @@ from typing import Iterator, List, Optional, Union import cloudpickle import zmq -import vllm.envs from vllm import AsyncEngineArgs, SamplingParams from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block @@ -113,17 +112,9 @@ class MQLLMEngine: load_general_plugins() engine_config = engine_args.create_engine_config() - if vllm.envs.VLLM_USE_V1: - # Lazy import: the v1 package isn't distributed - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - engine_class = V1LLMEngine - else: - engine_class = LLMEngine + executor_class = LLMEngine._get_executor_cls(engine_config) - executor_class = engine_class._get_executor_cls(engine_config) - - use_async_sockets = (engine_config.model_config.use_async_output_proc - and not vllm.envs.VLLM_USE_V1) + use_async_sockets = engine_config.model_config.use_async_output_proc return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index a71ad493d9920..4b701f81504bb 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -67,9 +67,13 @@ class StopChecker: return # Check if any stop strings are matched. - stop_str = self._check_stop_strings(seq, new_char_count, - sampling_params) - if stop_str is not None: + stop = self.check_stop_strings( + seq.output_text, new_char_count, sampling_params.stop, + sampling_params.include_stop_str_in_output) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + seq.output_text = seq.output_text[:truncate_to] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = stop_str return @@ -85,33 +89,40 @@ class StopChecker: return @staticmethod - def _check_stop_strings(seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> Optional[str]: + def check_stop_strings( + output_text: str, + new_char_count: int, + stop: List[str], + include_in_output: bool, + ) -> Optional[Tuple[str, int]]: """Check if any stop strings are matched and truncate sequence output text accordingly. - Returns the stop string if matched or else None. + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. """ - if not new_char_count or not sampling_params.stop: + if not new_char_count or not stop: return None - for stop_str in sampling_params.stop: + for stop_str in stop: stop_string_len = len(stop_str) # Avoid searching already-searched text. - stop_index = seq.output_text.find( - stop_str, -new_char_count - stop_string_len) + stop_index = output_text.find(stop_str, + -new_char_count - stop_string_len) if stop_index == -1: continue - if sampling_params.include_stop_str_in_output: + if include_in_output: # Truncate to end of stop string. stop_index += stop_string_len - if stop_index >= len(seq.output_text): + if stop_index >= len(output_text): # No truncation required. - return stop_str + return stop_str, -1 # Truncate the output text to either the beginning # or end of the stop string. - seq.output_text = seq.output_text[:stop_index] - return stop_str + return stop_str, stop_index return None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f830839776364..a15dbd1c45119 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -210,8 +210,11 @@ class LLM: # Logic to switch between engines is done at runtime instead of import # to avoid import order issues self.engine_class = self.get_engine_class() + + # TODO(rob): enable mp by default (issue with fork vs spawn) self.llm_engine = self.engine_class.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS) + self.request_counter = Counter() @staticmethod diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8b7912742d45..3e4070a25cf90 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -26,7 +26,6 @@ from typing_extensions import assert_never import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient @@ -61,6 +60,11 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.version import __version__ as VLLM_VERSION +if envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore +else: + from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore + TIMEOUT_KEEP_ALIVE = 5 # seconds prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -126,7 +130,8 @@ async def build_async_engine_client_from_engine_args( # Fall back # TODO: fill out feature matrix. if (MQLLMEngineClient.is_unsupported_config(engine_args) - or disable_frontend_multiprocessing): + or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): + engine_config = engine_args.create_engine_config() uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), "uses_ray", False) @@ -143,6 +148,8 @@ async def build_async_engine_client_from_engine_args( None, build_engine) yield engine_client + if hasattr(engine_client, "shutdown"): + engine_client.shutdown() return # Otherwise, use the multiprocessing AsyncLLMEngine. diff --git a/vllm/envs.py b/vllm/envs.py index 154246c69f165..f320e35971f94 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,7 @@ if TYPE_CHECKING: VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_ENABLE_V1_MULTIPROCESSING: bool = False def get_default_cache_root(): @@ -473,6 +474,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { # If set, use the V1 code path. "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + + # If set, enable multiprocessing in LLM for the V1 code path. + "VLLM_ENABLE_V1_MULTIPROCESSING": + lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), } # end-env-vars-definition diff --git a/vllm/outputs.py b/vllm/outputs.py index 951976310e7ae..abfdb7d328126 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -113,6 +113,36 @@ class RequestOutput: self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids + @classmethod + def new( + cls, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + text: str, + token_ids: List[int], + finished: bool = False, + ) -> "RequestOutput": + """Initialize a new RequestOutput object.""" + + # TODO: Support `n` > 1. + completion_output = CompletionOutput( + index=0, + text=text, + token_ids=token_ids, + cumulative_logprob=None, + logprobs=None, # TODO + ) + + return RequestOutput( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, # TODO + outputs=[completion_output], + finished=finished, + ) + @classmethod def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, diff --git a/vllm/v1/__init__.py b/vllm/v1/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 82094fb65dd1a..38f1c03a4d3ac 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -70,7 +70,7 @@ class KVCacheManager: Args: request: The request to get the computed blocks. - + Returns: A list of blocks that are computed for the request. """ @@ -105,7 +105,7 @@ class KVCacheManager: Args: request: The request to append slots. num_tokens: The number of tokens to append. - + Returns: A list of new blocks if new blocks are allocated, or None if new blocks are required but cannot be allocated. @@ -176,7 +176,7 @@ class KVCacheManager: num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. computed_blocks: The blocks that have already been computed. - + Returns: A list of new allocated blocks. """ @@ -240,7 +240,8 @@ class KVCacheManager: Args: request: The request to free the blocks. """ - blocks = self.req_to_blocks.pop(request.request_id) + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request.request_id, []) if self.enable_caching: # Free blocks in reverse order so that the tail blocks are # freed first. @@ -259,13 +260,13 @@ class KVCacheManager: """Get new blocks from the free block pool, and add token IDs to allocated blocks if caching is enabled. Note that we do not check block cache in this function. - + Args: num_blocks: The number of blocks to allocate. token_ids: The token IDs in the blocks. None if caching is disabled. parent_block: The parent block. Used to include block chain in the block hash. - + Returns: A list of new block. """ diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a60f8b8138ecf..ee860e792281d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,12 +1,13 @@ from collections import deque from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Deque, Dict, Iterable, List, Optional, Set, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -237,13 +238,12 @@ class Scheduler: self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> List[Tuple[Request, int]]: + ) -> List[EngineCoreOutput]: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist() num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] - # (request, num_sampled_tokens) - sampled: List[Tuple[Request, int]] = [] + engine_core_outputs: List[EngineCoreOutput] = [] for request in self.running: req_id = request.request_id request.num_computed_tokens += num_scheduled_tokens[req_id] @@ -257,17 +257,29 @@ class Scheduler: # generates at most one token at each step. token_id = sampled_token_ids[req_index] request.append_output_token_ids(token_id) - sampled.append((request, 1)) + num_new_tokens = 1 # TODO: Update the KV cache manager for prefix caching. - # Check if the request is finished. + # Check for stop and update request state. + # This must be called before me make the EngineCoreOutput. stopped = self._check_stop(request) + + # Add EngineCoreOutput for this Request. + output = EngineCoreOutput( + request_id=req_id, + new_token_ids=request.output_token_ids[-num_new_tokens:], + finished=request.is_finished(), + finish_reason=request.get_finished_reason(), + stop_reason=request.stop_reason) + engine_core_outputs.append(output) + + # Breakout of the loop. if stopped: continue new_running.append(request) self.running = new_running - return sampled + return engine_core_outputs def _check_stop(self, request: Request) -> bool: if (request.num_tokens >= self.max_model_len diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e69de29bb2d1d..8bc16651faf97 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -0,0 +1,72 @@ +import enum +from dataclasses import dataclass +from typing import List, Optional, Union + +import msgspec + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind, SamplingParams + + +@dataclass +class DetokenizerRequest: + + request_id: str + prompt: Optional[str] + prompt_token_ids: List[int] + skip_special_tokens: bool + spaces_between_special_tokens: bool + output_kind: RequestOutputKind + + stop: List[str] + include_stop_str_in_output: bool + + +class EngineCoreRequest(msgspec.Struct, omit_defaults=True): + + # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, + # but this object is currently not playing well with msgspec + # due to circular imports and typing we have in data.py + + request_id: str + #NOTE(Nick): I don't think we need to pass prompt here since it should + # always be tokenized? + prompt: Optional[str] + prompt_token_ids: List[int] + sampling_params: SamplingParams + eos_token_id: Optional[int] + arrival_time: float + lora_request: Optional[LoRARequest] + + +class EngineCoreOutput(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): + + request_id: str + new_token_ids: List[int] + finished: bool + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + + +class EngineCoreOutputs(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): + + #NOTE(Nick): We could consider ways to make this more compact, + # e.g. columnwise layout and using an int enum for finish/stop reason + + # [num_reqs] + outputs: List[EngineCoreOutput] + + +class EngineCoreRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + ADD = b'\x00' + ABORT = b'\x01' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py new file mode 100644 index 0000000000000..2d7c58cfea13b --- /dev/null +++ b/vllm/v1/engine/async_llm.py @@ -0,0 +1,368 @@ +import asyncio +from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union + +from vllm.config import ModelConfig, VllmConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient +from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.async_stream import AsyncStream +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.processor import Processor +from vllm.v1.executor.gpu_executor import GPUExecutor + +logger = init_logger(__name__) + + +class AsyncLLM(EngineClient): + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[GPUExecutor], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + log_requests: bool = True, + start_engine_loop: bool = True, + ) -> None: + assert start_engine_loop + + self.log_requests = log_requests + self.log_stats = log_stats + self.stat_loggers = stat_loggers + self.model_config = vllm_config.model_config + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + enable_lora=bool(vllm_config.lora_config)) + self.tokenizer.ping() + + # Request streams (map of request_id -> AsyncStream). + self.request_streams: Dict[str, AsyncStream] = {} + # List of cancelled request ids to be aborted. + self.client_aborted_requests: List[str] = [] + + # Processor (converts Inputs --> EngineCoreRequests). + self.processor = Processor(vllm_config.model_config, + vllm_config.lora_config, self.tokenizer, + input_registry) + + # Detokenizer (converts EngineCoreOutputs --> RequestOutput). + self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + + # EngineCore (starts the engine in background process). + self.engine_core = EngineCoreClient.make_client( + vllm_config=vllm_config, + executor_class=executor_class, + usage_context=usage_context, + multiprocess_mode=True, + asyncio_mode=True, + ) + + self.output_handler = None + + def __del__(self): + self.shutdown() + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[VllmConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Create an AsyncLLM from the EngineArgs.""" + + # Create the engine configs. + if engine_config is None: + vllm_config = engine_args.create_engine_config() + else: + vllm_config = engine_config + + executor_class = cls._get_executor_cls(vllm_config) + + # Create the AsyncLLM. + return cls( + vllm_config=vllm_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + def shutdown(self): + """Shutdown, cleaning up the background proc and IPC.""" + + self.engine_core.shutdown() + + if handler := getattr(self, "output_handler", None): + handler.cancel() + + @classmethod + def _get_executor_cls(cls, vllm_config: VllmConfig): + return GPUExecutor + + async def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + """Add new request to the AsyncLLM.""" + + if self.detokenizer.is_request_active(request_id): + raise KeyError(f"Request {request_id} already exists.") + + # 1) Create a new AsyncStream for the request. + stream = self._add_request_to_streams(request_id) + + # 2) Convert input --> DetokenizerRequest / EngineCoreRequest. + detokenizer_req, engine_core_req = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) + + # 3) Add the request to Detokenizer (this process). + self.detokenizer.add_request(detokenizer_req) + + # 4) Add the EngineCoreRequest to EngineCore (separate process). + await self.engine_core.add_request_async(engine_core_req) + + # 5) Return the generator. + return stream.generator() + + # TODO: we should support multiple prompts in one call, as you + # can do with LLM.generate. So that for multi-prompt completion + # requests we don't need to send multiple messages to core proc, + # and so we don't need multiple streams which then get + # re-multiplexed in the API server anyhow. + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """ + Main function called by the API server to kick off a request + * 1) Making an AsyncStream corresponding to the Request. + # 2) Processing the Input. + * 3) Adding the Request to the Detokenizer. + * 4) Adding the Request to the EngineCore (separate process). + + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the + per-request AsyncStream. + + The caller of generate() iterates the returned AsyncGenerator, + returning the RequestOutput back to the caller. + """ + + # We start the output_handler on the first call to generate() so that + # we can call __init__ before the event loop starts, which enables us + # to handle startup failure gracefully in the OpenAI server. + if self.output_handler is None: + self.output_handler = asyncio.create_task( + self._run_output_handler()) + + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield output + + def _finish_stream(self, request_id: str): + stream = self.request_streams.pop(request_id, None) + if stream is not None: + stream.finish() + + def _add_request_to_streams( + self, + request_id: str, + ) -> AsyncStream: + + if request_id in self.request_streams: + raise ValueError(f"Request id {request_id} already running.") + + # Avoid streams having circular ref to parent AsyncLLM object. + aborted_reqs = self.client_aborted_requests + stream = AsyncStream(request_id, aborted_reqs.append) + self.request_streams[request_id] = stream + + if self.log_requests: + logger.info("Added request %s.", request_id) + + return stream + + async def _process_cancellations(self) -> None: + """ + Process requests cancelled from user disconnecting. + + When a client disconnects, AsyncStream._cancel() is called. + We passed a callback to AsyncStream(), which appends to + self.client_aborted_requests. + + As a result, if any requests are canceled from the user side + the request_id will show up in self.client_aborted_requests. + """ + + # Avoid streams having circular ref to parent AsyncLLM object. + if not self.client_aborted_requests: + return + reqs_to_abort = self.client_aborted_requests.copy() + self.client_aborted_requests.clear() + + # Remove from Detokenizer. + self.detokenizer.abort_requests(reqs_to_abort) + + # Remove from RequestStreams. + for request_id in reqs_to_abort: + if self.log_requests: + logger.info("User-cancelled request %s.", request_id) + self._finish_stream(request_id) + + # Remove from EngineCore. + await self.engine_core.abort_requests_async(reqs_to_abort) + + def _process_request_outputs(self, request_outputs: List[RequestOutput]): + """Process outputs by putting them into per-request AsyncStreams.""" + + for request_output in request_outputs: + request_id = request_output.request_id + assert request_id in self.request_streams + + # Each request in the API server pulls from the per-request stream. + stream = self.request_streams.get(request_id) + if stream is not None: + stream.put(request_output) + + # If finished, remove from the tracker. + if request_output.finished: + if self.log_requests: + logger.info("Finished request %s.", request_id) + self._finish_stream(request_id) + + async def _run_output_handler(self): + """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + + try: + while True: + # 1) Pull EngineCoreOutput from the EngineCore. + outputs = await self.engine_core.get_output_async() + + # 2) Detokenize based on the output. + request_outputs, reqs_to_abort = self.detokenizer.step(outputs) + + # 3) Put the RequestOutputs into the per-request AsyncStreams. + self._process_request_outputs(request_outputs) + + # 4) Abort any requests that finished due to stop strings. + await self.engine_core.abort_requests_async(reqs_to_abort) + + # 5) Abort any requests due to client cancellations. + await self._process_cancellations() + + except BaseException as e: + logger.error(e) + raise e + + # TODO: can we eliminate these? + + async def abort(self, request_id: str) -> None: + # Note: Who Calls this? I dont think this is actually used. + raise ValueError("Not Supported on V1 yet.") + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ): + raise ValueError("Not Supported on V1 yet.") + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def get_decoding_config(self): + raise ValueError("Not Supported on V1 yet.") + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + assert lora_request is None + return self.detokenizer.tokenizer + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs=None, + model_output=None, + ) -> None: + logger.debug("Called do_log_stats.") + + async def check_health(self) -> None: + logger.debug("Called check_health.") + + async def start_profile(self) -> None: + raise ValueError("Not supported on V1 yet.") + + async def stop_profile(self) -> None: + raise ValueError("Not supported on V1 yet.") + + @property + def is_running(self) -> bool: + return True + + @property + def is_stopped(self) -> bool: + return False + + @property + def errored(self) -> bool: + return False + + @property + def dead_error(self) -> BaseException: + return Exception + + +# Retain V0 name for backwards compatibility. +AsyncLLMEngine = AsyncLLM diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py new file mode 100644 index 0000000000000..3e6c759ad5ebd --- /dev/null +++ b/vllm/v1/engine/async_stream.py @@ -0,0 +1,55 @@ +import asyncio +from typing import Any, AsyncGenerator, Callable, Optional, Type, Union + +from vllm.outputs import EmbeddingRequestOutput, RequestOutput + + +class AsyncStream: + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously via an async generator.""" + + STOP_ITERATION = Exception() # Sentinel + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self.request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: + if not self._finished: + self._queue.put_nowait(item) + + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: + if not self._finished: + self._finished = True + self._queue.put_nowait(exception if self._is_raisable(exception) + else AsyncStream.STOP_ITERATION) + + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + finished = False + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + finished = True + if result == AsyncStream.STOP_ITERATION: + return + raise result + yield result + finally: + self._finished = True + if not finished: + self._cancel(self.request_id) + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py new file mode 100644 index 0000000000000..f9d3473d0131c --- /dev/null +++ b/vllm/v1/engine/core.py @@ -0,0 +1,352 @@ +import multiprocessing +import queue +import threading +import time +from contextlib import contextmanager +from multiprocessing.process import BaseProcess +from multiprocessing.sharedctypes import Synchronized +from typing import Any, Iterator, List, Tuple, Type, Union + +import zmq +import zmq.asyncio +from msgspec import msgpack + +from vllm.config import CacheConfig, VllmConfig +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, + EngineCoreRequest, EngineCoreRequestType) +from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.request import Request, RequestStatus +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +LOGGING_TIME_S = 5000 + + +class EngineCore: + """Inner loop of vLLM's Engine.""" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[GPUExecutor], + usage_context: UsageContext, + ): + # Override the configs for V1. + # FIXME + if usage_context == UsageContext.LLM_CLASS: + vllm_config.scheduler_config.max_num_seqs = 1024 + vllm_config.scheduler_config.max_num_batched_tokens = 8192 + elif usage_context == UsageContext.OPENAI_API_SERVER: + vllm_config.scheduler_config.max_num_seqs = 1024 + vllm_config.scheduler_config.max_num_batched_tokens = 2048 + + # TODO (ywang96): Enable APC by default when VLM supports it. + if not vllm_config.model_config.is_multimodal_model: + vllm_config.cache_config.enable_prefix_caching = True + + assert vllm_config.model_config.task != "embedding" + + logger.info("Initializing an LLM engine (v%s) with config: %s", + VLLM_VERSION, vllm_config) + + # Setup Model. + self.model_executor = executor_class(vllm_config) + + # Setup KV Caches and update CacheConfig after profiling. + num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( + vllm_config.cache_config) + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks + vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + + # Setup scheduler. + self.scheduler = Scheduler(vllm_config.scheduler_config, + vllm_config.cache_config, + vllm_config.lora_config) + + self._last_logging_time = time.time() + + def _initialize_kv_caches(self, + cache_config: CacheConfig) -> Tuple[int, int]: + num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( + ) + + if cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + num_cpu_blocks = 0 + self.model_executor.initialize_cache(num_gpu_blocks) + return num_gpu_blocks, num_cpu_blocks + + def add_request(self, request: EngineCoreRequest): + """Add request to the scheduler.""" + + req = Request.from_engine_core_request(request) + self.scheduler.add_request(req) + + def abort_requests(self, request_ids: List[str]): + """Abort requests from the scheduler.""" + + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that + # (i.e. client-aborted vs stop criteria met). + self.scheduler.finish_requests(request_ids, + RequestStatus.FINISHED_ABORTED) + + def step(self) -> List[EngineCoreOutput]: + """Schedule, execute, and make output.""" + + if not self.scheduler.has_unfinished_requests(): + return [] + + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, output) + return engine_core_outputs + + +class EngineCoreProc(EngineCore): + """ZMQ-wrapper for running EngineCore in background process.""" + + READY_STR = "READY" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[GPUExecutor], + usage_context: UsageContext, + input_path: str, + output_path: str, + ready_path: str, + should_shutdown: Synchronized, + ): + super().__init__(vllm_config, executor_class, usage_context) + + # Signal from main process to shutdown (multiprocessing.Value). + self.should_shutdown = should_shutdown + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = queue.Queue() + self.output_queue = queue.Queue() + threading.Thread(target=self.process_input_socket, + args=(input_path, ), + daemon=True).start() + threading.Thread(target=self.process_output_socket, + args=(output_path, ), + daemon=True).start() + + # Send Readiness signal to EngineClient. + with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket: + ready_socket.send_string(EngineCoreProc.READY_STR) + + @contextmanager + def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]: + """Context manager for use """ + + ctx = zmq.Context() + try: + socket = ctx.socket(type) + + if type == zmq.constants.PULL: + socket.connect(path) + elif type == zmq.constants.PUSH: + socket.bind(path) + else: + raise ValueError(f"Unknown Socket Type: {type}") + + yield socket + + except KeyboardInterrupt: + logger.debug("EngineCore had Keyboard Interrupt.") + + finally: + ctx.destroy(linger=0) + + @staticmethod + def wait_for_startup( + proc: BaseProcess, + ready_path: str, + ) -> None: + """Wait until the EngineCore is ready.""" + + try: + sync_ctx = zmq.Context() # type: ignore[attr-defined] + socket = sync_ctx.socket(zmq.constants.PULL) + socket.connect(ready_path) + + # Wait for EngineCore to send EngineCoreProc.READY_STR. + while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for EngineCoreProc to startup.") + + if not proc.is_alive(): + raise RuntimeError("EngineCoreProc failed to start.") + + message = socket.recv_string() + assert message == EngineCoreProc.READY_STR + + except BaseException as e: + logger.exception(e) + raise e + + finally: + sync_ctx.destroy(linger=0) + + @staticmethod + def make_engine_core_process( + vllm_config: VllmConfig, + executor_class: Type[GPUExecutor], + usage_context: UsageContext, + input_path: str, + output_path: str, + ready_path: str, + should_shutdown: Synchronized, + ) -> BaseProcess: + # The current process might have CUDA context, + # so we need to spawn a new process. + # NOTE(rob): this is a problem for using EngineCoreProc w/ + # LLM, since we need a if __name__ == "__main__" guard. + context = multiprocessing.get_context("spawn") + + process_kwargs = { + "input_path": input_path, + "output_path": output_path, + "ready_path": ready_path, + "vllm_config": vllm_config, + "executor_class": executor_class, + "usage_context": usage_context, + "should_shutdown": should_shutdown + } + # Run EngineCore busy loop in background process. + proc = context.Process(target=EngineCoreProc.run_engine_core, + kwargs=process_kwargs) + proc.start() + + # Wait for startup + EngineCoreProc.wait_for_startup(proc, ready_path) + return proc + + @staticmethod + def run_engine_core(*args, **kwargs): + """Launch EngineCore busy loop in background process.""" + + try: + engine_core = EngineCoreProc(*args, **kwargs) + engine_core.run_busy_loop() + + except KeyboardInterrupt: + logger.debug("EngineCore interrupted.") + + except BaseException as e: + logger.exception(e) + raise e + + def run_busy_loop(self): + """Core busy loop of the EngineCore.""" + + # Loop until we get a shutdown signal. + while not self.should_shutdown: + # 1) Poll the input queue until there is work to do. + if not self.scheduler.has_unfinished_requests(): + while True: + try: + req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) + self._handle_client_request(req) + break + except queue.Empty: + self._log_stats() + logger.debug("EngineCore busy loop waiting.") + if self.should_shutdown: + return + + # 2) Handle any new client requests (Abort or Add). + while not self.input_queue.empty(): + req = self.input_queue.get_nowait() + self._handle_client_request(req) + + # 3) Step the engine core. + outputs = self.step() + + # 4) Put EngineCoreOutputs into the output queue. + self.output_queue.put_nowait(outputs) + + self._log_stats() + + def _log_stats(self): + """Log basic stats every LOGGING_TIME_S""" + + now = time.time() + + if now - self._last_logging_time > LOGGING_TIME_S: + logger.info( + "RUNNING: %s | WAITING: %s", + len(self.scheduler.running), + len(self.scheduler.waiting), + ) + + self._last_logging_time = now + + def _handle_client_request( + self, request: Union[EngineCoreRequest, List[str]]) -> None: + """Handle EngineCoreRequest or EngineCoreABORT from Client.""" + + if isinstance(request, EngineCoreRequest): + self.add_request(request) + else: + # TODO: make an EngineCoreAbort wrapper + assert isinstance(request, list) + self.abort_requests(request) + + def process_input_socket(self, input_path: str): + """Input socket IO thread.""" + + # Msgpack serialization decoding. + decoder_add_req = msgpack.Decoder(EngineCoreRequest) + decoder_abort_req = msgpack.Decoder(list[str]) + + with self.make_socket(input_path, zmq.constants.PULL) as socket: + while True: + # (RequestType, RequestData) + type_frame, data_frame = socket.recv_multipart(copy=False) + request_type = type_frame.buffer + request_data = data_frame.buffer + + # Deserialize the request data. + if request_type == EngineCoreRequestType.ADD.value: + request = decoder_add_req.decode(request_data) + elif request_type == EngineCoreRequestType.ABORT.value: + request = decoder_abort_req.decode(request_data) + else: + raise ValueError(f"Unknown RequestType: {request_type}") + + # Push to input queue for core busy loop. + self.input_queue.put_nowait(request) + + def process_output_socket(self, output_path: str): + """Output socket IO thread.""" + + # Msgpack serialization encoding. + encoder = msgpack.Encoder() + # Reuse send buffer. + buffer = bytearray() + + with self.make_socket(output_path, zmq.constants.PUSH) as socket: + while True: + engine_core_outputs = self.output_queue.get() + outputs = EngineCoreOutputs(outputs=engine_core_outputs) + encoder.encode_into(outputs, buffer) + socket.send_multipart((buffer, ), copy=False) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py new file mode 100644 index 0000000000000..f9e4677fb8c59 --- /dev/null +++ b/vllm/v1/engine/core_client.py @@ -0,0 +1,218 @@ +import multiprocessing +import time +from typing import List, Union + +import msgspec +import zmq +import zmq.asyncio + +from vllm.logger import init_logger +from vllm.utils import get_open_zmq_ipc_path +from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, + EngineCoreRequest, EngineCoreRequestType) +from vllm.v1.engine.core import EngineCore, EngineCoreProc + +logger = init_logger(__name__) + + +class EngineCoreClient: + """ + EngineCoreClient: subclasses handle different methods for pushing + and pulling from the EngineCore for asyncio / multiprocessing. + + Subclasses: + * InprocClient: In process EngineCore (for V0-style LLMEngine use) + * SyncMPClient: ZMQ + background proc EngineCore (for LLM) + * AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM) + """ + + @staticmethod + def make_client( + *args, + multiprocess_mode: bool, + asyncio_mode: bool, + **kwargs, + ) -> "EngineCoreClient": + + # TODO: support this for debugging purposes. + if asyncio_mode and not multiprocess_mode: + raise NotImplementedError( + "Running EngineCore in asyncio without multiprocessing " + "is not currently supported.") + + if multiprocess_mode and asyncio_mode: + return AsyncMPClient(*args, **kwargs) + + if multiprocess_mode and not asyncio_mode: + return SyncMPClient(*args, **kwargs) + + return InprocClient(*args, **kwargs) + + def shutdown(self): + pass + + def get_output(self) -> List[EngineCoreOutput]: + raise NotImplementedError + + def add_request(self, request: EngineCoreRequest) -> None: + raise NotImplementedError + + def abort_requests(self, request_ids: List[str]) -> None: + raise NotImplementedError + + async def get_output_async(self) -> List[EngineCoreOutput]: + raise NotImplementedError + + async def add_request_async(self, request: EngineCoreRequest) -> None: + raise NotImplementedError + + async def abort_requests_async(self, request_ids: List[str]) -> None: + raise NotImplementedError + + +class InprocClient(EngineCoreClient): + """ + InprocClient: client for in-process EngineCore. Intended + for use in LLMEngine for V0-style add_request() and step() + EngineCore setup in this process (no busy loop). + + * pushes EngineCoreRequest directly into the EngineCore + * pulls EngineCoreOutputs by stepping the EngineCore + + TODO: support asyncio-mode for debugging. + """ + + def __init__(self, *args, **kwargs): + self.engine_core = EngineCore(*args, **kwargs) + + def get_output(self) -> List[EngineCoreOutput]: + return self.engine_core.step() + + def add_request(self, request: EngineCoreRequest) -> None: + self.engine_core.add_request(request) + + def abort_requests(self, request_ids: List[str]) -> None: + self.engine_core.abort_requests(request_ids) + + +class MPClient(EngineCoreClient): + """ + MPClient: base client for multi-proc EngineCore. + EngineCore runs in a background process busy loop, getting + new EngineCoreRequests and returning EngineCoreOutputs + + * pushes EngineCoreRequests via input_socket + * pulls EngineCoreOutputs via output_socket + + * AsyncMPClient subclass for AsyncLLM usage + * SyncMPClient subclass for LLM usage + """ + + def __init__( + self, + *args, + asyncio_mode: bool, + **kwargs, + ): + # Serialization setup. + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + + # ZMQ setup. + self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context()) + + # Path for IPC. + ready_path = get_open_zmq_ipc_path() + output_path = get_open_zmq_ipc_path() + input_path = get_open_zmq_ipc_path() + + # Get output (EngineCoreOutput) from EngineCore. + self.output_socket = self.ctx.socket(zmq.constants.PULL) + self.output_socket.connect(output_path) + + # Send input (EngineCoreRequest) to EngineCore. + self.input_socket = self.ctx.socket(zmq.constants.PUSH) + self.input_socket.bind(input_path) + + # Start EngineCore in background process. + self.should_shutdown = multiprocessing.Value('b', False, lock=False) + self.proc = EngineCoreProc.make_engine_core_process( + *args, + input_path=input_path, + output_path=output_path, + ready_path=ready_path, + should_shutdown=self.should_shutdown, + **kwargs, + ) + + def shutdown(self): + # Send shutdown signal to background process. + self.should_shutdown = True + + # Shut down the zmq context. + self.ctx.destroy(linger=0) + + # Shutdown the process if needed. + if hasattr(self, "proc") and self.proc.is_alive(): + self.proc.terminate() + + time.sleep(5) + if self.proc.is_alive(): + self.proc.kill() + + def __del__(self): + self.shutdown() + + +class SyncMPClient(MPClient): + """Synchronous client for multi-proc EngineCore.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, asyncio_mode=False, **kwargs) + + def get_output(self) -> List[EngineCoreOutput]: + + (frame, ) = self.output_socket.recv_multipart(copy=False) + engine_core_outputs = self.decoder.decode(frame.buffer).outputs + return engine_core_outputs + + def _send_input(self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]]) -> None: + + # (RequestType, SerializedRequest) + msg = (request_type.value, self.encoder.encode(request)) + self.input_socket.send_multipart(msg, copy=False) + + def add_request(self, request: EngineCoreRequest) -> None: + self._send_input(EngineCoreRequestType.ADD, request) + + def abort_requests(self, request_ids: List[str]) -> None: + self._send_input(EngineCoreRequestType.ABORT, request_ids) + + +class AsyncMPClient(MPClient): + """Asyncio-compatible client for multi-proc EngineCore.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, asyncio_mode=True, **kwargs) + + async def get_output_async(self) -> List[EngineCoreOutput]: + + frames = await self.output_socket.recv_multipart(copy=False) + engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs + + return engine_core_outputs + + async def _send_input( + self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]]) -> None: + + msg = (request_type.value, self.encoder.encode(request)) + await self.input_socket.send_multipart(msg, copy=False) + + async def add_request_async(self, request: EngineCoreRequest) -> None: + await self._send_input(EngineCoreRequestType.ADD, request) + + async def abort_requests_async(self, request_ids: List[str]) -> None: + if len(request_ids) > 0: + await self._send_input(EngineCoreRequestType.ABORT, request_ids) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py new file mode 100644 index 0000000000000..1dbf8e75ec478 --- /dev/null +++ b/vllm/v1/engine/detokenizer.py @@ -0,0 +1,265 @@ +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput + +logger = init_logger(__name__) + + +@dataclass +class IncrementalDetokenizer: + + # Generation data + output_text: str + tokens: List[str] + token_ids: List[int] + + # Stop strings + stop: List[str] + include_stop_str_in_output: bool + + # Metadata for incremental detokenization + prefix_offset: int + read_offset: int + + # Parameters for detokenization + skip_special_tokens: bool + spaces_between_special_tokens: bool + output_kind: RequestOutputKind + + # TODO: Probably decouple these + request_id: str + prompt: Optional[str] + prompt_token_ids: List[int] + + # Tokenizer for this request + tokenizer: AnyTokenizer + + # Accounting for stop string buffering + stop_buffer_length: int + _last_output_text_offset: int = 0 + + @property + def output_token_ids(self) -> List[int]: + assert len(self.token_ids) >= len(self.prompt_token_ids) + return self.token_ids[len(self.prompt_token_ids):] + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: DetokenizerRequest, + ) -> "IncrementalDetokenizer": + + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.skip_special_tokens, + ) + + stops = request.stop + # Number of chars to hold back when stop strings are to be excluded + # from streamed output. + if stops and not request.include_stop_str_in_output: + stop_buffer_length = max(len(s) for s in stops) - 1 + else: + stop_buffer_length = 0 + + return cls( + output_text="", + tokens=tokens, + # Detokenizer mutates this list, so need a unique copy. + # NOTE(Nick): could we take ownership of it though? + token_ids=request.prompt_token_ids.copy(), + stop=stops, + include_stop_str_in_output=request.include_stop_str_in_output, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=request. + spaces_between_special_tokens, + output_kind=request.output_kind, + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + tokenizer=tokenizer, + stop_buffer_length=stop_buffer_length, + ) + + def add_tokens( + self, + new_token_ids: List[int], + finish_reason: Optional[str], + stop_reason: Optional[str], + ) -> Optional[RequestOutput]: + """ + Update RequestState for the request_id by: + 1) Detokenize the new token ids incrementally. + 2) Update the RequestOutput with the new text. + """ + + # 1) Detokenize the new token ids incrementally. + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. + decoded_text = "" + for new_token_id in new_token_ids: + self.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + ) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + self.output_text += new_decoded_token_text + + decoded_text += new_decoded_token_text + + # 2) Evaluate stop criteria. + if self.stop: + stop = StopChecker.check_stop_strings( + output_text=self.output_text, + new_char_count=len(decoded_text), + stop=self.stop, + include_in_output=self.include_stop_str_in_output, + ) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + self.output_text = self.output_text[:truncate_to] + finish_reason = "stop" # TODO: use constant + stop_reason = stop_str + + # TODO: handle stop_token_ids here too? + + # 3) Update the RequestOutput object with the new text. + finished = bool(finish_reason) + if self.output_kind == RequestOutputKind.FINAL_ONLY \ + and not finished: + return None + + delta = self.output_kind == RequestOutputKind.DELTA + output_text = self._get_next_output_text(finished, delta) + token_ids = new_token_ids if delta else self.output_token_ids + + request_output = RequestOutput.new( + self.request_id, + self.prompt, + self.prompt_token_ids, + output_text, + token_ids, + finished, + ) + + if finished: + completion_output = request_output.outputs[0] + completion_output.finish_reason = finish_reason + completion_output.stop_reason = stop_reason + + return request_output + + def _get_next_output_text(self, finished: bool, delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + + # We return the full output text if the sequence is finished. + buffer_length = 0 if finished else self.stop_buffer_length + if not delta: + return self.output_text[:-buffer_length] if buffer_length else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + +class Detokenizer: + + def __init__(self, tokenizer_name: str): + # TODO: once we support LoRA, we should should pass the tokenizer + # here. We currently have two copies (this + in the LLMEngine). + self.tokenizer = get_tokenizer(tokenizer_name) + + # Request id -> IncrementalDetokenizer + self.request_states: Dict[str, IncrementalDetokenizer] = {} + + def is_request_active(self, request_id: str): + return request_id in self.request_states + + def get_num_unfinished_requests(self): + return len(self.request_states) + + def has_unfinished_requests(self) -> bool: + return len(self.request_states) > 0 + + def abort_requests( + self, + request_ids: Iterable[str], + ) -> None: + """Remove the request_ids from the Detokenizer.""" + + for request_id in request_ids: + self.request_states.pop(request_id, None) + + def add_request( + self, + request: DetokenizerRequest, + ): + """Add new request to the Detokenizer.""" + + assert (request.request_id not in self.request_states) + + request_state = IncrementalDetokenizer.from_new_request( + self.tokenizer, request) + self.request_states[request.request_id] = request_state + + def step( + self, encore_core_outputs: List[EngineCoreOutput] + ) -> Tuple[List[RequestOutput], List[str]]: + """Update state and request the RequestOutputs to the LLMEngine.""" + + request_outputs: List[RequestOutput] = [] + requests_to_abort: List[str] = [] + for engine_core_output in encore_core_outputs: + request_id = engine_core_output.request_id + detokenizer = self.request_states.get(request_id) + if detokenizer is None: + # Ignore output for already-aborted request. + continue + + # Detokenize and update state. + request_output = detokenizer.add_tokens( + new_token_ids=engine_core_output.new_token_ids, + finish_reason=engine_core_output.finish_reason, + stop_reason=engine_core_output.stop_reason, + ) + + if request_output is not None: + # Add to RequestOutputs list. + request_outputs.append(request_output) + + # Free completed requests. + if request_output.finished: + self.request_states.pop(request_id) + if not engine_core_output.finished: + requests_to_abort.append(request_id) + + # Return to EngineClient. + return request_outputs, requests_to_abort diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 38d95ab44bb90..f37db92e8ea6b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,35 +1,28 @@ -import time -from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, - Union) +from typing import Dict, List, Mapping, Optional, Type, Union -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderLLMInputs, InputRegistry, PromptType) -from vllm.inputs.preprocess import InputPreprocessor +from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING +from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.v1.core.scheduler import Scheduler +from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.processor import Processor from vllm.v1.executor.gpu_executor import GPUExecutor -from vllm.v1.request import Request, RequestStatus -from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs -from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) class LLMEngine: + """Legacy LLMEngine for backwards compatibility.""" def __init__( self, @@ -40,146 +33,36 @@ class LLMEngine: stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, + multiprocess_mode: bool = False, ) -> None: - # TODO: remove the local variables and use self.* throughout the class. - model_config = self.model_config = vllm_config.model_config - cache_config = self.cache_config = vllm_config.cache_config - lora_config = self.lora_config = vllm_config.lora_config - parallel_config = self.parallel_config = vllm_config.parallel_config - scheduler_config = self.scheduler_config = vllm_config.scheduler_config - device_config = self.device_config = vllm_config.device_config - speculative_config = self.speculative_config = vllm_config.speculative_config # noqa - load_config = self.load_config = vllm_config.load_config - decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + # TODO: Can we avoid this? + self.model_config = vllm_config.model_config + + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + enable_lora=bool(vllm_config.lora_config)) + self.tokenizer.ping() + + # Processor (convert Inputs --> EngineCoreRequests) + self.processor = Processor(vllm_config.model_config, + vllm_config.lora_config, self.tokenizer, + input_registry) + + # Detokenizer (converts EngineCoreOutputs --> RequestOutput) + self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + + # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) + self.engine_core = EngineCoreClient.make_client( + vllm_config, + executor_class, + usage_context, + multiprocess_mode=multiprocess_mode, + asyncio_mode=False, ) - prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa - observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa - ) - - # Override the configs for V1. - # FIXME - if usage_context == UsageContext.LLM_CLASS: - scheduler_config.max_num_seqs = 1024 - scheduler_config.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - scheduler_config.max_num_seqs = 1024 - scheduler_config.max_num_batched_tokens = 2048 - - # TODO (ywang96): Enable APC by default when VLM supports it. - if not model_config.is_multimodal_model: - cache_config.enable_prefix_caching = True - - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.num_scheduler_steps, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - model_config.mm_processor_kwargs, - ) - - self.log_stats = log_stats - - assert not self.model_config.skip_tokenizer_init - self.tokenizer = self._init_tokenizer() - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - self.detokenizer = Detokenizer( - tokenizer_name=self.model_config.tokenizer, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code) - - self.generation_config_fields = _load_generation_config_dict( - model_config) - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - model_config) - - # Request id -> Request - self.requests: Dict[str, Request] = {} - # NOTE(woosuk): Now that the detokenizer works asynchronously, we need - # to keep track of how many steps each request has been lagged behind - # in terms of detokenization. - # Request id -> how many detokenizer steps the request should wait for. - self.num_lagged_steps: Dict[str, int] = {} - # OPTIMIZATION: Cache the request output and update it incrementally. - # This is used to avoid creating a new RequestOutput object every step. - # Request id -> RequestOutput - self.request_outputs: Dict[str, RequestOutput] = {} - - self.model_executor = executor_class(vllm_config=vllm_config) - assert self.model_config.task != "embedding" - self._initialize_kv_caches() - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) - - def __del__(self): - # Small hack- implicit clean up of resources on garbage collect - # TODO: this should probably be explicitly invoked when we're done with - # the engine - self.terminate_detokenizer() - - def _initialize_kv_caches(self) -> None: - num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( - ) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = 0 - self.model_executor.initialize_cache(num_gpu_blocks) @classmethod def from_engine_args( @@ -187,71 +70,49 @@ class LLMEngine: engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. - engine_config = engine_args.create_engine_config() - executor_class = cls._get_executor_cls(engine_config) - # Create the LLM engine. - engine = cls( - vllm_config=engine_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine + vllm_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(vllm_config) - def _init_tokenizer(self) -> BaseTokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) + if VLLM_ENABLE_V1_MULTIPROCESSING: + logger.debug("Enabling multiprocessing for LLMEngine.") + enable_multiprocessing = True - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - if self.prompt_adapter_config: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) + # Create the LLMEngine. + return cls(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing) - def _add_processed_request( - self, - request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - ) -> None: - assert prompt_adapter_request is None - assert trace_headers is None - self._validate_model_inputs(processed_inputs) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - # TODO(woosuk): Support embedding mode. - assert isinstance(params, SamplingParams) - sampling_params = params.clone() - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) - - # TODO(woosuk): Check max_logprobs - # TODO(woosuk): Support encoder-decoder models. - req = Request(request_id, processed_inputs, params, eos_token_id, - arrival_time) - self.requests[request_id] = req - self.num_lagged_steps[request_id] = 0 - self.scheduler.add_request(req) + @classmethod + def _get_executor_cls(cls, vllm_config: VllmConfig): + return GPUExecutor def stop_remote_worker_execution_loop(self) -> None: raise NotImplementedError("TP not implemented yet.") + def get_num_unfinished_requests(self) -> int: + return self.detokenizer.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + return self.detokenizer.has_unfinished_requests() + + @classmethod + def validate_outputs(cls, outputs, output_type): + return outputs + + def abort_request(self, request_ids: List[str]) -> None: + """Remove request_ids from EngineCore and Detokenizer.""" + + self.engine_core.abort_requests(request_ids) + self.detokenizer.abort_requests(request_ids) + def add_request( self, request_id: str, @@ -263,261 +124,46 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - if arrival_time is None: - arrival_time = time.time() - assert priority == 0, "vLLM V1 does not support priority at the moment." - preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) + # 1) Process raw inputs into the request. + detokenizer_req, engine_core_req = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - ) + # 2) Add the request to Detokenizer. + self.detokenizer.add_request(detokenizer_req) - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - self.scheduler.finish_requests(request_id, - RequestStatus.FINISHED_ABORTED) - self._free_request(request_id) - - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return len(self.requests) - - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return len(self.requests) > 0 + # 3) Add the request to EngineCore. + self.engine_core.add_request(engine_core_req) def step(self) -> List[RequestOutput]: - # NOTE(woosuk): This method may return an empty list when the - # detokenizer is still processing the outputs. This should not be - # considered as the end of the generation process. - # FIXME(woosuk): Currently, the step method is inefficient because it - # creates RequestOutput objects for all running requests, while they - # may not be needed unless the output is streamed to the client. - if self.scheduler.has_unfinished_requests(): - scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) - sampled = self.scheduler.update_from_output( - scheduler_output, output) - self.send_to_detokenizer(sampled) - req_outputs = self.recv_from_detokenizer() - return req_outputs - def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: - inputs = DetokenizerInputs( - req_ids=[], - prompt_token_ids=[], - new_token_ids=[], - skip_special_tokens=[], - spaces_between_special_tokens=[], - free_req_ids=[], # TODO(woosuk): Implement freeing. - ) - for req, num_tokens in sampled: - inputs.req_ids.append(req.request_id) - if req.num_output_tokens == num_tokens: - # The request is first detokenized. - inputs.prompt_token_ids.append(req.prompt_token_ids) - else: - # The prompt token ids are already cached in the detokenizer. - inputs.prompt_token_ids.append([]) - inputs.new_token_ids.append(req.output_token_ids[-num_tokens:]) - inputs.skip_special_tokens.append( - req.sampling_params.skip_special_tokens) - inputs.spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens) + # 1) Get EngineCoreOutput from the EngineCore. + engine_core_outputs = self.engine_core.get_output() - # Update the number of lagged steps. - self.num_lagged_steps[req.request_id] += 1 - self.detokenizer.send(inputs) + # 2) Detokenizer the EngineCoreOutput. + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) - def recv_from_detokenizer(self) -> List[RequestOutput]: - detokenizer_output = self.detokenizer.recv() - if detokenizer_output is None: - return [] + # 3) Abort requests that finished due to stopping criteria. + if requests_to_abort: + self.abort_request(requests_to_abort) - req_outputs: List[RequestOutput] = [] - num_reqs = len(detokenizer_output.req_ids) - for i in range(num_reqs): - req_id = detokenizer_output.req_ids[i] - if req_id not in self.requests: - # The request has been aborted while the detokenizer was - # processing the outputs. - continue + return request_outputs - req = self.requests[req_id] - req.output_text += detokenizer_output.detokenized_texts[i] + # TODO(rob): Can we get rid of these? - self.num_lagged_steps[req_id] -= 1 - finished = (self.num_lagged_steps[req_id] == 0 - and req.is_finished()) - req_output = self._make_request_output( - req, detokenizer_output.num_output_token_ids[i], - detokenizer_output.detokenized_texts[i], finished) - req_outputs.append(req_output) - - if finished: - self._free_request(req_id) - return req_outputs - - def terminate_detokenizer(self) -> None: - self.detokenizer.terminate() - - def _make_request_output( - self, - request: Request, - num_output_tokens: int, - new_output_text: str, - finished: bool, - ) -> RequestOutput: - req_output = self.request_outputs.get(request.request_id) - if req_output is None: - # TODO: Support `n` > 1. - completion_output = CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, # TODO - finish_reason=None, - stop_reason=None, - lora_request=None, - ) - req_output = RequestOutput( - request_id=request.request_id, - prompt=request.prompt, - prompt_token_ids=request.prompt_token_ids, - prompt_logprobs=None, # TODO - outputs=[completion_output], - finished=False, - metrics=None, - lora_request=None, - encoder_prompt=None, - encoder_prompt_token_ids=None, - ) - self.request_outputs[request.request_id] = req_output - - completion_output = req_output.outputs[0] - if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE: - completion_output.text += new_output_text - completion_output.token_ids = ( - request.output_token_ids[:num_output_tokens]) - elif request.sampling_params.output_kind == RequestOutputKind.DELTA: - completion_output.text = new_output_text - num_prev_tokens = len(completion_output.token_ids) - completion_output.token_ids = request.output_token_ids[ - num_prev_tokens:num_output_tokens] - elif (request.sampling_params.output_kind == - RequestOutputKind.FINAL_ONLY): - if finished: - completion_output.text = request.output_text - completion_output.token_ids = request.output_token_ids - else: - completion_output.text = "" - completion_output.token_ids = [] - - if finished: - completion_output.finish_reason = request.get_finished_reason() - completion_output.stop_reason = request.stop_reason - req_output.finished = finished - return req_output - - def _free_request(self, request_id: str) -> None: - self.requests.pop(request_id, None) - self.num_lagged_steps.pop(request_id, None) - self.request_outputs.pop(request_id, None) - - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderLLMInputs]): - prompt_ids = inputs.get("prompt_token_ids") - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") - - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - - @classmethod - def validate_outputs(cls, outputs, output_type): - return outputs - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config - - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config - - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config - - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config - - @classmethod - def _get_executor_cls(cls, engine_config: VllmConfig): - return GPUExecutor - - def is_tracing_enabled(self) -> bool: - return False - - def do_log_stats(self, *args, **kwargs) -> None: + def get_model_config(self): pass - def is_encoder_decoder_model(self) -> bool: - return False - - def start_profile(self) -> None: + def is_encoder_decoder_model(self): pass - def stop_profile(self) -> None: + def start_profile(self): pass - def get_tokenizer_group(self, *args, **kwargs): - return self.tokenizer + def stop_profile(self): + pass - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() + def get_tokenizer_group(self, group_type): + pass diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py new file mode 100644 index 0000000000000..d92e622810389 --- /dev/null +++ b/vllm/v1/engine/processor.py @@ -0,0 +1,128 @@ +import time +from typing import Any, Dict, Mapping, Optional, Tuple, Union + +from vllm.config import LoRAConfig, ModelConfig +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderLLMInputs, InputRegistry, PromptType) +from vllm.inputs.preprocess import InputPreprocessor +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.tokenizer_group import AnyTokenizer +from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest + + +class Processor: + + def __init__( + self, + model_config: ModelConfig, + lora_config: Optional[LoRAConfig], + tokenizer: AnyTokenizer, + input_registry: InputRegistry = INPUT_REGISTRY, + ): + + self.model_config = model_config + self.lora_config = lora_config + self.tokenizer = tokenizer + + self.generation_config_fields = _load_generation_config_dict( + model_config) + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + self.input_processor = input_registry.create_input_processor( + model_config) + + # TODO: run in an ThreadpoolExecutor or BackgroundProcess. + # This ideally should releases the GIL, so we should not block the + # asyncio loop while this is running. + def process_inputs( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Tuple[DetokenizerRequest, EngineCoreRequest]: + + # TODO(woosuk): Support embedding mode. + # TODO(woosuk): Check max_logprobs + # TODO(woosuk): Support encoder-decoder models. + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + assert priority == 0, "vLLM V1 does not support priority at the moment." + assert trace_headers is None, "vLLM V1 does not support tracing yet." + + # Process inputs. + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + self._validate_model_inputs(processed_inputs) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + assert isinstance(params, SamplingParams) + # TODO: can we avoid cloning here in multiproc case + sampling_params = params.clone() + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + + # Make Request for Detokenizer. + detokenizer_request = DetokenizerRequest( + request_id, processed_inputs.get("prompt"), + processed_inputs.get("prompt_token_ids"), + sampling_params.skip_special_tokens, + sampling_params.spaces_between_special_tokens, + sampling_params.output_kind, sampling_params.stop, + sampling_params.include_stop_str_in_output) + + # Make Request for EngineCore. + engine_core_request = EngineCoreRequest( + request_id, processed_inputs.get("prompt"), + processed_inputs.get("prompt_token_ids"), sampling_params, + eos_token_id, arrival_time, lora_request) + + return detokenizer_request, engine_core_request + + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, + EncoderDecoderLLMInputs]): + prompt_ids = inputs.get("prompt_token_ids") + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: + return {} + + return config.to_diff_dict() diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 087067cdac56f..00e5aea92a8df 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,9 +1,11 @@ import enum from typing import TYPE_CHECKING, List, Optional, Union +from vllm.inputs.data import DecoderOnlyInputs from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics +from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: @@ -43,9 +45,22 @@ class Request: self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() - self.output_text = "" self.num_computed_tokens = 0 + @classmethod + def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + + return cls( + request_id=request.request_id, + inputs=DecoderOnlyInputs(type="token", + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt), + sampling_params=request.sampling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + ) + @property def output_token_ids(self) -> ConstantList[int]: # Prevent directly appending to the output_token_ids since diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py deleted file mode 100644 index 8d80ebbc5cc45..0000000000000 --- a/vllm/v1/tokenizer/detokenizer.py +++ /dev/null @@ -1,228 +0,0 @@ -import multiprocessing -from dataclasses import dataclass -from typing import Dict, List, Optional - -import msgspec -import zmq -from msgspec import msgpack - -from vllm.transformers_utils.detokenizer_utils import ( - convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import get_open_port - - -class DetokenizerInputs(msgspec.Struct): - - # [num_reqs] - req_ids: List[str] - # A request's prompt token ids is sent to the detokenizer only when - # the request is first detokenized. Otherwise, an empty list is sent. - prompt_token_ids: List[List[int]] - new_token_ids: List[List[int]] - skip_special_tokens: List[bool] - spaces_between_special_tokens: List[bool] - - # [num_free_reqs] - free_req_ids: List[str] - - -class DetokenizerOutputs(msgspec.Struct): - - # [num_reqs] - req_ids: List[str] - detokenized_texts: List[str] - # NOTE(woosuk): The number of the output token ids of each request - # at the time of detokenization. The detokenizer returns this to the engine - # because the request state (including the output token ids) is - # asynchronously updated in the engine, while RequestOutput requires the - # output token ids to be consistent with the detokenized text. - num_output_token_ids: List[int] - - -class Detokenizer: - - def __init__(self, tokenizer_name: str, tokenizer_mode: str, - trust_remote_code: bool): - # FIXME(woosuk): Currently, the detokenizer is just a hacky prototype. - # For example, it does not terminate properly. We need to improve this. - self.push_port = get_open_port() - self.pull_port = get_open_port() - # NOTE: The push port of the engine process should be the same as the - # pull port of the detokenizer process. Vice versa. - self.detokenizer = DetokenizerProc(tokenizer_name=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - push_port=self.pull_port, - pull_port=self.push_port) - self.detokenizer.start() - - self.zmq_context = zmq.Context() - self.push_socket = self.zmq_context.socket(zmq.PUSH) - self.push_socket.connect(f"tcp://localhost:{self.push_port}") - self.pull_socket = self.zmq_context.socket(zmq.PULL) - self.pull_socket.connect(f"tcp://localhost:{self.pull_port}") - self.poller = zmq.Poller() - self.poller.register(self.pull_socket, zmq.POLLIN) - self.msgpack_encoder = msgpack.Encoder() - self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs) - - def send(self, inputs: DetokenizerInputs) -> None: - self.push_socket.send(self.msgpack_encoder.encode(inputs), - flags=zmq.NOBLOCK) - - def recv(self) -> Optional[DetokenizerOutputs]: - socks = dict(self.poller.poll(timeout=0)) - if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN: - msg = self.pull_socket.recv() - return self.msgpack_decoder.decode(msg) - return None - - def terminate(self) -> None: - self.detokenizer.kill() - self.detokenizer.join() - - -class DetokenizerProc(multiprocessing.Process): - - def __init__( - self, - tokenizer_name: str, - tokenizer_mode: str, - trust_remote_code: bool, - pull_port: int, - push_port: int, - ): - super().__init__() - self.tokenizer_name = tokenizer_name - self.tokenizer_mode = tokenizer_mode - self.trust_remote_code = trust_remote_code - # NOTE: The pull_port of the detokenizer process should be the same as - # the push_port of the engine process. Vice versa. - self.pull_port = pull_port - self.push_port = push_port - - def run(self): - # Initialize these objects after the process is forked since they are - # not picklable. - self.msgpack_encoder = msgpack.Encoder() - self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs) - self.tokenizer = get_tokenizer( - tokenizer_name=self.tokenizer_name, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code) - # req_id -> RequestState - self.request_states: Dict[str, RequestState] = {} - - self.zmq_context = zmq.Context() - self.pull_socket = self.zmq_context.socket(zmq.PULL) - self.pull_socket.bind(f"tcp://*:{self.pull_port}") - self.push_socket = self.zmq_context.socket(zmq.PUSH) - self.push_socket.bind(f"tcp://*:{self.push_port}") - - while True: - if self.pull_socket.poll(timeout=1000) == 0: - # Nothing to read - continue - message = self.pull_socket.recv() - inputs = self.msgpack_decoder.decode(message) - - for req_id in inputs.free_req_ids: - self.free(req_id) - - detokenized_texts: List[str] = [] - num_output_token_ids: List[int] = [] - num_reqs = len(inputs.req_ids) - for i in range(num_reqs): - req_id = inputs.req_ids[i] - if req_id not in self.request_states: - self.add_request( - request_id=req_id, - prompt_token_ids=inputs.prompt_token_ids[i], - skip_special_tokens=inputs.skip_special_tokens[i], - spaces_between_special_tokens=inputs. - spaces_between_special_tokens[i], - ) - new_str = self.detokenize(req_id, inputs.new_token_ids[i]) - detokenized_texts.append(new_str) - req_state = self.request_states[req_id] - num_output_token_ids.append( - len(req_state.token_ids) - req_state.num_prompt_tokens) - - detokenized = DetokenizerOutputs( - req_ids=inputs.req_ids, - detokenized_texts=detokenized_texts, - num_output_token_ids=num_output_token_ids, - ) - self.push_socket.send(self.msgpack_encoder.encode(detokenized), - flags=zmq.NOBLOCK) - - def add_request( - self, - request_id: str, - prompt_token_ids: List[int], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, - ) -> None: - tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( - tokenizer=self.tokenizer, - prompt_ids=prompt_token_ids, - skip_special_tokens=skip_special_tokens, - ) - self.request_states[request_id] = RequestState( - req_id=request_id, - token_ids=prompt_token_ids, - tokens=tokens, - num_prompt_tokens=len(prompt_token_ids), - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - def free(self, request_id: str) -> None: - del self.request_states[request_id] - - def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: - # TODO(woosuk): This method becomes very inefficient when the number of - # new_token_ids is more than 1. We need to optimize this. - req_state = self.request_states[request_id] - decoded_text = "" - for new_token_id in new_token_ids: - req_state.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=req_state.token_ids, - prev_tokens=req_state.tokens, - prefix_offset=req_state.prefix_offset, - read_offset=req_state.read_offset, - skip_special_tokens=req_state.skip_special_tokens, - spaces_between_special_tokens=req_state. - spaces_between_special_tokens, - ) - - req_state.tokens.extend(new_tokens) - req_state.prefix_offset = prefix_offset - req_state.read_offset = read_offset - req_state.output_text += new_decoded_token_text - decoded_text += new_decoded_token_text - return decoded_text - - -@dataclass -class RequestState: - - req_id: str - - token_ids: List[int] - tokens: List[str] - num_prompt_tokens: int - - prefix_offset: int - read_offset: int - - skip_special_tokens: bool - spaces_between_special_tokens: bool - - output_text: str = ""