[CI/Build] Fix mypy errors (#6968)

This commit is contained in:
Cyrus Leung 2024-07-31 10:49:48 +08:00 committed by GitHub
parent f230cc2ca6
commit 9f0e69b653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import contextlib
import functools
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Union
import torch
@ -336,7 +336,7 @@ def scaled_fp8_quant(
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape = input.shape
shape: Union[Tuple[int, int], torch.Size] = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

View File

@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@staticmethod
def _try_concat(
tensors: List[NestedTensors],
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
@ -105,7 +103,7 @@ class MultiModalInputs(_MultiModalInputsBase):
return {
k: MultiModalInputs._try_concat(item_list)
for k, item_list in item_lists.items()
} # type: ignore
}
@staticmethod
def as_kwargs(