mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[CI/Build] Fix mypy errors (#6968)
This commit is contained in:
parent
f230cc2ca6
commit
9f0e69b653
@ -1,6 +1,6 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -336,7 +336,7 @@ def scaled_fp8_quant(
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
||||
|
||||
@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _try_concat(
|
||||
tensors: List[NestedTensors],
|
||||
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
|
||||
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
|
||||
"""
|
||||
If each input tensor in the batch has the same shape, return a single
|
||||
batched tensor; otherwise, return a list of :class:`NestedTensors` with
|
||||
@ -105,7 +103,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
return {
|
||||
k: MultiModalInputs._try_concat(item_list)
|
||||
for k, item_list in item_lists.items()
|
||||
} # type: ignore
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user