mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-22 13:24:40 +08:00
[mypy][CI/Build] Fix mypy errors (#7929)
This commit is contained in:
parent
c166e7e43e
commit
51f86bf487
@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
assert sgm.sampling_params is not None
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
assert expected_tokens_item is not None
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
|
||||
@ -19,7 +19,9 @@ class AudioAsset:
|
||||
|
||||
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||
s3_prefix=ASSET_DIR)
|
||||
return librosa.load(audio_path, sr=None)
|
||||
y, sr = librosa.load(audio_path, sr=None)
|
||||
assert isinstance(sr, int)
|
||||
return y, sr
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
|
||||
@ -101,6 +101,7 @@ class AsyncEngineRPCClient:
|
||||
# Maximum number of sockets that can be opened (typically 65536).
|
||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
||||
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
|
||||
assert isinstance(socket_limit, int)
|
||||
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
|
||||
raise ValueError(
|
||||
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
|
||||
@ -141,8 +142,8 @@ class AsyncEngineRPCClient:
|
||||
poller.register(socket_from, zmq.constants.POLLIN)
|
||||
poller.register(socket_to, zmq.constants.POLLIN)
|
||||
while True:
|
||||
events = await poller.poll()
|
||||
events = dict(events)
|
||||
events_lst = await poller.poll()
|
||||
events = dict(events_lst)
|
||||
if socket_from in events:
|
||||
identity, msg = await socket_from.recv_multipart()
|
||||
await socket_to.send_multipart([identity, msg])
|
||||
|
||||
@ -14,7 +14,7 @@ from typing_extensions import TypeAlias
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import json_map_leaves
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -54,13 +54,14 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
return nested_tensors
|
||||
|
||||
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
|
||||
if any(isinstance(t, list) for t in stacked):
|
||||
if is_list_of(stacked, list):
|
||||
# Do not stack nested lists
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(List[torch.Tensor], stacked)
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
return stacked
|
||||
|
||||
return torch.stack(tensors_)
|
||||
|
||||
@ -101,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
|
||||
batched_inputs)
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
lambda x: x.to(device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -883,7 +883,7 @@ class SequenceGroupMetadata(
|
||||
request_id: str
|
||||
is_prompt: bool
|
||||
seq_data: Dict[int, SequenceData]
|
||||
sampling_params: SamplingParams
|
||||
sampling_params: Optional[SamplingParams]
|
||||
block_tables: Dict[int, List[int]]
|
||||
do_sample: bool = True
|
||||
pooling_params: Optional[PoolingParams] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user