[Bugfix] Fix TypeError in scheduler when comparing mixed request_id types (#21816)

Signed-off-by: chiliu <chiliu@paypal.com>
Co-authored-by: chiliu <chiliu@paypal.com>
This commit is contained in:
633WHU 2025-07-30 23:54:44 +08:00 committed by GitHub
parent ad510309ee
commit 5c765aec65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 13 deletions

View File

@ -236,7 +236,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
Test that the engine can handle multiple concurrent batches.
"""
def make_request_with_max_tokens(req_id: int,
def make_request_with_max_tokens(req_id: str,
max_tokens: int) -> EngineCoreRequest:
request = make_request()
request.request_id = req_id
@ -297,16 +297,16 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.batch_queue is not None
# Add two requests in a row. Each request have 12 prompt tokens.
req0 = make_request_with_max_tokens(0, 5)
req0 = make_request_with_max_tokens("0", 5)
engine_core.add_request(req0)
req1 = make_request_with_max_tokens(1, 5)
req1 = make_request_with_max_tokens("1", 5)
engine_core.add_request(req1)
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
@ -315,11 +315,11 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
assert scheduler_output.num_scheduled_tokens["0"] == 2
assert scheduler_output.num_scheduled_tokens["1"] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
assert engine_core.scheduler.requests["0"].num_computed_tokens == 12
assert engine_core.scheduler.requests["1"].num_computed_tokens == 8
assert engine_core.scheduler.get_num_unfinished_requests() == 2
@ -331,7 +331,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4
assert scheduler_output.num_scheduled_tokens["1"] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()[0].get(0)
@ -343,7 +343,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1
assert scheduler_output.num_scheduled_tokens["0"] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()[0].get(0)
@ -355,14 +355,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1
assert scheduler_output.num_scheduled_tokens["1"] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
engine_core.scheduler.requests["0"].num_tokens + 1,
engine_core.scheduler.requests["1"].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()[0]
@ -413,3 +413,49 @@ def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
get_worker_cache_config_field, args=("num_cpu_blocks", ))
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)
@create_new_process_for_each_test()
def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
"""Test that engine raises TypeError for non-string request_id."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
# Test with UUID object (common mistake)
uuid_request = make_request()
uuid_request.request_id = uuid.uuid4() # UUID object instead of string
with pytest.raises(TypeError,
match="request_id must be a string, got.*UUID"):
engine_core.add_request(uuid_request)
# Test with integer
int_request = make_request()
int_request.request_id = 12345
with pytest.raises(TypeError,
match="request_id must be a string, got.*int"):
engine_core.add_request(int_request)
# Test with None
none_request = make_request()
none_request.request_id = None
with pytest.raises(TypeError,
match="request_id must be a string, got.*NoneType"):
engine_core.add_request(none_request)
# Verify engine is still functional after errors
valid_request = make_request()
engine_core.add_request(valid_request)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0

View File

@ -207,6 +207,11 @@ class EngineCore:
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
# Validate the request_id type.
if not isinstance(request.request_id, str):
raise TypeError(
f"request_id must be a string, got {type(request.request_id)}")
if pooling_params := request.pooling_params:
supported_pooling_tasks = [
task for task in self.get_supported_tasks()