mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm.v1.core.sched.output import NewRequestData
|
|
|
|
|
|
def _create_new_requests_data(prompt_embeds: torch.Tensor | None) -> NewRequestData:
|
|
return NewRequestData(
|
|
req_id="test_req",
|
|
prompt_token_ids=None,
|
|
mm_features=[],
|
|
sampling_params=None,
|
|
pooling_params=None,
|
|
block_ids=([],),
|
|
num_computed_tokens=0,
|
|
lora_request=None,
|
|
prompt_embeds=prompt_embeds,
|
|
)
|
|
|
|
|
|
def test_repr_with_none() -> None:
|
|
"""Test repr when prompt_embeds is None."""
|
|
new_requests_data = _create_new_requests_data(None)
|
|
|
|
assert "prompt_embeds_shape=None" in repr(new_requests_data)
|
|
assert "prompt_embeds_shape=None" in new_requests_data.anon_repr()
|
|
|
|
|
|
def test_repr_with_multi_element_tensor() -> None:
|
|
"""Test repr when prompt_embeds is a multi-element tensor."""
|
|
prompt_embeds = torch.randn(10, 768)
|
|
new_requests_data = _create_new_requests_data(prompt_embeds)
|
|
|
|
assert "prompt_embeds_shape=torch.Size([10, 768])" in repr(new_requests_data)
|
|
assert "prompt_embeds_shape=torch.Size([10, 768])" in new_requests_data.anon_repr()
|