[Core] Move EngineCoreRequest to Request conversion out of EngineCore (#21627)

Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
Zebing Lin 2025-07-30 18:00:54 -04:00 committed by GitHub
parent 601f856d56
commit ca9e2be3ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 48 deletions

View File

@ -65,7 +65,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
"""Test basic request lifecycle."""
# First request.
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
@ -74,7 +75,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 1
# Second request.
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 1
@ -83,8 +85,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 2
# Add two requests in a row.
engine_core.add_request(make_request())
engine_core.add_request(make_request())
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 2
@ -104,7 +108,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req = make_request()
request_id = req.request_id
engine_core.add_request(req)
engine_core.add_request(*engine_core.preprocess_add_request(req))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
assert engine_core.scheduler.has_unfinished_requests()
@ -131,8 +135,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req1 = make_request()
req2 = make_request()
engine_core.add_request(req0)
engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req0))
engine_core.add_request(*engine_core.preprocess_add_request(req1))
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 0
@ -140,7 +144,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
engine_core.add_request(req2)
engine_core.add_request(*engine_core.preprocess_add_request(req2))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 2
@ -166,12 +170,12 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0 = make_request()
req1 = make_request()
req0.request_id = req1.request_id = "test"
engine_core.add_request(req0)
engine_core.add_request(*engine_core.preprocess_add_request(req0))
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req1))
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
@ -207,7 +211,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
repetition_penalty=0.1,
stop_token_ids=[1001, 1002],
)
engine_core.add_request(request)
engine_core.add_request(*engine_core.preprocess_add_request(request))
def _check_engine_state():
assert len(engine_core.scheduler.waiting) == 1
@ -226,7 +230,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
top_p=0.99,
top_k=50,
)
engine_core.add_request(request2)
engine_core.add_request(*engine_core.preprocess_add_request(request2))
_check_engine_state()
@ -298,9 +302,9 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Add two requests in a row. Each request have 12 prompt tokens.
req0 = make_request_with_max_tokens("0", 5)
engine_core.add_request(req0)
engine_core.add_request(*engine_core.preprocess_add_request(req0))
req1 = make_request_with_max_tokens("1", 5)
engine_core.add_request(req1)
engine_core.add_request(*engine_core.preprocess_add_request(req1))
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None
@ -436,7 +440,8 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
with pytest.raises(TypeError,
match="request_id must be a string, got.*UUID"):
engine_core.add_request(uuid_request)
engine_core.add_request(
*engine_core.preprocess_add_request(uuid_request))
# Test with integer
int_request = make_request()
@ -444,7 +449,8 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
with pytest.raises(TypeError,
match="request_id must be a string, got.*int"):
engine_core.add_request(int_request)
engine_core.add_request(
*engine_core.preprocess_add_request(int_request))
# Test with None
none_request = make_request()
@ -452,10 +458,12 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
with pytest.raises(TypeError,
match="request_id must be a string, got.*NoneType"):
engine_core.add_request(none_request)
engine_core.add_request(
*engine_core.preprocess_add_request(none_request))
# Verify engine is still functional after errors
valid_request = make_request()
engine_core.add_request(valid_request)
engine_core.add_request(
*engine_core.preprocess_add_request(valid_request))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0

View File

@ -205,8 +205,12 @@ class EngineCore:
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
def add_request(self, request: Request, request_wave: int = 0):
"""Add request to the scheduler.
`request_wave`: indicate which wave of requests this is expected to
belong to in DP case
"""
# Validate the request_id type.
if not isinstance(request.request_id, str):
raise TypeError(
@ -222,27 +226,12 @@ class EngineCore:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.kv_transfer_params is not None and (
if request.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
self.scheduler.add_request(request)
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
@ -414,6 +403,31 @@ class EngineCore:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )
def preprocess_add_request(
self, request: EngineCoreRequest) -> tuple[Request, int]:
"""Preprocess the request.
This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward
"""
if request.mm_hashes is not None:
assert request.mm_inputs is not None
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
# `structured_output_manager`, each request is independent and
# grammar compilation is async. Scheduler always checks grammar
# compilation status before scheduling request.
self.structured_output_manager.grammar_init(req)
return req, request.current_wave
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
@ -707,7 +721,8 @@ class EngineCoreProc(EngineCore):
"""Dispatch request from client."""
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
req, request_wave = request
self.add_request(req, request_wave)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
@ -806,10 +821,11 @@ class EngineCoreProc(EngineCore):
bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
if request_type == EngineCoreRequestType.ADD:
request = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(request)
else:
request = generic_decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
@ -939,17 +955,17 @@ class DPEngineCoreProc(EngineCoreProc):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
def add_request(self, request: Request, request_wave: int = 0):
if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave:
self.current_wave = request_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
super().add_request(request, request_wave)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:

View File

@ -250,7 +250,8 @@ class InprocClient(EngineCoreClient):
return self.engine_core.get_supported_tasks()
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)
def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0: