mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:55:00 +08:00
[CI/Build] Fix mypy errors (#6968)
This commit is contained in:
parent
f230cc2ca6
commit
9f0e69b653
@ -1,6 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -336,7 +336,7 @@ def scaled_fp8_quant(
|
|||||||
"""
|
"""
|
||||||
# This code assumes batch_dim and num_tokens are flattened
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
assert (input.ndim == 2)
|
assert (input.ndim == 2)
|
||||||
shape = input.shape
|
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
||||||
if num_token_padding:
|
if num_token_padding:
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
||||||
|
|||||||
@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _try_concat(
|
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
|
||||||
tensors: List[NestedTensors],
|
|
||||||
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
|
|
||||||
"""
|
"""
|
||||||
If each input tensor in the batch has the same shape, return a single
|
If each input tensor in the batch has the same shape, return a single
|
||||||
batched tensor; otherwise, return a list of :class:`NestedTensors` with
|
batched tensor; otherwise, return a list of :class:`NestedTensors` with
|
||||||
@ -105,7 +103,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
return {
|
return {
|
||||||
k: MultiModalInputs._try_concat(item_list)
|
k: MultiModalInputs._try_concat(item_list)
|
||||||
for k, item_list in item_lists.items()
|
for k, item_list in item_lists.items()
|
||||||
} # type: ignore
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def as_kwargs(
|
def as_kwargs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user