[BugFix] Fix ValueError in NewRequestData repr methods (#29392)

Signed-off-by: maang <maang_h@163.com>
This commit is contained in:
maang-h 2025-11-28 13:42:30 +08:00 committed by GitHub
parent 18523b87f6
commit c7ba1f6bc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 2 deletions

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.v1.core.sched.output import NewRequestData
def _create_new_requests_data(prompt_embeds: torch.Tensor | None) -> NewRequestData:
return NewRequestData(
req_id="test_req",
prompt_token_ids=None,
mm_features=[],
sampling_params=None,
pooling_params=None,
block_ids=([],),
num_computed_tokens=0,
lora_request=None,
prompt_embeds=prompt_embeds,
)
def test_repr_with_none() -> None:
"""Test repr when prompt_embeds is None."""
new_requests_data = _create_new_requests_data(None)
assert "prompt_embeds_shape=None" in repr(new_requests_data)
assert "prompt_embeds_shape=None" in new_requests_data.anon_repr()
def test_repr_with_multi_element_tensor() -> None:
"""Test repr when prompt_embeds is a multi-element tensor."""
prompt_embeds = torch.randn(10, 768)
new_requests_data = _create_new_requests_data(prompt_embeds)
assert "prompt_embeds_shape=torch.Size([10, 768])" in repr(new_requests_data)
assert "prompt_embeds_shape=torch.Size([10, 768])" in new_requests_data.anon_repr()

View File

@ -68,7 +68,9 @@ class NewRequestData:
)
def __repr__(self) -> str:
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
prompt_embeds_shape = (
self.prompt_embeds.shape if self.prompt_embeds is not None else None
)
return (
f"NewRequestData("
f"req_id={self.req_id},"
@ -88,7 +90,9 @@ class NewRequestData:
prompt_token_ids_len = (
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
)
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
prompt_embeds_shape = (
self.prompt_embeds.shape if self.prompt_embeds is not None else None
)
return (
f"NewRequestData("
f"req_id={self.req_id},"