From 51f86bf48730c3766f39c15aecc1268780879835 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 28 Aug 2024 14:47:44 +0800 Subject: [PATCH] [mypy][CI/Build] Fix mypy errors (#7929) --- tests/samplers/test_sampler.py | 5 +++++ vllm/assets/audio.py | 4 +++- vllm/entrypoints/openai/rpc/client.py | 5 +++-- vllm/multimodal/base.py | 17 ++++++++++++----- vllm/sequence.py | 2 +- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 719254a398c03..19a5ca5e27502 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): prompt_len = seq_data.get_prompt_len() seq_lens.append(prompt_len) + assert sgm.sampling_params is not None if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in # logits @@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str): for i, (sequence_output, metadata) in enumerate( zip(sampler_output, seq_group_metadata_list)): + assert metadata.sampling_params is not None + if metadata.sampling_params.use_beam_search: continue @@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str): assert expected_tokens_item is not None for n, nth_output in enumerate(sequence_output.samples): + assert metadata.sampling_params is not None + if (metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None): # Ensure exact matches for greedy or random with seed diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index b00a61ebfec65..49bb6aeee90bc 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -19,7 +19,9 @@ class AudioAsset: audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR) - return librosa.load(audio_path, sr=None) + y, sr = librosa.load(audio_path, sr=None) + assert isinstance(sr, int) + return y, sr @property def url(self) -> str: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index dc316ca1160c6..a472e12e8ca48 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -101,6 +101,7 @@ class AsyncEngineRPCClient: # Maximum number of sockets that can be opened (typically 65536). # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) + assert isinstance(socket_limit, int) if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: raise ValueError( f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " @@ -141,8 +142,8 @@ class AsyncEngineRPCClient: poller.register(socket_from, zmq.constants.POLLIN) poller.register(socket_to, zmq.constants.POLLIN) while True: - events = await poller.poll() - events = dict(events) + events_lst = await poller.poll() + events = dict(events_lst) if socket_from in events: identity, msg = await socket_from.recv_multipart() await socket_to.send_multipart([identity, msg]) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 5b00117c64e53..f26e3292c264d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -14,7 +14,7 @@ from typing_extensions import TypeAlias from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import json_map_leaves +from vllm.utils import JSONTree, is_list_of, json_map_leaves logger = init_logger(__name__) @@ -54,13 +54,14 @@ class MultiModalInputs(_MultiModalInputsBase): return nested_tensors stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if any(isinstance(t, list) for t in stacked): + if is_list_of(stacked, list): + # Do not stack nested lists return stacked tensors_ = cast(List[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. - return tensors_ + return stacked return torch.stack(tensors_) @@ -101,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase): *, device: torch.types.Device, ) -> BatchedTensorInputs: - return json_map_leaves(lambda x: x.to(device, non_blocking=True), - batched_inputs) + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) _T = TypeVar("_T") diff --git a/vllm/sequence.py b/vllm/sequence.py index 964072dd7c8f1..f289a9aec80c5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -883,7 +883,7 @@ class SequenceGroupMetadata( request_id: str is_prompt: bool seq_data: Dict[int, SequenceData] - sampling_params: SamplingParams + sampling_params: Optional[SamplingParams] block_tables: Dict[int, List[int]] do_sample: bool = True pooling_params: Optional[PoolingParams] = None