mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 14:27:09 +08:00
[V1] AsyncLLM Implementation (#9826)
Signed-off-by: Nick Hill <nickhill@us.ibm.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
08f93e7439
commit
6ace6fba2c
@ -165,6 +165,14 @@ steps:
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
|
||||
- label: V1 Test
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
commands:
|
||||
- pytest -v -s v1
|
||||
|
||||
- label: Examples Test # 15min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
#mirror_hardwares: [amd]
|
||||
|
||||
56
tests/entrypoints/llm/test_accuracy.py
Normal file
56
tests/entrypoints/llm/test_accuracy.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""
|
||||
This file test accuracy of the vLLM server via LMEval.
|
||||
It uses local-completions, which interacts with vLLM
|
||||
through the OAI API with N concurrent connections.
|
||||
This simulates real work usage of the API and makes
|
||||
sure that the zmq frontend mp RPC message passing and
|
||||
AsyncLLMEngine are working correctly.
|
||||
"""
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
|
||||
|
||||
def run_test():
|
||||
"""Run the end to end accuracy test."""
|
||||
|
||||
model_args = f"pretrained={MODEL_NAME},max_model_len=2048"
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks="gsm8k",
|
||||
batch_size="auto",
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="V1 is currently only supported on CUDA.")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
run_test()
|
||||
|
||||
|
||||
def test_lm_eval_accuracy_v0_engine(monkeypatch):
|
||||
"""Run with the V0 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
run_test()
|
||||
@ -37,11 +37,11 @@ if current_platform.is_tpu():
|
||||
MAX_WAIT_SECONDS = 600
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
def test_lm_eval_accuracy(more_args):
|
||||
def run_test(more_args):
|
||||
"""Run the end to end accuracy test."""
|
||||
|
||||
args = list(DEFAULT_ARGS)
|
||||
args.extend(more_args)
|
||||
|
||||
print(f"Running with: {args}")
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
@ -64,3 +64,22 @@ def test_lm_eval_accuracy(more_args):
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="V1 currently only supported on CUDA")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
run_test([])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
def test_lm_eval_accuracy_v0_engine(monkeypatch, more_args):
|
||||
"""Run with the V0 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
run_test(more_args)
|
||||
|
||||
0
tests/v1/engine/__init__.py
Normal file
0
tests/v1/engine/__init__.py
Normal file
66
tests/v1/engine/test_async_llm.py
Normal file
66
tests/v1/engine/test_async_llm.py
Normal file
@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
|
||||
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
||||
disable_log_requests=True)
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM, request_id: str,
|
||||
max_tokens: int) -> Tuple[int, str]:
|
||||
count = 0
|
||||
async for _ in engine.generate(request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=max_tokens, temperature=0)):
|
||||
|
||||
count += 1
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||
|
||||
NUM_REQUESTS = 10000
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
failed_request_id = None
|
||||
tokens = None
|
||||
for task in tasks:
|
||||
num_generated_tokens, request_id = await task
|
||||
if (num_generated_tokens != NUM_EXPECTED_TOKENS
|
||||
and failed_request_id is None):
|
||||
failed_request_id = request_id
|
||||
tokens = num_generated_tokens
|
||||
|
||||
assert failed_request_id is None, (
|
||||
f"{failed_request_id} generated {tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
engine.shutdown()
|
||||
205
tests/v1/engine/test_detokenizer.py
Normal file
205
tests/v1/engine/test_detokenizer.py
Normal file
@ -0,0 +1,205 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine import EngineCoreOutput
|
||||
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
|
||||
|
||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
FULL_STRINGS = [
|
||||
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
||||
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
|
||||
"Nick is the name of my brother in addition to my colleague from Red Hat.",
|
||||
]
|
||||
|
||||
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
|
||||
|
||||
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS]
|
||||
PROMPT_LEN = 5
|
||||
PROMPT_TOKENS = [
|
||||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
|
||||
]
|
||||
GENERATION_TOKENS = [
|
||||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
|
||||
]
|
||||
PROMPT_STRINGS = [
|
||||
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
||||
for prompt_tokens in PROMPT_TOKENS
|
||||
]
|
||||
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS]
|
||||
GENERATION_STRINGS = [
|
||||
text[prompt_len:]
|
||||
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN)
|
||||
]
|
||||
|
||||
|
||||
class MockEngineCore:
|
||||
"""Mock outputs form premade tokens lists."""
|
||||
|
||||
def __init__(self, tokens_list: List[List[int]]):
|
||||
self.tokens_list = tokens_list
|
||||
self.current_idx = 0
|
||||
|
||||
def get_outputs(self) -> List[EngineCoreOutput]:
|
||||
token_idx = self.current_idx
|
||||
self.current_idx += 1
|
||||
|
||||
outputs = []
|
||||
for req_idx, token_ids in enumerate(self.tokens_list):
|
||||
if len(token_ids) > token_idx:
|
||||
output = EngineCoreOutput(request_id=f"request-{req_idx}",
|
||||
new_token_ids=[token_ids[token_idx]],
|
||||
finished=False)
|
||||
if token_idx == len(token_ids) - 1:
|
||||
output.finished = True
|
||||
output.finish_reason = "stopped"
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
detokenizer = Detokenizer(TOKENIZER_NAME)
|
||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
DetokenizerRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
) for idx, (
|
||||
prompt,
|
||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||
]
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request in requests:
|
||||
detokenizer.add_request(request)
|
||||
|
||||
gen_strings = {}
|
||||
gen_tokens = {}
|
||||
while True:
|
||||
# Mock output from the EngineCore.
|
||||
outputs = engine_core.get_outputs()
|
||||
if len(outputs) == 0:
|
||||
break
|
||||
|
||||
# Step the Detokenizer.
|
||||
request_outputs, requests_to_abort = detokenizer.step(outputs)
|
||||
assert len(requests_to_abort) == 0
|
||||
|
||||
# Update tracking.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
new_text = request_output.outputs[0].text
|
||||
new_tokens = request_output.outputs[0].token_ids
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
gen_tokens[request_id] = new_tokens
|
||||
else:
|
||||
gen_strings[request_id] += new_text
|
||||
gen_tokens[request_id].extend(new_tokens)
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||
zip(GENERATION_STRINGS, GENERATION_TOKENS)):
|
||||
gen_str = gen_strings[f"request-{idx}"]
|
||||
gen_toks = gen_tokens[f"request-{idx}"]
|
||||
|
||||
assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
|
||||
assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"
|
||||
|
||||
assert detokenizer.get_num_unfinished_requests() == 0
|
||||
assert not detokenizer.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
def test_stop_string(include_stop_str_in_output: bool):
|
||||
detokenizer = Detokenizer(TOKENIZER_NAME)
|
||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
DetokenizerRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
stop=STOP_STRINGS,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
) for idx, (
|
||||
prompt,
|
||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||
]
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request in requests:
|
||||
detokenizer.add_request(request)
|
||||
|
||||
gen_strings = {}
|
||||
aborted = []
|
||||
while True:
|
||||
# Mock output from the EngineCore.
|
||||
outputs = engine_core.get_outputs()
|
||||
if len(outputs) == 0:
|
||||
break
|
||||
|
||||
# Step the Detokenizer.
|
||||
request_outputs, requests_to_abort = detokenizer.step(outputs)
|
||||
for request_output in request_outputs:
|
||||
# If aborted, we should not get a request output.
|
||||
assert request_output.request_id not in aborted
|
||||
aborted.extend(requests_to_abort)
|
||||
|
||||
# Update tracking.
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
assert request_output.outputs[0].finish_reason == "stop"
|
||||
|
||||
request_id = request_output.request_id
|
||||
new_text = request_output.outputs[0].text
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
else:
|
||||
gen_strings[request_id] += new_text
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str,
|
||||
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)):
|
||||
|
||||
# Request should be aborted.
|
||||
request_id = f"request-{idx}"
|
||||
assert request_id in aborted
|
||||
|
||||
# Collected values that were generated.
|
||||
gen_str = gen_strings[request_id]
|
||||
|
||||
# Construct reference strings.
|
||||
stop_str_idx = ref_gen_str.find(stop_str)
|
||||
ref_str_exc_stop = ref_gen_str[:stop_str_idx]
|
||||
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
|
||||
|
||||
if include_stop_str_in_output:
|
||||
assert gen_str == ref_str_inc_stop, (
|
||||
f"{gen_str=}, {ref_str_inc_stop=}")
|
||||
else:
|
||||
assert gen_str == ref_str_exc_stop, (
|
||||
f"{gen_str=}, {ref_str_exc_stop=}")
|
||||
|
||||
assert detokenizer.get_num_unfinished_requests() == 0
|
||||
assert not detokenizer.has_unfinished_requests()
|
||||
137
tests/v1/engine/test_engine_core.py
Normal file
137
tests/v1/engine/test_engine_core.py
Normal file
@ -0,0 +1,137 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request() -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id=uuid.uuid4(),
|
||||
prompt=PROMPT,
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
sampling_params=SamplingParams(),
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
|
||||
def test_engine_core(monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
engine_core.add_request(make_request())
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
# Second request.
|
||||
engine_core.add_request(make_request())
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
# Add two requests in a row.
|
||||
engine_core.add_request(make_request())
|
||||
engine_core.add_request(make_request())
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()) > 0:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
"""Test abort cycle."""
|
||||
|
||||
# Basic abort.
|
||||
req = make_request()
|
||||
request_id = req.request_id
|
||||
|
||||
engine_core.add_request(req)
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
engine_core.abort_requests([request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
# Add, step, abort 1 of the 3.
|
||||
req0 = make_request()
|
||||
req1 = make_request()
|
||||
req2 = make_request()
|
||||
|
||||
engine_core.add_request(req0)
|
||||
engine_core.add_request(req1)
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
engine_core.add_request(req2)
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 3
|
||||
|
||||
# Abort just one.
|
||||
engine_core.abort_requests([req1.request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
# Abort the other requests at the same time.
|
||||
engine_core.abort_requests([req2.request_id, req0.request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
202
tests/v1/engine/test_engine_core_client.py
Normal file
202
tests/v1/engine/test_engine_core_client.py
Normal file
@ -0,0 +1,202 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request(params: SamplingParams) -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt=PROMPT,
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
sampling_params=params,
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: Dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = client.get_output()
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
break
|
||||
|
||||
all_finished = True
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
if not out.finished:
|
||||
all_finished = False
|
||||
|
||||
if all_finished:
|
||||
break
|
||||
|
||||
|
||||
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = await client.get_output_async()
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
break
|
||||
|
||||
all_finished = True
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
if not out.finished:
|
||||
all_finished = False
|
||||
|
||||
if all_finished:
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
UsageContext.UNKNOWN_CONTEXT,
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
client.add_request(request)
|
||||
time.sleep(0.01)
|
||||
|
||||
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
# since we have to call get_output() explicitly
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
client.add_request(request)
|
||||
time.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Abort after request is finished."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
# since we have to call get_output() explicitly
|
||||
|
||||
request = requests[0]
|
||||
client.add_request(request)
|
||||
time.sleep(10.)
|
||||
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
# Shutdown the client.
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_core_client_asyncio(monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
UsageContext.UNKNOWN_CONTEXT,
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
outputs: Dict[str, List] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
|
||||
# Shutdown the client.
|
||||
client.shutdown()
|
||||
@ -2106,3 +2106,44 @@ class VllmConfig:
|
||||
self.model_config is not None and self.load_config is not None:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
def __str__(self):
|
||||
return ("model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, "
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
|
||||
(self.model_config.model, self.speculative_config,
|
||||
self.model_config.tokenizer,
|
||||
self.model_config.skip_tokenizer_init,
|
||||
self.model_config.tokenizer_mode,
|
||||
self.model_config.revision,
|
||||
self.model_config.override_neuron_config,
|
||||
self.model_config.tokenizer_revision,
|
||||
self.model_config.trust_remote_code,
|
||||
self.model_config.dtype,
|
||||
self.model_config.max_model_len,
|
||||
self.load_config.download_dir,
|
||||
self.load_config.load_format,
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.disable_custom_all_reduce,
|
||||
self.model_config.quantization,
|
||||
self.model_config.enforce_eager,
|
||||
self.cache_config.cache_dtype,
|
||||
self.model_config.quantization_param_path,
|
||||
self.device_config.device, self.decoding_config,
|
||||
self.observability_config, self.model_config.seed,
|
||||
self.model_config.served_model_name,
|
||||
self.scheduler_config.num_scheduler_steps,
|
||||
self.cache_config.enable_prefix_caching,
|
||||
self.model_config.use_async_output_proc,
|
||||
self.model_config.mm_processor_kwargs)
|
||||
@ -6,7 +6,6 @@ from typing import Iterator, List, Optional, Union
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
import vllm.envs
|
||||
from vllm import AsyncEngineArgs, SamplingParams
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
@ -113,17 +112,9 @@ class MQLLMEngine:
|
||||
load_general_plugins()
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
if vllm.envs.VLLM_USE_V1:
|
||||
# Lazy import: the v1 package isn't distributed
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
engine_class = V1LLMEngine
|
||||
else:
|
||||
engine_class = LLMEngine
|
||||
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
executor_class = engine_class._get_executor_cls(engine_config)
|
||||
|
||||
use_async_sockets = (engine_config.model_config.use_async_output_proc
|
||||
and not vllm.envs.VLLM_USE_V1)
|
||||
use_async_sockets = engine_config.model_config.use_async_output_proc
|
||||
|
||||
return cls(ipc_path=ipc_path,
|
||||
use_async_sockets=use_async_sockets,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -67,9 +67,13 @@ class StopChecker:
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
stop = self.check_stop_strings(
|
||||
seq.output_text, new_char_count, sampling_params.stop,
|
||||
sampling_params.include_stop_str_in_output)
|
||||
if stop is not None:
|
||||
stop_str, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
seq.output_text = seq.output_text[:truncate_to]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
@ -85,33 +89,40 @@ class StopChecker:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
def check_stop_strings(
|
||||
output_text: str,
|
||||
new_char_count: int,
|
||||
stop: List[str],
|
||||
include_in_output: bool,
|
||||
) -> Optional[Tuple[str, int]]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
Returns tuple (stop_string, offset) if matched or else None.
|
||||
|
||||
Where stop_string is the matched stop string and offset is the
|
||||
length to which output_text should be truncated, or -1 for no
|
||||
truncation.
|
||||
"""
|
||||
if not new_char_count or not sampling_params.stop:
|
||||
if not new_char_count or not stop:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
for stop_str in stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
stop_index = output_text.find(stop_str,
|
||||
-new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
if include_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
if stop_index >= len(output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
return stop_str, -1
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return stop_str, stop_index
|
||||
return None
|
||||
|
||||
@ -210,8 +210,11 @@ class LLM:
|
||||
# Logic to switch between engines is done at runtime instead of import
|
||||
# to avoid import order issues
|
||||
self.engine_class = self.get_engine_class()
|
||||
|
||||
# TODO(rob): enable mp by default (issue with fork vs spawn)
|
||||
self.llm_engine = self.engine_class.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
self.request_counter = Counter()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -26,7 +26,6 @@ from typing_extensions import assert_never
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -61,6 +60,11 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore
|
||||
else:
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
@ -126,7 +130,8 @@ async def build_async_engine_client_from_engine_args(
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
@ -143,6 +148,8 @@ async def build_async_engine_client_from_engine_args(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
if hasattr(engine_client, "shutdown"):
|
||||
engine_client.shutdown()
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
|
||||
@ -72,6 +72,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CUSTOM_OPS: List[str] = []
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
VLLM_USE_V1: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -473,6 +474,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# If set, use the V1 code path.
|
||||
"VLLM_USE_V1":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
|
||||
|
||||
# If set, enable multiprocessing in LLM for the V1 code path.
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -113,6 +113,36 @@ class RequestOutput:
|
||||
self.encoder_prompt = encoder_prompt
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
finished: bool = False,
|
||||
) -> "RequestOutput":
|
||||
"""Initialize a new RequestOutput object."""
|
||||
|
||||
# TODO: Support `n` > 1.
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=None,
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None, # TODO
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: SequenceGroup, use_cache: bool,
|
||||
|
||||
0
vllm/v1/__init__.py
Normal file
0
vllm/v1/__init__.py
Normal file
@ -70,7 +70,7 @@ class KVCacheManager:
|
||||
|
||||
Args:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of blocks that are computed for the request.
|
||||
"""
|
||||
@ -105,7 +105,7 @@ class KVCacheManager:
|
||||
Args:
|
||||
request: The request to append slots.
|
||||
num_tokens: The number of tokens to append.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of new blocks if new blocks are allocated, or None
|
||||
if new blocks are required but cannot be allocated.
|
||||
@ -176,7 +176,7 @@ class KVCacheManager:
|
||||
num_tokens: The number of tokens to allocate. Note that this does
|
||||
not include the tokens that have already been computed.
|
||||
computed_blocks: The blocks that have already been computed.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
@ -240,7 +240,8 @@ class KVCacheManager:
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
blocks = self.req_to_blocks.pop(request.request_id)
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
blocks = self.req_to_blocks.pop(request.request_id, [])
|
||||
if self.enable_caching:
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
@ -259,13 +260,13 @@ class KVCacheManager:
|
||||
"""Get new blocks from the free block pool, and add token IDs to
|
||||
allocated blocks if caching is enabled.
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
token_ids: The token IDs in the blocks. None if caching is disabled.
|
||||
parent_block: The parent block. Used to include block chain
|
||||
in the block hash.
|
||||
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.engine import EngineCoreOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
@ -237,13 +238,12 @@ class Scheduler:
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> List[Tuple[Request, int]]:
|
||||
) -> List[EngineCoreOutput]:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
new_running: List[Request] = []
|
||||
# (request, num_sampled_tokens)
|
||||
sampled: List[Tuple[Request, int]] = []
|
||||
engine_core_outputs: List[EngineCoreOutput] = []
|
||||
for request in self.running:
|
||||
req_id = request.request_id
|
||||
request.num_computed_tokens += num_scheduled_tokens[req_id]
|
||||
@ -257,17 +257,29 @@ class Scheduler:
|
||||
# generates at most one token at each step.
|
||||
token_id = sampled_token_ids[req_index]
|
||||
request.append_output_token_ids(token_id)
|
||||
sampled.append((request, 1))
|
||||
num_new_tokens = 1
|
||||
# TODO: Update the KV cache manager for prefix caching.
|
||||
|
||||
# Check if the request is finished.
|
||||
# Check for stop and update request state.
|
||||
# This must be called before me make the EngineCoreOutput.
|
||||
stopped = self._check_stop(request)
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
output = EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=request.output_token_ids[-num_new_tokens:],
|
||||
finished=request.is_finished(),
|
||||
finish_reason=request.get_finished_reason(),
|
||||
stop_reason=request.stop_reason)
|
||||
engine_core_outputs.append(output)
|
||||
|
||||
# Breakout of the loop.
|
||||
if stopped:
|
||||
continue
|
||||
|
||||
new_running.append(request)
|
||||
self.running = new_running
|
||||
return sampled
|
||||
return engine_core_outputs
|
||||
|
||||
def _check_stop(self, request: Request) -> bool:
|
||||
if (request.num_tokens >= self.max_model_len
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetokenizerRequest:
|
||||
|
||||
request_id: str
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
output_kind: RequestOutputKind
|
||||
|
||||
stop: List[str]
|
||||
include_stop_str_in_output: bool
|
||||
|
||||
|
||||
class EngineCoreRequest(msgspec.Struct, omit_defaults=True):
|
||||
|
||||
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
||||
# but this object is currently not playing well with msgspec
|
||||
# due to circular imports and typing we have in data.py
|
||||
|
||||
request_id: str
|
||||
#NOTE(Nick): I don't think we need to pass prompt here since it should
|
||||
# always be tokenized?
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
|
||||
class EngineCoreOutput(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False):
|
||||
|
||||
request_id: str
|
||||
new_token_ids: List[int]
|
||||
finished: bool
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False):
|
||||
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout and using an int enum for finish/stop reason
|
||||
|
||||
# [num_reqs]
|
||||
outputs: List[EngineCoreOutput]
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
368
vllm/v1/engine/async_llm.py
Normal file
368
vllm/v1/engine/async_llm.py
Normal file
@ -0,0 +1,368 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.async_stream import AsyncStream
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.detokenizer import Detokenizer
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
) -> None:
|
||||
assert start_engine_loop
|
||||
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
self.stat_loggers = stat_loggers
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
enable_lora=bool(vllm_config.lora_config))
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Request streams (map of request_id -> AsyncStream).
|
||||
self.request_streams: Dict[str, AsyncStream] = {}
|
||||
# List of cancelled request ids to be aborted.
|
||||
self.client_aborted_requests: List[str] = []
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(vllm_config.model_config,
|
||||
vllm_config.lora_config, self.tokenizer,
|
||||
input_registry)
|
||||
|
||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=usage_context,
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
)
|
||||
|
||||
self.output_handler = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
if engine_config is None:
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
else:
|
||||
vllm_config = engine_config
|
||||
|
||||
executor_class = cls._get_executor_cls(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
self.engine_core.shutdown()
|
||||
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.detokenizer.is_request_active(request_id):
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
# 1) Create a new AsyncStream for the request.
|
||||
stream = self._add_request_to_streams(request_id)
|
||||
|
||||
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
|
||||
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
trace_headers, prompt_adapter_request, priority)
|
||||
|
||||
# 3) Add the request to Detokenizer (this process).
|
||||
self.detokenizer.add_request(detokenizer_req)
|
||||
|
||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(engine_core_req)
|
||||
|
||||
# 5) Return the generator.
|
||||
return stream.generator()
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
# 2) Processing the Input.
|
||||
* 3) Adding the Request to the Detokenizer.
|
||||
* 4) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
# We start the output_handler on the first call to generate() so that
|
||||
# we can call __init__ before the event loop starts, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
if self.output_handler is None:
|
||||
self.output_handler = asyncio.create_task(
|
||||
self._run_output_handler())
|
||||
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
):
|
||||
yield output
|
||||
|
||||
def _finish_stream(self, request_id: str):
|
||||
stream = self.request_streams.pop(request_id, None)
|
||||
if stream is not None:
|
||||
stream.finish()
|
||||
|
||||
def _add_request_to_streams(
|
||||
self,
|
||||
request_id: str,
|
||||
) -> AsyncStream:
|
||||
|
||||
if request_id in self.request_streams:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
# Avoid streams having circular ref to parent AsyncLLM object.
|
||||
aborted_reqs = self.client_aborted_requests
|
||||
stream = AsyncStream(request_id, aborted_reqs.append)
|
||||
self.request_streams[request_id] = stream
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
return stream
|
||||
|
||||
async def _process_cancellations(self) -> None:
|
||||
"""
|
||||
Process requests cancelled from user disconnecting.
|
||||
|
||||
When a client disconnects, AsyncStream._cancel() is called.
|
||||
We passed a callback to AsyncStream(), which appends to
|
||||
self.client_aborted_requests.
|
||||
|
||||
As a result, if any requests are canceled from the user side
|
||||
the request_id will show up in self.client_aborted_requests.
|
||||
"""
|
||||
|
||||
# Avoid streams having circular ref to parent AsyncLLM object.
|
||||
if not self.client_aborted_requests:
|
||||
return
|
||||
reqs_to_abort = self.client_aborted_requests.copy()
|
||||
self.client_aborted_requests.clear()
|
||||
|
||||
# Remove from Detokenizer.
|
||||
self.detokenizer.abort_requests(reqs_to_abort)
|
||||
|
||||
# Remove from RequestStreams.
|
||||
for request_id in reqs_to_abort:
|
||||
if self.log_requests:
|
||||
logger.info("User-cancelled request %s.", request_id)
|
||||
self._finish_stream(request_id)
|
||||
|
||||
# Remove from EngineCore.
|
||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
||||
|
||||
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
|
||||
"""Process outputs by putting them into per-request AsyncStreams."""
|
||||
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
assert request_id in self.request_streams
|
||||
|
||||
# Each request in the API server pulls from the per-request stream.
|
||||
stream = self.request_streams.get(request_id)
|
||||
if stream is not None:
|
||||
stream.put(request_output)
|
||||
|
||||
# If finished, remove from the tracker.
|
||||
if request_output.finished:
|
||||
if self.log_requests:
|
||||
logger.info("Finished request %s.", request_id)
|
||||
self._finish_stream(request_id)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1) Pull EngineCoreOutput from the EngineCore.
|
||||
outputs = await self.engine_core.get_output_async()
|
||||
|
||||
# 2) Detokenize based on the output.
|
||||
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
|
||||
|
||||
# 3) Put the RequestOutputs into the per-request AsyncStreams.
|
||||
self._process_request_outputs(request_outputs)
|
||||
|
||||
# 4) Abort any requests that finished due to stop strings.
|
||||
await self.engine_core.abort_requests_async(reqs_to_abort)
|
||||
|
||||
# 5) Abort any requests due to client cancellations.
|
||||
await self._process_cancellations()
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
# TODO: can we eliminate these?
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
# Note: Who Calls this? I dont think this is actually used.
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_decoding_config(self):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
assert lora_request is None
|
||||
return self.detokenizer.tokenizer
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
logger.debug("Called do_log_stats.")
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
raise ValueError("Not supported on V1 yet.")
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
raise ValueError("Not supported on V1 yet.")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return Exception
|
||||
|
||||
|
||||
# Retain V0 name for backwards compatibility.
|
||||
AsyncLLMEngine = AsyncLLM
|
||||
55
vllm/v1/engine/async_stream.py
Normal file
55
vllm/v1/engine/async_stream.py
Normal file
@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
|
||||
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
self.request_id = request_id
|
||||
self._cancel = cancel
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if not self._finished:
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(
|
||||
self,
|
||||
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
|
||||
) -> None:
|
||||
if not self._finished:
|
||||
self._finished = True
|
||||
self._queue.put_nowait(exception if self._is_raisable(exception)
|
||||
else AsyncStream.STOP_ITERATION)
|
||||
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
finished = False
|
||||
try:
|
||||
while True:
|
||||
result = await self._queue.get()
|
||||
if self._is_raisable(result):
|
||||
finished = True
|
||||
if result == AsyncStream.STOP_ITERATION:
|
||||
return
|
||||
raise result
|
||||
yield result
|
||||
finally:
|
||||
self._finished = True
|
||||
if not finished:
|
||||
self._cancel(self.request_id)
|
||||
|
||||
@staticmethod
|
||||
def _is_raisable(value: Any):
|
||||
return isinstance(value, BaseException) or \
|
||||
(isinstance(value, type) and \
|
||||
issubclass(value, BaseException))
|
||||
352
vllm/v1/engine/core.py
Normal file
352
vllm/v1/engine/core.py
Normal file
@ -0,0 +1,352 @@
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.sharedctypes import Synchronized
|
||||
from typing import Any, Iterator, List, Tuple, Type, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreRequest, EngineCoreRequestType)
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
LOGGING_TIME_S = 5000
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
usage_context: UsageContext,
|
||||
):
|
||||
# Override the configs for V1.
|
||||
# FIXME
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
vllm_config.scheduler_config.max_num_seqs = 1024
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
vllm_config.scheduler_config.max_num_seqs = 1024
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2048
|
||||
|
||||
# TODO (ywang96): Enable APC by default when VLM supports it.
|
||||
if not vllm_config.model_config.is_multimodal_model:
|
||||
vllm_config.cache_config.enable_prefix_caching = True
|
||||
|
||||
assert vllm_config.model_config.task != "embedding"
|
||||
|
||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
|
||||
vllm_config.cache_config)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Setup scheduler.
|
||||
self.scheduler = Scheduler(vllm_config.scheduler_config,
|
||||
vllm_config.cache_config,
|
||||
vllm_config.lora_config)
|
||||
|
||||
self._last_logging_time = time.time()
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
cache_config: CacheConfig) -> Tuple[int, int]:
|
||||
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
|
||||
)
|
||||
|
||||
if cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
||||
num_gpu_blocks_override)
|
||||
num_gpu_blocks = num_gpu_blocks_override
|
||||
|
||||
num_cpu_blocks = 0
|
||||
self.model_executor.initialize_cache(num_gpu_blocks)
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: List[str]):
|
||||
"""Abort requests from the scheduler."""
|
||||
|
||||
# TODO: The scheduler doesn't really need to know the
|
||||
# specific finish reason, TBD whether we propagate that
|
||||
# (i.e. client-aborted vs stop criteria met).
|
||||
self.scheduler.finish_requests(request_ids,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def step(self) -> List[EngineCoreOutput]:
|
||||
"""Schedule, execute, and make output."""
|
||||
|
||||
if not self.scheduler.has_unfinished_requests():
|
||||
return []
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output)
|
||||
return engine_core_outputs
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_path: str,
|
||||
should_shutdown: Synchronized,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, usage_context)
|
||||
|
||||
# Signal from main process to shutdown (multiprocessing.Value).
|
||||
self.should_shutdown = should_shutdown
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = queue.Queue()
|
||||
self.output_queue = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
daemon=True).start()
|
||||
threading.Thread(target=self.process_output_socket,
|
||||
args=(output_path, ),
|
||||
daemon=True).start()
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
ready_socket.send_string(EngineCoreProc.READY_STR)
|
||||
|
||||
@contextmanager
|
||||
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for use """
|
||||
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
|
||||
yield socket
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("EngineCore had Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
ready_path: str,
|
||||
) -> None:
|
||||
"""Wait until the EngineCore is ready."""
|
||||
|
||||
try:
|
||||
sync_ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
socket = sync_ctx.socket(zmq.constants.PULL)
|
||||
socket.connect(ready_path)
|
||||
|
||||
# Wait for EngineCore to send EngineCoreProc.READY_STR.
|
||||
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for EngineCoreProc to startup.")
|
||||
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("EngineCoreProc failed to start.")
|
||||
|
||||
message = socket.recv_string()
|
||||
assert message == EngineCoreProc.READY_STR
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
sync_ctx.destroy(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def make_engine_core_process(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_path: str,
|
||||
should_shutdown: Synchronized,
|
||||
) -> BaseProcess:
|
||||
# The current process might have CUDA context,
|
||||
# so we need to spawn a new process.
|
||||
# NOTE(rob): this is a problem for using EngineCoreProc w/
|
||||
# LLM, since we need a if __name__ == "__main__" guard.
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
process_kwargs = {
|
||||
"input_path": input_path,
|
||||
"output_path": output_path,
|
||||
"ready_path": ready_path,
|
||||
"vllm_config": vllm_config,
|
||||
"executor_class": executor_class,
|
||||
"usage_context": usage_context,
|
||||
"should_shutdown": should_shutdown
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=EngineCoreProc.run_engine_core,
|
||||
kwargs=process_kwargs)
|
||||
proc.start()
|
||||
|
||||
# Wait for startup
|
||||
EngineCoreProc.wait_for_startup(proc, ready_path)
|
||||
return proc
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
try:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("EngineCore interrupted.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
# Loop until we get a shutdown signal.
|
||||
while not self.should_shutdown:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
if not self.scheduler.has_unfinished_requests():
|
||||
while True:
|
||||
try:
|
||||
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
|
||||
self._handle_client_request(req)
|
||||
break
|
||||
except queue.Empty:
|
||||
self._log_stats()
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
if self.should_shutdown:
|
||||
return
|
||||
|
||||
# 2) Handle any new client requests (Abort or Add).
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(req)
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
self._log_stats()
|
||||
|
||||
def _log_stats(self):
|
||||
"""Log basic stats every LOGGING_TIME_S"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
if now - self._last_logging_time > LOGGING_TIME_S:
|
||||
logger.info(
|
||||
"RUNNING: %s | WAITING: %s",
|
||||
len(self.scheduler.running),
|
||||
len(self.scheduler.waiting),
|
||||
)
|
||||
|
||||
self._last_logging_time = now
|
||||
|
||||
def _handle_client_request(
|
||||
self, request: Union[EngineCoreRequest, List[str]]) -> None:
|
||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||
|
||||
if isinstance(request, EngineCoreRequest):
|
||||
self.add_request(request)
|
||||
else:
|
||||
# TODO: make an EngineCoreAbort wrapper
|
||||
assert isinstance(request, list)
|
||||
self.abort_requests(request)
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
decoder_add_req = msgpack.Decoder(EngineCoreRequest)
|
||||
decoder_abort_req = msgpack.Decoder(list[str])
|
||||
|
||||
with self.make_socket(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
request_type = type_frame.buffer
|
||||
request_data = data_frame.buffer
|
||||
|
||||
# Deserialize the request data.
|
||||
if request_type == EngineCoreRequestType.ADD.value:
|
||||
request = decoder_add_req.decode(request_data)
|
||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||
request = decoder_abort_req.decode(request_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait(request)
|
||||
|
||||
def process_output_socket(self, output_path: str):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = msgpack.Encoder()
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
|
||||
with self.make_socket(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
engine_core_outputs = self.output_queue.get()
|
||||
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
||||
encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart((buffer, ), copy=False)
|
||||
218
vllm/v1/engine/core_client.py
Normal file
218
vllm/v1/engine/core_client.py
Normal file
@ -0,0 +1,218 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_open_zmq_ipc_path
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreRequest, EngineCoreRequestType)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EngineCoreClient:
|
||||
"""
|
||||
EngineCoreClient: subclasses handle different methods for pushing
|
||||
and pulling from the EngineCore for asyncio / multiprocessing.
|
||||
|
||||
Subclasses:
|
||||
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
|
||||
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
|
||||
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def make_client(
|
||||
*args,
|
||||
multiprocess_mode: bool,
|
||||
asyncio_mode: bool,
|
||||
**kwargs,
|
||||
) -> "EngineCoreClient":
|
||||
|
||||
# TODO: support this for debugging purposes.
|
||||
if asyncio_mode and not multiprocess_mode:
|
||||
raise NotImplementedError(
|
||||
"Running EngineCore in asyncio without multiprocessing "
|
||||
"is not currently supported.")
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
return AsyncMPClient(*args, **kwargs)
|
||||
|
||||
if multiprocess_mode and not asyncio_mode:
|
||||
return SyncMPClient(*args, **kwargs)
|
||||
|
||||
return InprocClient(*args, **kwargs)
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> List[EngineCoreOutput]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InprocClient(EngineCoreClient):
|
||||
"""
|
||||
InprocClient: client for in-process EngineCore. Intended
|
||||
for use in LLMEngine for V0-style add_request() and step()
|
||||
EngineCore setup in this process (no busy loop).
|
||||
|
||||
* pushes EngineCoreRequest directly into the EngineCore
|
||||
* pulls EngineCoreOutputs by stepping the EngineCore
|
||||
|
||||
TODO: support asyncio-mode for debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.engine_core = EngineCore(*args, **kwargs)
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
return self.engine_core.step()
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
self.engine_core.add_request(request)
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
EngineCore runs in a background process busy loop, getting
|
||||
new EngineCoreRequests and returning EngineCoreOutputs
|
||||
|
||||
* pushes EngineCoreRequests via input_socket
|
||||
* pulls EngineCoreOutputs via output_socket
|
||||
|
||||
* AsyncMPClient subclass for AsyncLLM usage
|
||||
* SyncMPClient subclass for LLM usage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
asyncio_mode: bool,
|
||||
**kwargs,
|
||||
):
|
||||
# Serialization setup.
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
|
||||
|
||||
# Path for IPC.
|
||||
ready_path = get_open_zmq_ipc_path()
|
||||
output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
|
||||
# Get output (EngineCoreOutput) from EngineCore.
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
self.output_socket.connect(output_path)
|
||||
|
||||
# Send input (EngineCoreRequest) to EngineCore.
|
||||
self.input_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.input_socket.bind(input_path)
|
||||
|
||||
# Start EngineCore in background process.
|
||||
self.should_shutdown = multiprocessing.Value('b', False, lock=False)
|
||||
self.proc = EngineCoreProc.make_engine_core_process(
|
||||
*args,
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
ready_path=ready_path,
|
||||
should_shutdown=self.should_shutdown,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
# Send shutdown signal to background process.
|
||||
self.should_shutdown = True
|
||||
|
||||
# Shut down the zmq context.
|
||||
self.ctx.destroy(linger=0)
|
||||
|
||||
# Shutdown the process if needed.
|
||||
if hasattr(self, "proc") and self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
|
||||
time.sleep(5)
|
||||
if self.proc.is_alive():
|
||||
self.proc.kill()
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, asyncio_mode=False, **kwargs)
|
||||
|
||||
def get_output(self) -> List[EngineCoreOutput]:
|
||||
|
||||
(frame, ) = self.output_socket.recv_multipart(copy=False)
|
||||
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
|
||||
return engine_core_outputs
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Union[EngineCoreRequest, List[str]]) -> None:
|
||||
|
||||
# (RequestType, SerializedRequest)
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
self._send_input(EngineCoreRequestType.ADD, request)
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, asyncio_mode=True, **kwargs)
|
||||
|
||||
async def get_output_async(self) -> List[EngineCoreOutput]:
|
||||
|
||||
frames = await self.output_socket.recv_multipart(copy=False)
|
||||
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
async def _send_input(
|
||||
self, request_type: EngineCoreRequestType,
|
||||
request: Union[EngineCoreRequest, List[str]]) -> None:
|
||||
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
await self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
await self._send_input(EngineCoreRequestType.ADD, request)
|
||||
|
||||
async def abort_requests_async(self, request_ids: List[str]) -> None:
|
||||
if len(request_ids) > 0:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
265
vllm/v1/engine/detokenizer.py
Normal file
265
vllm/v1/engine/detokenizer.py
Normal file
@ -0,0 +1,265 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
# Generation data
|
||||
output_text: str
|
||||
tokens: List[str]
|
||||
token_ids: List[int]
|
||||
|
||||
# Stop strings
|
||||
stop: List[str]
|
||||
include_stop_str_in_output: bool
|
||||
|
||||
# Metadata for incremental detokenization
|
||||
prefix_offset: int
|
||||
read_offset: int
|
||||
|
||||
# Parameters for detokenization
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
output_kind: RequestOutputKind
|
||||
|
||||
# TODO: Probably decouple these
|
||||
request_id: str
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
|
||||
# Accounting for stop string buffering
|
||||
stop_buffer_length: int
|
||||
_last_output_text_offset: int = 0
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> List[int]:
|
||||
assert len(self.token_ids) >= len(self.prompt_token_ids)
|
||||
return self.token_ids[len(self.prompt_token_ids):]
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: DetokenizerRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
|
||||
stops = request.stop
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
# from streamed output.
|
||||
if stops and not request.include_stop_str_in_output:
|
||||
stop_buffer_length = max(len(s) for s in stops) - 1
|
||||
else:
|
||||
stop_buffer_length = 0
|
||||
|
||||
return cls(
|
||||
output_text="",
|
||||
tokens=tokens,
|
||||
# Detokenizer mutates this list, so need a unique copy.
|
||||
# NOTE(Nick): could we take ownership of it though?
|
||||
token_ids=request.prompt_token_ids.copy(),
|
||||
stop=stops,
|
||||
include_stop_str_in_output=request.include_stop_str_in_output,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=request.
|
||||
spaces_between_special_tokens,
|
||||
output_kind=request.output_kind,
|
||||
request_id=request.request_id,
|
||||
prompt=request.prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
tokenizer=tokenizer,
|
||||
stop_buffer_length=stop_buffer_length,
|
||||
)
|
||||
|
||||
def add_tokens(
|
||||
self,
|
||||
new_token_ids: List[int],
|
||||
finish_reason: Optional[str],
|
||||
stop_reason: Optional[str],
|
||||
) -> Optional[RequestOutput]:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Update the RequestOutput with the new text.
|
||||
"""
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
decoded_text = ""
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=self.token_ids,
|
||||
prev_tokens=self.tokens,
|
||||
prefix_offset=self.prefix_offset,
|
||||
read_offset=self.read_offset,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
self.output_text += new_decoded_token_text
|
||||
|
||||
decoded_text += new_decoded_token_text
|
||||
|
||||
# 2) Evaluate stop criteria.
|
||||
if self.stop:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(decoded_text),
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_str, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
finish_reason = "stop" # TODO: use constant
|
||||
stop_reason = stop_str
|
||||
|
||||
# TODO: handle stop_token_ids here too?
|
||||
|
||||
# 3) Update the RequestOutput object with the new text.
|
||||
finished = bool(finish_reason)
|
||||
if self.output_kind == RequestOutputKind.FINAL_ONLY \
|
||||
and not finished:
|
||||
return None
|
||||
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
output_text = self._get_next_output_text(finished, delta)
|
||||
token_ids = new_token_ids if delta else self.output_token_ids
|
||||
|
||||
request_output = RequestOutput.new(
|
||||
self.request_id,
|
||||
self.prompt,
|
||||
self.prompt_token_ids,
|
||||
output_text,
|
||||
token_ids,
|
||||
finished,
|
||||
)
|
||||
|
||||
if finished:
|
||||
completion_output = request_output.outputs[0]
|
||||
completion_output.finish_reason = finish_reason
|
||||
completion_output.stop_reason = stop_reason
|
||||
|
||||
return request_output
|
||||
|
||||
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if buffer_length else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
# TODO: once we support LoRA, we should should pass the tokenizer
|
||||
# here. We currently have two copies (this + in the LLMEngine).
|
||||
self.tokenizer = get_tokenizer(tokenizer_name)
|
||||
|
||||
# Request id -> IncrementalDetokenizer
|
||||
self.request_states: Dict[str, IncrementalDetokenizer] = {}
|
||||
|
||||
def is_request_active(self, request_id: str):
|
||||
return request_id in self.request_states
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: Iterable[str],
|
||||
) -> None:
|
||||
"""Remove the request_ids from the Detokenizer."""
|
||||
|
||||
for request_id in request_ids:
|
||||
self.request_states.pop(request_id, None)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: DetokenizerRequest,
|
||||
):
|
||||
"""Add new request to the Detokenizer."""
|
||||
|
||||
assert (request.request_id not in self.request_states)
|
||||
|
||||
request_state = IncrementalDetokenizer.from_new_request(
|
||||
self.tokenizer, request)
|
||||
self.request_states[request.request_id] = request_state
|
||||
|
||||
def step(
|
||||
self, encore_core_outputs: List[EngineCoreOutput]
|
||||
) -> Tuple[List[RequestOutput], List[str]]:
|
||||
"""Update state and request the RequestOutputs to the LLMEngine."""
|
||||
|
||||
request_outputs: List[RequestOutput] = []
|
||||
requests_to_abort: List[str] = []
|
||||
for engine_core_output in encore_core_outputs:
|
||||
request_id = engine_core_output.request_id
|
||||
detokenizer = self.request_states.get(request_id)
|
||||
if detokenizer is None:
|
||||
# Ignore output for already-aborted request.
|
||||
continue
|
||||
|
||||
# Detokenize and update state.
|
||||
request_output = detokenizer.add_tokens(
|
||||
new_token_ids=engine_core_output.new_token_ids,
|
||||
finish_reason=engine_core_output.finish_reason,
|
||||
stop_reason=engine_core_output.stop_reason,
|
||||
)
|
||||
|
||||
if request_output is not None:
|
||||
# Add to RequestOutputs list.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if request_output.finished:
|
||||
self.request_states.pop(request_id)
|
||||
if not engine_core_output.finished:
|
||||
requests_to_abort.append(request_id)
|
||||
|
||||
# Return to EngineClient.
|
||||
return request_outputs, requests_to_abort
|
||||
@ -1,35 +1,28 @@
|
||||
import time
|
||||
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
from typing import Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.detokenizer import Detokenizer
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Legacy LLMEngine for backwards compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,146 +33,36 @@ class LLMEngine:
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
|
||||
# TODO: remove the local variables and use self.* throughout the class.
|
||||
model_config = self.model_config = vllm_config.model_config
|
||||
cache_config = self.cache_config = vllm_config.cache_config
|
||||
lora_config = self.lora_config = vllm_config.lora_config
|
||||
parallel_config = self.parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
|
||||
device_config = self.device_config = vllm_config.device_config
|
||||
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
|
||||
load_config = self.load_config = vllm_config.load_config
|
||||
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
||||
# TODO: Can we avoid this?
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
enable_lora=bool(vllm_config.lora_config))
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config.model_config,
|
||||
vllm_config.lora_config, self.tokenizer,
|
||||
input_registry)
|
||||
|
||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
|
||||
self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer)
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
usage_context,
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
)
|
||||
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
|
||||
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
||||
)
|
||||
|
||||
# Override the configs for V1.
|
||||
# FIXME
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
scheduler_config.max_num_seqs = 1024
|
||||
scheduler_config.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
scheduler_config.max_num_seqs = 1024
|
||||
scheduler_config.max_num_batched_tokens = 2048
|
||||
|
||||
# TODO (ywang96): Enable APC by default when VLM supports it.
|
||||
if not model_config.is_multimodal_model:
|
||||
cache_config.enable_prefix_caching = True
|
||||
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, "
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
model_config.tokenizer,
|
||||
model_config.skip_tokenizer_init,
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.override_neuron_config,
|
||||
model_config.tokenizer_revision,
|
||||
model_config.trust_remote_code,
|
||||
model_config.dtype,
|
||||
model_config.max_model_len,
|
||||
load_config.download_dir,
|
||||
load_config.load_format,
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.disable_custom_all_reduce,
|
||||
model_config.quantization,
|
||||
model_config.enforce_eager,
|
||||
cache_config.cache_dtype,
|
||||
model_config.quantization_param_path,
|
||||
device_config.device,
|
||||
decoding_config,
|
||||
observability_config,
|
||||
model_config.seed,
|
||||
model_config.served_model_name,
|
||||
scheduler_config.num_scheduler_steps,
|
||||
cache_config.enable_prefix_caching,
|
||||
model_config.use_async_output_proc,
|
||||
model_config.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
assert not self.model_config.skip_tokenizer_init
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
if self.tokenizer:
|
||||
# Ping the tokenizer to ensure liveness if it runs in a
|
||||
# different process.
|
||||
self.tokenizer.ping()
|
||||
self.detokenizer = Detokenizer(
|
||||
tokenizer_name=self.model_config.tokenizer,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code)
|
||||
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
|
||||
# Request id -> Request
|
||||
self.requests: Dict[str, Request] = {}
|
||||
# NOTE(woosuk): Now that the detokenizer works asynchronously, we need
|
||||
# to keep track of how many steps each request has been lagged behind
|
||||
# in terms of detokenization.
|
||||
# Request id -> how many detokenizer steps the request should wait for.
|
||||
self.num_lagged_steps: Dict[str, int] = {}
|
||||
# OPTIMIZATION: Cache the request output and update it incrementally.
|
||||
# This is used to avoid creating a new RequestOutput object every step.
|
||||
# Request id -> RequestOutput
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config)
|
||||
assert self.model_config.task != "embedding"
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
|
||||
def __del__(self):
|
||||
# Small hack- implicit clean up of resources on garbage collect
|
||||
# TODO: this should probably be explicitly invoked when we're done with
|
||||
# the engine
|
||||
self.terminate_detokenizer()
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
|
||||
)
|
||||
|
||||
if self.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_gpu_blocks,
|
||||
num_gpu_blocks_override)
|
||||
num_gpu_blocks = num_gpu_blocks_override
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
self.model_executor.initialize_cache(num_gpu_blocks)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
@ -187,71 +70,49 @@ class LLMEngine:
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
enable_multiprocessing: bool = False,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
vllm_config=engine_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
return engine
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = cls._get_executor_cls(vllm_config)
|
||||
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
enable_lora=bool(self.lora_config))
|
||||
if VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.debug("Enabling multiprocessing for LLMEngine.")
|
||||
enable_multiprocessing = True
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
# Create the LLMEngine.
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> None:
|
||||
assert prompt_adapter_request is None
|
||||
assert trace_headers is None
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
# TODO(woosuk): Support embedding mode.
|
||||
assert isinstance(params, SamplingParams)
|
||||
sampling_params = params.clone()
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# TODO(woosuk): Check max_logprobs
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
req = Request(request_id, processed_inputs, params, eos_token_id,
|
||||
arrival_time)
|
||||
self.requests[request_id] = req
|
||||
self.num_lagged_steps[request_id] = 0
|
||||
self.scheduler.add_request(req)
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
raise NotImplementedError("TP not implemented yet.")
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.detokenizer.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.detokenizer.has_unfinished_requests()
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def abort_request(self, request_ids: List[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
self.detokenizer.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -263,261 +124,46 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
# 1) Process raw inputs into the request.
|
||||
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
trace_headers, prompt_adapter_request, priority)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
# 2) Add the request to Detokenizer.
|
||||
self.detokenizer.add_request(detokenizer_req)
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
self.scheduler.finish_requests(request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
self._free_request(request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return len(self.requests)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return len(self.requests) > 0
|
||||
# 3) Add the request to EngineCore.
|
||||
self.engine_core.add_request(engine_core_req)
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
# NOTE(woosuk): This method may return an empty list when the
|
||||
# detokenizer is still processing the outputs. This should not be
|
||||
# considered as the end of the generation process.
|
||||
# FIXME(woosuk): Currently, the step method is inefficient because it
|
||||
# creates RequestOutput objects for all running requests, while they
|
||||
# may not be needed unless the output is streamed to the client.
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
sampled = self.scheduler.update_from_output(
|
||||
scheduler_output, output)
|
||||
self.send_to_detokenizer(sampled)
|
||||
req_outputs = self.recv_from_detokenizer()
|
||||
return req_outputs
|
||||
|
||||
def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
|
||||
inputs = DetokenizerInputs(
|
||||
req_ids=[],
|
||||
prompt_token_ids=[],
|
||||
new_token_ids=[],
|
||||
skip_special_tokens=[],
|
||||
spaces_between_special_tokens=[],
|
||||
free_req_ids=[], # TODO(woosuk): Implement freeing.
|
||||
)
|
||||
for req, num_tokens in sampled:
|
||||
inputs.req_ids.append(req.request_id)
|
||||
if req.num_output_tokens == num_tokens:
|
||||
# The request is first detokenized.
|
||||
inputs.prompt_token_ids.append(req.prompt_token_ids)
|
||||
else:
|
||||
# The prompt token ids are already cached in the detokenizer.
|
||||
inputs.prompt_token_ids.append([])
|
||||
inputs.new_token_ids.append(req.output_token_ids[-num_tokens:])
|
||||
inputs.skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens)
|
||||
inputs.spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens)
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
engine_core_outputs = self.engine_core.get_output()
|
||||
|
||||
# Update the number of lagged steps.
|
||||
self.num_lagged_steps[req.request_id] += 1
|
||||
self.detokenizer.send(inputs)
|
||||
# 2) Detokenizer the EngineCoreOutput.
|
||||
request_outputs, requests_to_abort = self.detokenizer.step(
|
||||
engine_core_outputs)
|
||||
|
||||
def recv_from_detokenizer(self) -> List[RequestOutput]:
|
||||
detokenizer_output = self.detokenizer.recv()
|
||||
if detokenizer_output is None:
|
||||
return []
|
||||
# 3) Abort requests that finished due to stopping criteria.
|
||||
if requests_to_abort:
|
||||
self.abort_request(requests_to_abort)
|
||||
|
||||
req_outputs: List[RequestOutput] = []
|
||||
num_reqs = len(detokenizer_output.req_ids)
|
||||
for i in range(num_reqs):
|
||||
req_id = detokenizer_output.req_ids[i]
|
||||
if req_id not in self.requests:
|
||||
# The request has been aborted while the detokenizer was
|
||||
# processing the outputs.
|
||||
continue
|
||||
return request_outputs
|
||||
|
||||
req = self.requests[req_id]
|
||||
req.output_text += detokenizer_output.detokenized_texts[i]
|
||||
# TODO(rob): Can we get rid of these?
|
||||
|
||||
self.num_lagged_steps[req_id] -= 1
|
||||
finished = (self.num_lagged_steps[req_id] == 0
|
||||
and req.is_finished())
|
||||
req_output = self._make_request_output(
|
||||
req, detokenizer_output.num_output_token_ids[i],
|
||||
detokenizer_output.detokenized_texts[i], finished)
|
||||
req_outputs.append(req_output)
|
||||
|
||||
if finished:
|
||||
self._free_request(req_id)
|
||||
return req_outputs
|
||||
|
||||
def terminate_detokenizer(self) -> None:
|
||||
self.detokenizer.terminate()
|
||||
|
||||
def _make_request_output(
|
||||
self,
|
||||
request: Request,
|
||||
num_output_tokens: int,
|
||||
new_output_text: str,
|
||||
finished: bool,
|
||||
) -> RequestOutput:
|
||||
req_output = self.request_outputs.get(request.request_id)
|
||||
if req_output is None:
|
||||
# TODO: Support `n` > 1.
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None, # TODO
|
||||
finish_reason=None,
|
||||
stop_reason=None,
|
||||
lora_request=None,
|
||||
)
|
||||
req_output = RequestOutput(
|
||||
request_id=request.request_id,
|
||||
prompt=request.prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_logprobs=None, # TODO
|
||||
outputs=[completion_output],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
self.request_outputs[request.request_id] = req_output
|
||||
|
||||
completion_output = req_output.outputs[0]
|
||||
if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE:
|
||||
completion_output.text += new_output_text
|
||||
completion_output.token_ids = (
|
||||
request.output_token_ids[:num_output_tokens])
|
||||
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
|
||||
completion_output.text = new_output_text
|
||||
num_prev_tokens = len(completion_output.token_ids)
|
||||
completion_output.token_ids = request.output_token_ids[
|
||||
num_prev_tokens:num_output_tokens]
|
||||
elif (request.sampling_params.output_kind ==
|
||||
RequestOutputKind.FINAL_ONLY):
|
||||
if finished:
|
||||
completion_output.text = request.output_text
|
||||
completion_output.token_ids = request.output_token_ids
|
||||
else:
|
||||
completion_output.text = ""
|
||||
completion_output.token_ids = []
|
||||
|
||||
if finished:
|
||||
completion_output.finish_reason = request.get_finished_reason()
|
||||
completion_output.stop_reason = request.stop_reason
|
||||
req_output.finished = finished
|
||||
return req_output
|
||||
|
||||
def _free_request(self, request_id: str) -> None:
|
||||
self.requests.pop(request_id, None)
|
||||
self.num_lagged_steps.pop(request_id, None)
|
||||
self.request_outputs.pop(request_id, None)
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
self.model_executor.check_health()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if self.model_config.is_multimodal_model:
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
raise ValueError(
|
||||
f"The prompt (total length {len(prompt_ids)}) is too long "
|
||||
f"to fit into the model (context length {max_prompt_len}). "
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_parallel_config(self) -> ParallelConfig:
|
||||
"""Gets the parallel configuration."""
|
||||
return self.parallel_config
|
||||
|
||||
def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Gets the decoding configuration."""
|
||||
return self.decoding_config
|
||||
|
||||
def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Gets the scheduler configuration."""
|
||||
return self.scheduler_config
|
||||
|
||||
def get_lora_config(self) -> LoRAConfig:
|
||||
"""Gets the LoRA configuration."""
|
||||
return self.lora_config
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, engine_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
def do_log_stats(self, *args, **kwargs) -> None:
|
||||
def get_model_config(self):
|
||||
pass
|
||||
|
||||
def is_encoder_decoder_model(self) -> bool:
|
||||
return False
|
||||
|
||||
def start_profile(self) -> None:
|
||||
def is_encoder_decoder_model(self):
|
||||
pass
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
def start_profile(self):
|
||||
pass
|
||||
|
||||
def get_tokenizer_group(self, *args, **kwargs):
|
||||
return self.tokenizer
|
||||
def stop_profile(self):
|
||||
pass
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
def get_tokenizer_group(self, group_type):
|
||||
pass
|
||||
|
||||
128
vllm/v1/engine/processor.py
Normal file
128
vllm/v1/engine/processor.py
Normal file
@ -0,0 +1,128 @@
|
||||
import time
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
tokenizer: AnyTokenizer,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
):
|
||||
|
||||
self.model_config = model_config
|
||||
self.lora_config = lora_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
|
||||
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
|
||||
# This ideally should releases the GIL, so we should not block the
|
||||
# asyncio loop while this is running.
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Tuple[DetokenizerRequest, EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support embedding mode.
|
||||
# TODO(woosuk): Check max_logprobs
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
||||
|
||||
# Process inputs.
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case
|
||||
sampling_params = params.clone()
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# Make Request for Detokenizer.
|
||||
detokenizer_request = DetokenizerRequest(
|
||||
request_id, processed_inputs.get("prompt"),
|
||||
processed_inputs.get("prompt_token_ids"),
|
||||
sampling_params.skip_special_tokens,
|
||||
sampling_params.spaces_between_special_tokens,
|
||||
sampling_params.output_kind, sampling_params.stop,
|
||||
sampling_params.include_stop_str_in_output)
|
||||
|
||||
# Make Request for EngineCore.
|
||||
engine_core_request = EngineCoreRequest(
|
||||
request_id, processed_inputs.get("prompt"),
|
||||
processed_inputs.get("prompt_token_ids"), sampling_params,
|
||||
eos_token_id, arrival_time, lora_request)
|
||||
|
||||
return detokenizer_request, engine_core_request
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if self.model_config.is_multimodal_model:
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
raise ValueError(
|
||||
f"The prompt (total length {len(prompt_ids)}) is too long "
|
||||
f"to fit into the model (context length {max_prompt_len}). "
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
@ -1,9 +1,11 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.inputs.data import DecoderOnlyInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -43,9 +45,22 @@ class Request:
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self._output_token_ids: List[int] = []
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.output_text = ""
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
inputs=DecoderOnlyInputs(type="token",
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt),
|
||||
sampling_params=request.sampling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> ConstantList[int]:
|
||||
# Prevent directly appending to the output_token_ids since
|
||||
|
||||
@ -1,228 +0,0 @@
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
class DetokenizerInputs(msgspec.Struct):
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: List[str]
|
||||
# A request's prompt token ids is sent to the detokenizer only when
|
||||
# the request is first detokenized. Otherwise, an empty list is sent.
|
||||
prompt_token_ids: List[List[int]]
|
||||
new_token_ids: List[List[int]]
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
|
||||
# [num_free_reqs]
|
||||
free_req_ids: List[str]
|
||||
|
||||
|
||||
class DetokenizerOutputs(msgspec.Struct):
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: List[str]
|
||||
detokenized_texts: List[str]
|
||||
# NOTE(woosuk): The number of the output token ids of each request
|
||||
# at the time of detokenization. The detokenizer returns this to the engine
|
||||
# because the request state (including the output token ids) is
|
||||
# asynchronously updated in the engine, while RequestOutput requires the
|
||||
# output token ids to be consistent with the detokenized text.
|
||||
num_output_token_ids: List[int]
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
|
||||
def __init__(self, tokenizer_name: str, tokenizer_mode: str,
|
||||
trust_remote_code: bool):
|
||||
# FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
|
||||
# For example, it does not terminate properly. We need to improve this.
|
||||
self.push_port = get_open_port()
|
||||
self.pull_port = get_open_port()
|
||||
# NOTE: The push port of the engine process should be the same as the
|
||||
# pull port of the detokenizer process. Vice versa.
|
||||
self.detokenizer = DetokenizerProc(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
push_port=self.pull_port,
|
||||
pull_port=self.push_port)
|
||||
self.detokenizer.start()
|
||||
|
||||
self.zmq_context = zmq.Context()
|
||||
self.push_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
self.push_socket.connect(f"tcp://localhost:{self.push_port}")
|
||||
self.pull_socket = self.zmq_context.socket(zmq.PULL)
|
||||
self.pull_socket.connect(f"tcp://localhost:{self.pull_port}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.pull_socket, zmq.POLLIN)
|
||||
self.msgpack_encoder = msgpack.Encoder()
|
||||
self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs)
|
||||
|
||||
def send(self, inputs: DetokenizerInputs) -> None:
|
||||
self.push_socket.send(self.msgpack_encoder.encode(inputs),
|
||||
flags=zmq.NOBLOCK)
|
||||
|
||||
def recv(self) -> Optional[DetokenizerOutputs]:
|
||||
socks = dict(self.poller.poll(timeout=0))
|
||||
if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
|
||||
msg = self.pull_socket.recv()
|
||||
return self.msgpack_decoder.decode(msg)
|
||||
return None
|
||||
|
||||
def terminate(self) -> None:
|
||||
self.detokenizer.kill()
|
||||
self.detokenizer.join()
|
||||
|
||||
|
||||
class DetokenizerProc(multiprocessing.Process):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_name: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
pull_port: int,
|
||||
push_port: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer_name = tokenizer_name
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
# NOTE: The pull_port of the detokenizer process should be the same as
|
||||
# the push_port of the engine process. Vice versa.
|
||||
self.pull_port = pull_port
|
||||
self.push_port = push_port
|
||||
|
||||
def run(self):
|
||||
# Initialize these objects after the process is forked since they are
|
||||
# not picklable.
|
||||
self.msgpack_encoder = msgpack.Encoder()
|
||||
self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
|
||||
self.tokenizer = get_tokenizer(
|
||||
tokenizer_name=self.tokenizer_name,
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code)
|
||||
# req_id -> RequestState
|
||||
self.request_states: Dict[str, RequestState] = {}
|
||||
|
||||
self.zmq_context = zmq.Context()
|
||||
self.pull_socket = self.zmq_context.socket(zmq.PULL)
|
||||
self.pull_socket.bind(f"tcp://*:{self.pull_port}")
|
||||
self.push_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
self.push_socket.bind(f"tcp://*:{self.push_port}")
|
||||
|
||||
while True:
|
||||
if self.pull_socket.poll(timeout=1000) == 0:
|
||||
# Nothing to read
|
||||
continue
|
||||
message = self.pull_socket.recv()
|
||||
inputs = self.msgpack_decoder.decode(message)
|
||||
|
||||
for req_id in inputs.free_req_ids:
|
||||
self.free(req_id)
|
||||
|
||||
detokenized_texts: List[str] = []
|
||||
num_output_token_ids: List[int] = []
|
||||
num_reqs = len(inputs.req_ids)
|
||||
for i in range(num_reqs):
|
||||
req_id = inputs.req_ids[i]
|
||||
if req_id not in self.request_states:
|
||||
self.add_request(
|
||||
request_id=req_id,
|
||||
prompt_token_ids=inputs.prompt_token_ids[i],
|
||||
skip_special_tokens=inputs.skip_special_tokens[i],
|
||||
spaces_between_special_tokens=inputs.
|
||||
spaces_between_special_tokens[i],
|
||||
)
|
||||
new_str = self.detokenize(req_id, inputs.new_token_ids[i])
|
||||
detokenized_texts.append(new_str)
|
||||
req_state = self.request_states[req_id]
|
||||
num_output_token_ids.append(
|
||||
len(req_state.token_ids) - req_state.num_prompt_tokens)
|
||||
|
||||
detokenized = DetokenizerOutputs(
|
||||
req_ids=inputs.req_ids,
|
||||
detokenized_texts=detokenized_texts,
|
||||
num_output_token_ids=num_output_token_ids,
|
||||
)
|
||||
self.push_socket.send(self.msgpack_encoder.encode(detokenized),
|
||||
flags=zmq.NOBLOCK)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: List[int],
|
||||
skip_special_tokens: bool,
|
||||
spaces_between_special_tokens: bool,
|
||||
) -> None:
|
||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_ids=prompt_token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
self.request_states[request_id] = RequestState(
|
||||
req_id=request_id,
|
||||
token_ids=prompt_token_ids,
|
||||
tokens=tokens,
|
||||
num_prompt_tokens=len(prompt_token_ids),
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
del self.request_states[request_id]
|
||||
|
||||
def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
req_state = self.request_states[request_id]
|
||||
decoded_text = ""
|
||||
for new_token_id in new_token_ids:
|
||||
req_state.token_ids.append(new_token_id)
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=req_state.token_ids,
|
||||
prev_tokens=req_state.tokens,
|
||||
prefix_offset=req_state.prefix_offset,
|
||||
read_offset=req_state.read_offset,
|
||||
skip_special_tokens=req_state.skip_special_tokens,
|
||||
spaces_between_special_tokens=req_state.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
req_state.tokens.extend(new_tokens)
|
||||
req_state.prefix_offset = prefix_offset
|
||||
req_state.read_offset = read_offset
|
||||
req_state.output_text += new_decoded_token_text
|
||||
decoded_text += new_decoded_token_text
|
||||
return decoded_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestState:
|
||||
|
||||
req_id: str
|
||||
|
||||
token_ids: List[int]
|
||||
tokens: List[str]
|
||||
num_prompt_tokens: int
|
||||
|
||||
prefix_offset: int
|
||||
read_offset: int
|
||||
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
|
||||
output_text: str = ""
|
||||
Loading…
x
Reference in New Issue
Block a user