Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-13 12:17:23 +01:00 committed by GitHub
parent 906f0598fc
commit 6223dd8114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
87 changed files with 523 additions and 523 deletions

View File

@ -80,7 +80,6 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any, Dict, Optional
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
_config: Optional[Dict[str, Any]] = None
_config: Optional[dict[str, Any]] = None
@contextmanager
@ -19,7 +19,7 @@ def override_config(config):
_config = old_config
def get_config() -> Optional[Dict[str, Any]]:
def get_config() -> Optional[dict[str, Any]]:
return _config

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional, Tuple
from typing import Optional
import torch
@ -61,7 +61,7 @@ def _moe_permute(
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.

View File

@ -3,7 +3,7 @@
import functools
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str:
block_shape: Optional[list[int]] = None) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
@ -638,7 +638,7 @@ def get_moe_configs(
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
) -> Optional[Dict[int, Any]]:
) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@ -670,7 +670,7 @@ def get_moe_configs(
return None
def get_moe_wna16_block_config(config: Dict[str,
def get_moe_wna16_block_config(config: dict[str,
int], use_moe_wna16_cuda: bool,
num_valid_tokens: int, size_k: int, size_n: int,
num_experts: int, group_size: int,
@ -742,8 +742,8 @@ def get_default_config(
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
block_shape: Optional[list[int]] = None,
) -> dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
@ -795,13 +795,13 @@ def get_default_config(
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
):
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
@ -855,7 +855,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@ -895,7 +895,7 @@ def grouped_topk(
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[list[int]] = None) -> None:
pass
@ -1046,7 +1046,7 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
block_shape: Optional[list[int]] = None):
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
@ -1452,7 +1452,7 @@ def fused_moe(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@ -1497,7 +1497,7 @@ def fused_moe(
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
- block_shape: (Optional[list[int]]): Optional block size for block-wise
quantization.
Returns:

View File

@ -2,7 +2,7 @@
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional
import torch
import torch.nn.functional as F
@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
@ -338,7 +338,7 @@ def determine_expert_map(
global_num_experts (int): The total number of experts in the model.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
@ -909,7 +909,7 @@ class FusedMoE(torch.nn.Module):
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:
num_experts: int) -> list[tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -153,7 +153,7 @@ def moe_align_block_size(
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -15,7 +15,7 @@ def moe_permute(
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from functools import cache
from typing import List, Optional, Tuple
from typing import Optional
import torch
@ -97,7 +97,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck
@ -142,7 +142,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: List[int],
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=hidden_states_dtype)
@ -280,7 +280,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@ -372,14 +372,14 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> Tuple[torch.Tensor, ...]:
renormalize: bool) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
token_expert_indices, gating_output,
renormalize)
return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
@ -395,7 +395,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
from typing import Optional
import torch
@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
@ -22,8 +22,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn as nn
@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
ops.fused_add_rms_norm(
x,
@ -57,7 +57,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
@ -119,7 +119,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
@ -157,7 +157,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@ -174,7 +174,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None:
@ -194,7 +194,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@ -244,7 +244,7 @@ class GemmaRMSNorm(CustomOp):
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
@ -267,7 +267,7 @@ class GemmaRMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
@ -276,7 +276,7 @@ class GemmaRMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp):
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]],
shard_spec: list[tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from enum import IntEnum
from typing import List, Optional, Union
from typing import Optional, Union
import torch
import torch.nn as nn
@ -46,7 +46,7 @@ class SimplePooler(nn.Module):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "SimplePooler":
if pooling_type == PoolingType.LAST:
assert step_tag_id is None and returned_token_ids is None
@ -174,7 +174,7 @@ class StepPool(SimplePooler):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
):
super().__init__(normalize=normalize, softmax=softmax)
@ -259,7 +259,7 @@ class Pooler(nn.Module):
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
returned_token_ids: Optional[list[int]] = None,
) -> SimplePooler:
return SimplePooler.from_pooling_type(
pooling_type=PoolingType[pooler_config.pooling_type]

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Literal, Type, get_args
from typing import Literal, get_args
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -76,7 +76,7 @@ def register_quantization_config(quantization: str):
return _wrapper
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, Type[QuantizationConfig]] = {
method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,

View File

@ -4,7 +4,7 @@
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
@ -98,7 +98,7 @@ def generic_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int],
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
@ -136,7 +136,7 @@ def optimized_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: List[int],
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig):
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig):
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
modules_to_not_convert: Optional[list[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig):
return 75
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig):
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig):
return None
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
from torch.nn import Parameter
@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig):
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig):
return None
@classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import nn
@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool:
method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
return None
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
"quantization config.")
@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig):
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig):
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: Dict[str, Any],
keys: List[str],
def get_from_keys(config: dict[str, Any],
keys: list[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig):
return default
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return "bitsandbytes"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
return 70
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"adapter_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
def get_safe_value(config, keys, default_value=None):
try:
@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return None
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split('.')
@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import suppress
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
from typing import Any, Literal, Optional, cast
import torch
from compressed_tensors.config import (CompressionFormat,
@ -38,20 +38,20 @@ logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(
self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
sparsity_ignore_list: List[str],
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
ignore: List[str] = cast(List[str], config.get("ignore", []))
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def _parse_sparsity_config(
cls, config: Dict[str, Any]
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements
@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_config = SparsityCompressionConfig.model_validate(
sparsity_config)
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
target: sparsity_config
for target in sparsity_config.targets or list()
}
@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def _quantization_scheme_map_from_config(
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: Dict[str, Any] = dict()
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self,
@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme):
quantized: bool = False,
weight_quant: Optional[QuantizationArgs] = None,
input_quant: Optional[QuantizationArgs] = None,
model_compression_config: Optional[Dict[str, Any]] = None,
model_compression_config: Optional[dict[str, Any]] = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme):
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme):
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: List[torch.Tensor] = []
split_bitmask: List[torch.Tensor] = []
split_shape: List[Tuple[int, int]] = []
split_weights: list[torch.Tensor] = []
split_bitmask: list[torch.Tensor] = []
split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn import Parameter
@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
return 80
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
prepare_fp8_layer_for_marlin(layer)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
@ -19,7 +19,7 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
return 80
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Type
from typing import Optional
import torch
@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
block_size_m: int = 32,
block_size_n: int = 32,

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional
from typing import Optional
from compressed_tensors import CompressionFormat
from torch.nn import Module
@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False
@ -84,7 +85,7 @@ def find_matched_target(
layer_name: Optional[str],
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str,
def _match_fused_layer(
layer_name: str, target_layers: Iterable[str],
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
@ -201,7 +202,7 @@ def _match_fused_layer(
]
# for each unfused component, find a match in targets
unfused_matches: List[Optional[str]] = []
unfused_matches: list[Optional[str]] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn as nn
@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return "deepspeedfp"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits=weight_bits, group_size=group_size)
@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return DeepSpeedFPLinearMethod(self)
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
return 60
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig):
return "experts_int8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn import Module
@ -28,7 +28,7 @@ logger = init_logger(__name__)
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
def __init__(self, ignore_list: list[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig):
return "fbgemm_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
ignored_layers: Optional[list[str]] = None,
weight_block_size: Optional[list[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig):
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import gguf
import torch
@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
self.params_dtype = params_dtype
@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: List[torch.Tensor]
data_container: list[torch.Tensor]
def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}

View File

@ -3,7 +3,7 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import torch
from torch.nn.parameter import Parameter
@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
dynamic: dict[str, dict[str, Union[int, bool]]],
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig):
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig):
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return "gptq_bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
"""
kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config
@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Optional, Union
import torch
@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig):
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic
@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig):
GPTQMarlinLinearMethod)
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig):
return "gptq_marlin_24"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig):
self,
weight_bits: int,
group_size: int,
skip_modules: Optional[List[str]] = None,
skip_modules: Optional[list[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig):
return "hqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"])
@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig):
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
modules_to_not_convert: Optional[list[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig):
return "ipex"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig):
return -1
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower()
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import torch
@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
@ -75,7 +75,7 @@ class MPLinearKernel(ABC):
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type
from typing import Optional
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
@ -46,7 +46,7 @@ def choose_mp_linear_kernel(
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
from typing import Optional
import torch
@ -21,10 +21,10 @@ logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
BITBLAS_DTYPES: dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
is_bitblas_installed = True

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Optional, Tuple
from typing import Optional
import torch
@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional
import torch
@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module) -> Tuple[
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Dict, List, Optional, Type
from typing import Optional
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel)
@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig,
compute_capability: Optional[int] = None
) -> Type[ScaledMMLinearKernel]:
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel(
ValueError: If no kernel can implement the given config.
Returns:
Type[ScaledMMLinearKernel]: Chosen kernel.
type[ScaledMMLinearKernel]: Chosen kernel.
"""
if compute_capability is None:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_rocm():
return (
False,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
return False, "CutlassScaledMM requires running on CUDA or CPU."

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
from functorch.experimental.control_flow import cond # noqa: F401
@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig):
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig):
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig):
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str,
exclude_modules: List[str],
exclude_modules: list[str],
group_size: int = 16,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return "nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, exclude_modules: List):
def is_layer_excluded(self, prefix: str, exclude_modules: list):
import re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig):
def __init__(self, linear_quant_method: str, weight_bits: int,
group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
modules_to_not_convert: Optional[list[str]],
full_config: dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig):
return "moe_wna16"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig):
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig):
return None
@classmethod
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
return None
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)

View File

@ -2,7 +2,7 @@
import os
from importlib.util import find_spec
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from torch.nn import Module
@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]:
def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST
@classmethod
@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig):
"This function should not be called with Neuron Backend")
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config):
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
ignored_layers: Optional[list[str]] = None,
) -> None:
if not current_platform.is_rocm():
raise ValueError(
@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config):
return "ptpc_fp8"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(activation_scheme=activation_scheme,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch.nn.parameter import Parameter
@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig):
return "qqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig):
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig):
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
from typing import Any, Dict, List, Optional, cast
from typing import Any, Optional, cast
import torch
@ -29,9 +29,9 @@ logger = init_logger(__name__)
class QuarkConfig(QuantizationConfig):
def __init__(self,
quant_config: Dict[str, Any],
kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None,
quant_config: dict[str, Any],
kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder"):
super().__init__()
if kv_cache_group is None:
@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig):
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig):
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix,
ignore=exclude_layers,
fused_mapping=self.packed_modules_mapping):
@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export")
if export_config is None:
raise ValueError("The export key should be included in "
"the configurations of Quark quantized model")
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig):
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(Dict[str, Any],
layer_quant_config = cast(dict[str, Any],
config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig):
"configuration.")
q_configs = [
cast(Dict[str, Any], layer_quant_config.get(name))
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
]
if not all(
@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig):
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
q_proj_q_config = cast(dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig):
pack_method=pack_method)
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self,
@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig):
else:
return False
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig):
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
return is_per_tensor_activation
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug("Quark model is not in MX-FP4 format: "
@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig):
return True
def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:
module: torch.nn.Module) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping:
@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig):
return shard_configs[0]
else:
layer_quant_config = cast(
Dict[str, Any], self.quant_config.get("layer_quant_config"))
dict[str, Any], self.quant_config.get("layer_quant_config"))
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = cast(str, type(module))
layer_type_quant_config = cast(
Dict[str, Any],
dict[str, Any],
self.quant_config.get("layer_type_quant_config"))
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config"))
dict[str, Any], self.quant_config.get("global_quant_config"))
return global_quant_config
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported")
weight_config = cast(Dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors"))
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_fp8_w8a8(weight_config, input_config):
is_fp8_w8a8_supported = self._check_scheme_supported(
@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
import torch
@ -45,7 +45,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
self.weight_quant = weight_config
self.input_quant = input_config

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
@ -18,8 +18,8 @@ __all__ = ["QuarkW4A4MXFP4"]
class QuarkW4A4MXFP4(QuarkScheme):
def __init__(self, weight_quant_spec: Dict[str, Any],
input_quant_spec: Dict[str, Any]):
def __init__(self, weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
@ -74,7 +74,7 @@ class QuarkW4A4MXFP4(QuarkScheme):
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
from torch.nn import Parameter
@ -88,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Set
from typing import Callable, Optional
import torch
@ -17,7 +17,7 @@ logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: Set[str] = set()
_kernel_backends_being_used: set[str] = set()
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
input_symmetric: Optional[bool]):
@ -31,7 +31,7 @@ class QuarkW8A8Int8(QuarkScheme):
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional
from typing import Any, Optional
def deep_compare(dict1: Any, dict2: Any) -> bool:
@ -21,7 +22,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

View File

@ -12,7 +12,7 @@ possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from typing import Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
@ -23,7 +23,7 @@ class KVCacheQuantSchema(BaseModel):
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
scaling_factor: dict[int, dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
@ -24,7 +24,7 @@ class TorchAOConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
@ -32,11 +32,11 @@ class TorchAOConfig(QuantizationConfig):
return 75
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
@ -60,7 +60,7 @@ class TorchAOConfig(QuantizationConfig):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
def get_scaled_act_names(self) -> list[str]:
return []
@ -97,7 +97,7 @@ class TorchAOLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
from torch.nn import Module
@ -31,7 +31,7 @@ class Int8TpuConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
@ -40,11 +40,11 @@ class Int8TpuConfig(QuantizationConfig):
"This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
@ -62,7 +62,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
self.quant_config = quant_config
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
@ -77,7 +77,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
layer.register_parameter("weight", weight)
def _quantize_weight(
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
@ -51,7 +51,7 @@ def _check_bitblas_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
@ -133,7 +133,7 @@ def verify_bitblas_supports_shape(output_size_per_partition: int,
def check_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
try:
verify_bitblas_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
@ -166,7 +166,7 @@ def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
def bitblas_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices

View File

@ -4,7 +4,7 @@
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -114,7 +114,7 @@ direct_register_custom_op(
def input_to_float8(
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
@ -129,7 +129,7 @@ def input_to_float8(
def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import re
from copy import deepcopy
from typing import Dict, Optional, Union
from typing import Optional, Union
import torch
@ -52,7 +52,7 @@ def get_dynamic_override(
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool,
None] = None) -> Union[Dict, int, bool, None]:
None] = None) -> Union[dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):

View File

@ -5,7 +5,7 @@ import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
@ -58,7 +58,7 @@ def input_to_int8(
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
block_size: list[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
@ -211,7 +211,7 @@ def per_token_group_quant_int8(
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
@ -225,7 +225,7 @@ def per_token_group_quant_int8(
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[Dict[int, Any]]:
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import torch
@ -10,19 +10,19 @@ MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]:
if zero_points:
return [scalar_types.uint4, scalar_types.uint8]
else:
return [scalar_types.uint4b8, scalar_types.uint8b128]
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
return [torch.float16, torch.bfloat16]
def check_machete_supports_shape(in_features: int, out_featrues: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return False, "Input features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
from typing import Optional
import numpy
import torch
@ -70,7 +70,7 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
@ -143,7 +143,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
-> tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
@ -231,16 +231,16 @@ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Utility functions used for tests and benchmarks"""
from typing import List, Optional
from typing import Optional
import numpy as np
import torch
@ -64,9 +64,9 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [

View File

@ -2,7 +2,6 @@
"""Utility functions used for tests and benchmarks"""
import random
from typing import List
import numpy
import torch
@ -373,19 +372,19 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
def get_scale_perms_24():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return scale_perm, scale_perm_single
def get_weight_perm_24(num_bits: int):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
col_o = col // 2
for block in [0, 1]:

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import numpy
import torch
@ -34,10 +32,10 @@ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
def get_qqq_scale_perms():
scale_perm: List[int] = []
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
@ -46,9 +44,9 @@ def get_qqq_scale_perms():
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def get_qqq_weight_perm(num_bits: int, quant_type: str):
perm_list: List[int] = []
perm_list: list[int] = []
for i in range(32):
perm1: List[int] = []
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple
import torch
@ -9,7 +8,7 @@ OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple
from typing import Optional
import numpy
import torch
@ -15,7 +16,7 @@ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
@ -56,9 +57,9 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
group_shape: tuple[int, int],
quant_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, \
"currently `scaled_quantize` only supports floating point dtypes " \
@ -97,9 +98,9 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[Tuple[int, int]] = None,
group_shape: Optional[tuple[int, int]] = None,
out_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
ignored_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
@ -81,7 +81,7 @@ def all_close_1d(x: torch.Tensor) -> bool:
def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
# Create channelwise buffer
weight_scale_channel = torch.empty((sum(logical_widths), 1),
dtype=torch.float32,
@ -99,7 +99,7 @@ def convert_to_channelwise(
def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
@ -136,7 +136,7 @@ def maybe_create_device_identity():
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: List, **kwargs) -> torch.Tensor:
output_shape: list, **kwargs) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
@ -154,7 +154,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
@ -177,7 +177,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
@ -198,7 +198,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor:
output_shape: list) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
@ -228,7 +228,7 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: List,
output_shape: list,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
@ -384,7 +384,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.

View File

@ -2,7 +2,7 @@
from functools import cached_property
from importlib.util import find_spec
from typing import Dict, Optional, Tuple
from typing import Optional
import torch
import torch.jit
@ -65,7 +65,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
@ -161,8 +161,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
@ -194,7 +194,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
return accepted, recovered_token_ids
def _create_uniform_samples(self,
seeded_seqs: Optional[Dict[int,
seeded_seqs: Optional[dict[int,
torch.Generator]],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
@ -210,7 +210,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
@ -255,7 +255,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[Dict[int, torch.Generator]],
seeded_seqs: Optional[dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
@ -379,7 +379,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
seeded_seqs: Dict[int, torch.Generator],
seeded_seqs: dict[int, torch.Generator],
) -> torch.Tensor:
if num_samples > 1:

View File

@ -33,7 +33,7 @@ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
"""
import math
from functools import partial
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union
import numpy as np
import torch
@ -69,7 +69,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
version: tuple[int, int] = (2, 0)) -> torch.Tensor:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
@ -96,7 +96,7 @@ def get_1d_sincos_pos_embed_from_grid(
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
version: tuple[int, int] = (2, 0)) -> torch.Tensor:
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
@ -114,9 +114,9 @@ def get_2d_sincos_pos_embed_from_grid(
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
grid_size: Union[int, tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
version: tuple[int, int] = (2, 0),
) -> torch.Tensor:
"""
grid_size: int of the grid height and width

View File

@ -23,7 +23,7 @@
# limitations under the License.
"""Rotary Positional Embeddings."""
import math
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.nn as nn
@ -140,7 +140,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
@ -174,7 +174,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm import _custom_ops as ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
@ -202,7 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
@ -232,7 +232,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if offsets is not None:
@ -290,7 +290,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
@ -406,23 +406,23 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
scaling_factors: Union[list[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
self.scaling_factors: list[float] = scaling_factors # noqa
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
self._scaling_factor_to_offset: dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: List[torch.Tensor] = []
cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = []
offsets: list[int] = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
@ -452,7 +452,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
return torch.cat(cache_list, dim=0)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
def scaling_factor_to_offset(self) -> dict[float, int]:
return self._scaling_factor_to_offset
@ -512,7 +512,7 @@ def _yarn_find_correction_range(
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
max_position_embeddings: int = 2048) -> tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
@ -613,8 +613,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_factor: list[float],
long_factor: list[float],
short_mscale: Optional[float] = None,
long_mscale: Optional[float] = None,
):
@ -662,7 +662,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
long_short_cache,
persistent=False)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
@ -671,7 +671,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
rescale_factors: list[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
@ -688,7 +688,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
@ -799,7 +799,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
query_rot = query[..., :self.rotary_dim]
@ -930,7 +930,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
@ -958,7 +958,7 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
mrope_section: Optional[list[int]] = None,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
@ -976,7 +976,7 @@ class MRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
@ -1024,16 +1024,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod
def get_input_positions(
cls,
input_tokens: List[int],
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
second_per_grid_ts: Optional[List[float]],
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[list[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[List[List[int]], int]:
) -> tuple[list[list[int]], int]:
"""Get mrope input positions and delta value."""
image_grid_thw = [] if image_grid_thw is None else image_grid_thw
@ -1059,16 +1059,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: List[int],
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: List[float],
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
@ -1096,14 +1096,14 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: List[int],
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: List[float],
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id
@ -1195,16 +1195,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: List[int],
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: Optional[List[float]] = None,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
@ -1329,7 +1329,7 @@ class MRotaryEmbedding(RotaryEmbedding):
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: List[torch.Tensor] = []
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
@ -1382,7 +1382,7 @@ class MRotaryEmbedding(RotaryEmbedding):
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: List[int],
t_index: list[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
@ -1402,8 +1402,8 @@ class MRotaryEmbedding(RotaryEmbedding):
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> List[List[int]]:
ranges: List[List[int]] = [[]
interval: int) -> list[list[int]]:
ranges: list[list[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
@ -1415,7 +1415,7 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
) -> list[list[int]]:
return [
list(
range(context_len + mrope_position_delta,
@ -1438,9 +1438,9 @@ class MRotaryEmbedding(RotaryEmbedding):
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[List[int], torch.Tensor],
video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float,
) -> List[int]:
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
@ -1593,7 +1593,7 @@ class DualChunkRotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
@ -1664,7 +1664,7 @@ class DualChunkRotaryEmbedding(CustomOp):
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope(
@ -1673,10 +1673,10 @@ def get_rope(
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()

View File

@ -2,10 +2,11 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from collections.abc import Iterator
from dataclasses import dataclass
from importlib.util import find_spec
from math import inf
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Optional, Union
import msgspec
import torch
@ -42,14 +43,14 @@ def get_sampler() -> torch.nn.Module:
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
SampleResultType = list[tuple[list[int], list[int]]]
# Types of temporary data structures used for
# computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
List[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
SampleMetadataType = dict[SamplingType, tuple[list[int],
list[SequenceGroupToSample]]]
MultinomialSamplesType = dict[SamplingType, torch.Tensor]
SampleResultsDictType = dict[int, tuple[list[int], list[int]]]
# Encapsulates temporary data structures for computing
@ -76,7 +77,7 @@ class SampleResultArgsType:
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
# Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
class SamplerOutput(
@ -90,7 +91,7 @@ class SamplerOutput(
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
outputs: list[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
@ -350,7 +351,7 @@ def _apply_min_tokens_penalty(
have not been generated yet
"""
# list of indices in logits that will be set to -inf
logits_to_penalize: List[Tuple[int, int]] = []
logits_to_penalize: list[tuple[int, int]] = []
logits_applied = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
@ -366,7 +367,7 @@ def _apply_min_tokens_penalty(
min_tokens = sampling_params.min_tokens
token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize: List[int] = []
seqs_to_penalize: list[int] = []
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids_array) < min_tokens:
@ -436,7 +437,7 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
selected_seq_groups: list[SequenceGroupToSample],
samples: torch.Tensor,
) -> SampleResultType:
"""Run greedy sampling on a given samples.
@ -471,7 +472,7 @@ def _greedy_sample(
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
selected_seq_groups: list[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> SampleResultType:
"""Run random sampling on a given samples.
@ -522,7 +523,7 @@ def _random_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
seq_groups: Optional[List[SequenceGroupToSample]] = None,
seq_groups: Optional[list[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
@ -543,7 +544,7 @@ def _multinomial(
def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
@ -648,7 +649,7 @@ def _sample_with_torch(
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
categorized_seq_group_ids: dict[SamplingType, list[int]] = {
t: []
for t in SamplingType
}
@ -812,7 +813,7 @@ def get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]:
"""Return sample logprobs and prompt logprobs.
The logic consists of 3 parts.
@ -841,9 +842,9 @@ def get_logprobs(
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices: List[int] = []
query_indices: list[int] = []
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
next_token_ids: list[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
@ -925,8 +926,8 @@ def get_logprobs(
ranks = ranks.to('cpu')
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = []
sample_logprobs_per_seq_group: list[SampleLogprobs] = []
top_logprob_idx = 0
selected_logprobs_idx = 0
@ -977,7 +978,7 @@ def _get_prompt_logprob_if_needed(
for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens.
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
prompt_logprobs_dict: dict[int, tuple[float, int]] = {
token_id: (selected_logprob_items[idx], rank_items[idx])
}
@ -1009,7 +1010,7 @@ def _get_prompt_logprob_if_needed(
def _get_sampled_logprob_if_needed(
seq_group: SequenceGroupToSample,
sample_result: Tuple[List[int], List[int]],
sample_result: tuple[list[int], list[int]],
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
@ -1130,9 +1131,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output(
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
prompt_logprobs: Optional[list[Optional[PromptLogprobs]]],
sample_logprobs: Optional[list[SampleLogprobs]],
on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
skip_sampler_cpu_output: bool = False,
) -> SamplerOutput:
@ -1144,7 +1145,7 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output: List[CompletionSequenceGroupOutput] = []
sampler_output: list[CompletionSequenceGroupOutput] = []
if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
@ -1166,7 +1167,7 @@ def _build_sampler_output(
prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
seq_outputs: list[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from typing import Dict, Optional, Union
from typing import Optional, Union
import torch
import torch.jit
@ -253,6 +253,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers."""
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import torch
@ -13,7 +13,7 @@ def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
@ -141,7 +142,7 @@ def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size < 2:
return None
base_embeddings: List[int] = []
added_embeddings: List[int] = []
padding: List[int] = []
base_embeddings: list[int] = []
added_embeddings: list[int] = []
padding: list[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,