[Core] Remove prompt string from engine core data structures (#17214)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-25 23:41:05 -07:00 committed by GitHub
parent 513f074766
commit df6f3ce883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 40 additions and 76 deletions

View File

@ -60,8 +60,8 @@ def _run_incremental_decode(tokenizer,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
params, None, 0.0, None)
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
None, 0.0, None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(

View File

@ -37,7 +37,6 @@ def make_request(request_id,
return Request(
request_id=request_id,
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,

View File

@ -29,7 +29,6 @@ def make_request(request_id,
return Request(
request_id=request_id,
prompt=None,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,

View File

@ -132,7 +132,6 @@ def create_requests(num_requests: int,
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt=None,
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,

View File

@ -31,8 +31,7 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=uuid.uuid4(),
prompt=PROMPT,
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None,
mm_hashes=None,

View File

@ -35,7 +35,6 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None,
mm_hashes=None,

View File

@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
# Make N requests.
requests = [
EngineCoreRequest(request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
))
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add requests to the detokenizer.
for request in requests:
output_processor.add_request(request)
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
gen_strings = {}
gen_tokens = {}
@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
]
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
))
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add requests to the detokenizer.
for request in requests:
output_processor.add_request(request)
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
gen_tokens = {}
gen_logprobs = {}
@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id = "request-0"
request = EngineCoreRequest(
request_id=request_id,
prompt=prompt_string,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool,
))
# Add request to the detokenizer.
output_processor.add_request(request)
output_processor.add_request(request, prompt_string)
# Loop over engine core steps; run output processor
gen_string = ""
@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool,
requests = [
EngineCoreRequest(
request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
))
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add requests to the detokenizer.
for request in requests:
output_processor.add_request(request)
for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
output_processor.add_request(request, prompt)
gen_strings = {}
gen_tokens = {}
@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors):
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(),
) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add all requests except one to the OutputProcessor.
num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]:
output_processor.add_request(request)
output_processor.add_request(request, None)
inactive_request = requests[num_active]
# First iteration has 2 prefills.
@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active
# Add a new request - prefill and 2 decodes in this step.
output_processor.add_request(inactive_request)
output_processor.add_request(inactive_request, None)
num_active += 1
outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats()

View File

@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],

View File

@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],

View File

@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],

View File

@ -22,7 +22,6 @@ class NewRequestData:
req_id: str
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
@ -40,7 +39,6 @@ class NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,

View File

@ -49,9 +49,6 @@ class EngineCoreRequest(
# due to circular imports and typing we have in data.py
request_id: str
# NOTE(ywang96): original text prompt is needed when a request is added to
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]

View File

@ -217,14 +217,12 @@ class AsyncLLM(EngineClient):
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
if params.n == 1:
await self._add_request(request, None, 0, queue)
await self._add_request(request, prompt_str, None, 0, queue)
return queue
# Fan out child requests (for n>1).
@ -234,15 +232,18 @@ class AsyncLLM(EngineClient):
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
await self._add_request(child_request, prompt_str, parent_request,
idx, queue)
return queue
async def _add_request(self, request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest], index: int,
queue: RequestOutputCollector):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue)
self.output_processor.add_request(request, prompt, parent_req, index,
queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)

View File

@ -583,9 +583,6 @@ class SyncMPClient(MPClient):
return future.result()
def add_request(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: list[str]) -> None:
@ -772,9 +769,6 @@ class AsyncMPClient(MPClient):
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
@ -867,9 +861,6 @@ class DPAsyncMPClient(AsyncMPClient):
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
request.current_wave = self.current_wave
chosen_engine = self.get_core_engine_for_request()

View File

@ -180,17 +180,15 @@ class LLMEngine:
priority: int = 0,
) -> None:
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, None, 0)
self.output_processor.add_request(request, prompt_str, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request)
return
@ -204,7 +202,8 @@ class LLMEngine:
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, parent_req, idx)
self.output_processor.add_request(child_request, prompt_str,
parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)

View File

@ -109,6 +109,7 @@ class RequestState:
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[RequestOutputCollector],
@ -123,7 +124,7 @@ class RequestState:
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind,
prompt=request.prompt,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
@ -267,6 +268,7 @@ class OutputProcessor:
def add_request(
self,
request: EngineCoreRequest,
prompt: Optional[str],
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[RequestOutputCollector] = None,
@ -278,6 +280,7 @@ class OutputProcessor:
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,

View File

@ -202,7 +202,7 @@ class Processor:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> EngineCoreRequest:
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
@ -306,9 +306,8 @@ class Processor:
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest(
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,

View File

@ -20,7 +20,6 @@ class Request:
def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
@ -46,7 +45,6 @@ class Request:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: list[int] = []
@ -81,7 +79,6 @@ class Request:
return cls(
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,

View File

@ -24,7 +24,6 @@ class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams

View File

@ -347,7 +347,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,

View File

@ -356,7 +356,6 @@ class TPUModelRunner:
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt=new_req_data.prompt,
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,