mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:24:55 +08:00
Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
906f0598fc
commit
6223dd8114
@ -80,7 +80,6 @@ exclude = [
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/lora/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[Dict[str, Any]] = None
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -19,7 +19,7 @@ def override_config(config):
|
||||
_config = old_config
|
||||
|
||||
|
||||
def get_config() -> Optional[Dict[str, Any]]:
|
||||
def get_config() -> Optional[dict[str, Any]]:
|
||||
return _config
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.util
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -61,7 +61,7 @@ def _moe_permute(
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
block_m: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -472,14 +472,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@ -622,7 +622,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
def get_config_file_name(E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[List[int]] = None) -> str:
|
||||
block_shape: Optional[list[int]] = None) -> str:
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
block_shape_selector = ("" if not block_shape or not all(block_shape) else
|
||||
@ -638,7 +638,7 @@ def get_moe_configs(
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = None,
|
||||
block_k: Optional[int] = None,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@ -670,7 +670,7 @@ def get_moe_configs(
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_wna16_block_config(config: Dict[str,
|
||||
def get_moe_wna16_block_config(config: dict[str,
|
||||
int], use_moe_wna16_cuda: bool,
|
||||
num_valid_tokens: int, size_k: int, size_n: int,
|
||||
num_experts: int, group_size: int,
|
||||
@ -742,8 +742,8 @@ def get_default_config(
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Dict[str, int]:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> dict[str, int]:
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
@ -795,13 +795,13 @@ def get_default_config(
|
||||
|
||||
|
||||
def try_get_optimal_moe_config(
|
||||
w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
w1_shape: tuple[int, ...],
|
||||
w2_shape: tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
override_config = get_config()
|
||||
@ -855,7 +855,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@ -895,7 +895,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@ -982,7 +982,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
@ -1012,7 +1012,7 @@ def inplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1046,7 +1046,7 @@ def outplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, activation, apply_router_weight_on_input,
|
||||
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
|
||||
@ -1076,7 +1076,7 @@ def outplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@ -1129,7 +1129,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
@ -1184,8 +1184,8 @@ def moe_kernel_prepare_input(
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
@ -1248,7 +1248,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
block_shape: Optional[list[int]] = None):
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
@ -1452,7 +1452,7 @@ def fused_moe(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@ -1497,7 +1497,7 @@ def fused_moe(
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def determine_expert_map(
|
||||
ep_size: int, ep_rank: int,
|
||||
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
|
||||
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
creates a mapping from global to local expert index. Experts are
|
||||
@ -338,7 +338,7 @@ def determine_expert_map(
|
||||
global_num_experts (int): The total number of experts in the model.
|
||||
|
||||
Returns:
|
||||
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||
tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
||||
- local_num_experts (int): The number of experts assigned
|
||||
to the current rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
||||
@ -909,7 +909,7 @@ class FusedMoE(torch.nn.Module):
|
||||
def make_expert_params_mapping(
|
||||
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
||||
num_experts: int) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -153,7 +153,7 @@ def moe_align_block_size(
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,7 +15,7 @@ def moe_permute(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function expands and permutes activation to gather uncontinuous tokens
|
||||
for each expert.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from functools import cache
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -97,7 +97,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: List[int],
|
||||
block_shape: list[int],
|
||||
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
from aiter import fmoe_fp8_blockscale_g1u1
|
||||
from aiter.fused_moe_bf16_asm import moe_sorting_ck
|
||||
@ -142,7 +142,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: List[int],
|
||||
block_shape: list[int],
|
||||
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return torch.empty_like(a1, dtype=hidden_states_dtype)
|
||||
@ -280,7 +280,7 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@ -372,14 +372,14 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool) -> Tuple[torch.Tensor, ...]:
|
||||
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
|
||||
token_expert_indices, gating_output,
|
||||
renormalize)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
@ -395,7 +395,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
|
||||
|
||||
def expand_weights(*tensors: torch.Tensor,
|
||||
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
|
||||
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Expands the dimensions of input tensors.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
@ -22,8 +22,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Custom normalization layers."""
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -31,7 +31,7 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
@ -57,7 +57,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
|
||||
def rocm_aiter_fused_add_rms_norm(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
|
||||
@ -119,7 +119,7 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
@ -157,7 +157,7 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@ -174,7 +174,7 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
HPUFusedRMSNorm = rms_norm()
|
||||
if HPUFusedRMSNorm is None:
|
||||
@ -194,7 +194,7 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
@ -244,7 +244,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
variance_epsilon: float,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
@ -267,7 +267,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return self.forward_static(self.weight.data, self.variance_epsilon, x,
|
||||
residual)
|
||||
@ -276,7 +276,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
if self.tp_size > 1 or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
shard_spec: List[Tuple[int, int, float]],
|
||||
shard_spec: list[tuple[int, int, float]],
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -46,7 +46,7 @@ class SimplePooler(nn.Module):
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> "SimplePooler":
|
||||
if pooling_type == PoolingType.LAST:
|
||||
assert step_tag_id is None and returned_token_ids is None
|
||||
@ -174,7 +174,7 @@ class StepPool(SimplePooler):
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(normalize=normalize, softmax=softmax)
|
||||
|
||||
@ -259,7 +259,7 @@ class Pooler(nn.Module):
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
returned_token_ids: Optional[list[int]] = None,
|
||||
) -> SimplePooler:
|
||||
return SimplePooler.from_pooling_type(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
from typing import Literal, get_args
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@ -76,7 +76,7 @@ def register_quantization_config(quantization: str):
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, Type[QuantizationConfig]] = {
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# and https://arxiv.org/pdf/2401.06118.pdf
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -98,7 +98,7 @@ def generic_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (scales.shape[0], )
|
||||
@ -136,7 +136,7 @@ def optimized_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig):
|
||||
return "aqlm"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
|
||||
in_group_size = cls.get_from_keys(config, ["in_group_size"])
|
||||
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
|
||||
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
|
||||
@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
del output_size # Unused.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig):
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str],
|
||||
default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any],
|
||||
keys: List[str],
|
||||
def get_from_keys(config: dict[str, Any],
|
||||
keys: list[str],
|
||||
default: Any = None) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"], -1)
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"], False)
|
||||
@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[List[str]] = None,
|
||||
llm_int8_skip_modules: Optional[list[str]] = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"adapter_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
@ -38,20 +38,20 @@ logger = init_logger(__name__)
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
target_scheme_map: dict[str, Any],
|
||||
ignore: list[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
||||
cls, config: dict[str, Any]
|
||||
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
||||
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: Dict[str, Any] = dict()
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[Dict[str, Any]] = None,
|
||||
model_compression_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: List[torch.Tensor] = []
|
||||
split_bitmask: List[torch.Tensor] = []
|
||||
split_shape: List[Tuple[int, int]] = []
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@ -19,7 +19,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||
input_symmetric: bool):
|
||||
@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: List[int],
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Iterable, List, Mapping, Optional
|
||||
from typing import Optional
|
||||
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
@ -84,7 +85,7 @@ def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str,
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str, target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
|
||||
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
@ -201,7 +202,7 @@ def _match_fused_layer(
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: List[Optional[str]] = []
|
||||
unfused_matches: list[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits=weight_bits, group_size=group_size)
|
||||
@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
|
||||
def create_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@ -28,7 +28,7 @@ logger = init_logger(__name__)
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: List[str], input_scale_ub: float):
|
||||
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
weight_block_size: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: List[torch.Tensor]
|
||||
data_container: list[torch.Tensor]
|
||||
|
||||
def materialize_nested(self) -> Parameter:
|
||||
dtype = {data.dtype for data in self.data_container}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return self.TORCH_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
@classmethod
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
"""
|
||||
|
||||
kernel_type = BitBLASLinearKernel
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
skip_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
|
||||
wq_params = (config["quant_config"]["weight_quant_params"])
|
||||
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
||||
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
||||
@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
) -> None:
|
||||
@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig):
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
|
||||
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||
if method == "awq":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: Tuple[int, int] # [in, out]
|
||||
partition_weight_shape: Tuple[int, int]
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
@ -75,7 +75,7 @@ class MPLinearKernel(ABC):
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> Tuple[
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
||||
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
@ -46,7 +46,7 @@ def choose_mp_linear_kernel(
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
Type[MPLinearKernel]: Chosen kernel.
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -21,10 +21,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
|
||||
BITBLAS_DTYPES: dict[torch.dtype, str] = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
is_bitblas_installed = True
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Exllama, "\
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||
@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> Tuple[
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel)
|
||||
@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None
|
||||
) -> Type[ScaledMMLinearKernel]:
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel(
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
Type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
|
||||
return False, "CutlassScaledMM requires running on CUDA or CPU."
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig):
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool,
|
||||
kv_cache_quant_algo: str,
|
||||
exclude_modules: List[str],
|
||||
exclude_modules: list[str],
|
||||
group_size: int = 16,
|
||||
) -> None:
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return "nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
||||
|
||||
@classmethod
|
||||
@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: List):
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
import re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[str]:
|
||||
def get_supported_act_dtypes(self) -> list[str]:
|
||||
return SUPPORTED_QUANT_DTYPE_LIST
|
||||
|
||||
@classmethod
|
||||
@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError(
|
||||
@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig):
|
||||
return "qqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
return [
|
||||
"quant_config.json",
|
||||
@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
@ -29,9 +29,9 @@ logger = init_logger(__name__)
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: Dict[str, Any],
|
||||
kv_cache_group: Optional[List[str]] = None,
|
||||
kv_cache_config: Optional[Dict[str, Any]] = None,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(Dict[str, Any],
|
||||
layer_quant_config = cast(dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(Dict[str, Any], layer_quant_config.get(name))
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(Dict[str, Any],
|
||||
q_proj_q_config = cast(dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> Dict[str, Any]:
|
||||
module: torch.nn.Module) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig):
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(Dict[str, Any], config.get("weight"))
|
||||
input_config = cast(Dict[str, Any], config.get("input_tensors"))
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -45,7 +45,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
|
||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
||||
Any]):
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -18,8 +18,8 @@ __all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: Dict[str, Any],
|
||||
input_quant_spec: Dict[str, Any]):
|
||||
def __init__(self, weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
@ -74,7 +74,7 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -88,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -17,7 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
@ -31,7 +31,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Mapping, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
@ -21,7 +22,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
@ -12,7 +12,7 @@ possible on ROCm), the model can be optionally augmented with KV cache
|
||||
scaling factors.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
|
||||
@ -23,7 +23,7 @@ class KVCacheQuantSchema(BaseModel):
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: Dict[int, Dict[int, float]]
|
||||
scaling_factor: dict[int, dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -24,7 +24,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -32,11 +32,11 @@ class TorchAOConfig(QuantizationConfig):
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
@ -60,7 +60,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
return TorchAOLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ class TorchAOLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@ -31,7 +31,7 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@ -40,11 +40,11 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
"This function should not be called with TPU Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme=activation_scheme)
|
||||
|
||||
@ -62,7 +62,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
@ -77,7 +77,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight_dtype = weight.dtype
|
||||
weight = weight.cpu().to(torch.float32)
|
||||
n_bit = 8
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -51,7 +51,7 @@ def _check_bitblas_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
@ -133,7 +133,7 @@ def verify_bitblas_supports_shape(output_size_per_partition: int,
|
||||
def check_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_bitblas_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
@ -166,7 +166,7 @@ def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
|
||||
|
||||
def bitblas_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -114,7 +114,7 @@ direct_register_custom_op(
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@ -129,7 +129,7 @@ def input_to_float8(
|
||||
def block_quant_to_tensor_quant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise
|
||||
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
||||
block-wise quantization scale and the block size.
|
||||
@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -52,7 +52,7 @@ def get_dynamic_override(
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool,
|
||||
None] = None) -> Union[Dict, int, bool, None]:
|
||||
None] = None) -> Union[dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
|
||||
@ -5,7 +5,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
def apply_w8a8_block_int8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
|
||||
|
||||
def input_to_int8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to int8 values with
|
||||
tensor-wise quantization."""
|
||||
iinfo = torch.iinfo(dtype)
|
||||
@ -58,7 +58,7 @@ def input_to_int8(
|
||||
def block_dequant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""This function conducts block-wise dequantization.
|
||||
The inputs are block-wise quantization tensor `x_q_block`,
|
||||
@ -211,7 +211,7 @@ def per_token_group_quant_int8(
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = torch.int8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
|
||||
It converts the tensor values into signed int8 values and returns the
|
||||
@ -225,7 +225,7 @@ def per_token_group_quant_int8(
|
||||
is supported for now.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
|
||||
@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,19 +10,19 @@ MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
|
||||
|
||||
|
||||
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
|
||||
def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]:
|
||||
if zero_points:
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
|
||||
def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def check_machete_supports_shape(in_features: int, out_featrues: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
|
||||
return False, "Input features size must be divisible by "\
|
||||
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -70,7 +70,7 @@ def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
@ -143,7 +143,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
def check_marlin_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_marlin_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
@ -231,16 +231,16 @@ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
|
||||
|
||||
def marlin_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def get_scale_perms():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -64,9 +64,9 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -373,19 +372,19 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
||||
|
||||
|
||||
def get_scale_perms_24():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_weight_perm_24(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
@ -34,10 +32,10 @@ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
||||
|
||||
|
||||
def get_qqq_scale_perms():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
@ -46,9 +44,9 @@ def get_qqq_scale_perms():
|
||||
|
||||
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
||||
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -9,7 +8,7 @@ OCP_MX_BLOCK_SIZE = 32
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -15,7 +16,7 @@ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
@ -56,9 +57,9 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: Tuple[int, int],
|
||||
group_shape: tuple[int, int],
|
||||
quant_dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
@ -97,9 +98,9 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[Tuple[int, int]] = None,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
ignored_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -81,7 +81,7 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
def convert_to_channelwise(
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Create channelwise buffer
|
||||
weight_scale_channel = torch.empty((sum(logical_widths), 1),
|
||||
dtype=torch.float32,
|
||||
@ -99,7 +99,7 @@ def convert_to_channelwise(
|
||||
|
||||
def requantize_with_max_scale(
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max()
|
||||
|
||||
@ -136,7 +136,7 @@ def maybe_create_device_identity():
|
||||
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: List, **kwargs) -> torch.Tensor:
|
||||
output_shape: list, **kwargs) -> torch.Tensor:
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
@ -154,7 +154,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@ -177,7 +177,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
@ -198,7 +198,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
@ -228,7 +228,7 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
@ -384,7 +384,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert weight.dtype == torch.float8_e4m3fn
|
||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@ -65,7 +65,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||
tokens proposed by the draft model using the probability of each token
|
||||
@ -161,8 +161,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform modified rejection sampling on each sequence.
|
||||
|
||||
Returns:
|
||||
@ -194,7 +194,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
return accepted, recovered_token_ids
|
||||
|
||||
def _create_uniform_samples(self,
|
||||
seeded_seqs: Optional[Dict[int,
|
||||
seeded_seqs: Optional[dict[int,
|
||||
torch.Generator]],
|
||||
batch_size: int, k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
@ -210,7 +210,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
a seed.
|
||||
|
||||
Args:
|
||||
seeded_seqs : Optional[Dict[int, torch.Generator]]
|
||||
seeded_seqs : Optional[dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects. If `None`, all samples are
|
||||
generated without a seed.
|
||||
@ -255,7 +255,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
r"""Create bool matrix over the proposed draft tokens. If
|
||||
True, then a token can be accepted, else it should be
|
||||
@ -379,7 +379,7 @@ def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
k: int,
|
||||
seeded_seqs: Dict[int, torch.Generator],
|
||||
seeded_seqs: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_samples > 1:
|
||||
|
||||
@ -33,7 +33,7 @@ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -69,7 +69,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: np.ndarray,
|
||||
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||
version: tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,) / (H, W)
|
||||
@ -96,7 +96,7 @@ def get_1d_sincos_pos_embed_from_grid(
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: np.ndarray,
|
||||
version: Tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||
version: tuple[int, int] = (2, 0)) -> torch.Tensor:
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
@ -114,9 +114,9 @@ def get_2d_sincos_pos_embed_from_grid(
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
grid_size: Union[int, Tuple[int, int]],
|
||||
grid_size: Union[int, tuple[int, int]],
|
||||
cls_token: bool = False,
|
||||
version: Tuple[int, int] = (2, 0),
|
||||
version: tuple[int, int] = (2, 0),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -140,7 +140,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
@ -174,7 +174,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
@ -202,7 +202,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
@ -232,7 +232,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
|
||||
if offsets is not None:
|
||||
@ -290,7 +290,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
def _apply_rotary_emb_neuron(
|
||||
x: torch.Tensor,
|
||||
@ -406,23 +406,23 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
scaling_factors: Union[list[float], float],
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||
self.scaling_factors: list[float] = scaling_factors # noqa
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: Dict[float, int]
|
||||
self._scaling_factor_to_offset: dict[float, int]
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
cache_list: List[torch.Tensor] = []
|
||||
cache_list: list[torch.Tensor] = []
|
||||
# offsets to the next cache in a tensor.
|
||||
# Each offset corresponds to the same index in scaling_factors.
|
||||
offsets: List[int] = []
|
||||
offsets: list[int] = []
|
||||
for scaling_factor in self.scaling_factors:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
@ -452,7 +452,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
return torch.cat(cache_list, dim=0)
|
||||
|
||||
@property
|
||||
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
||||
def scaling_factor_to_offset(self) -> dict[float, int]:
|
||||
return self._scaling_factor_to_offset
|
||||
|
||||
|
||||
@ -512,7 +512,7 @@ def _yarn_find_correction_range(
|
||||
high_rot: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> Tuple[int, int]:
|
||||
max_position_embeddings: int = 2048) -> tuple[int, int]:
|
||||
low = math.floor(
|
||||
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
@ -613,8 +613,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
short_factor: List[float],
|
||||
long_factor: List[float],
|
||||
short_factor: list[float],
|
||||
long_factor: list[float],
|
||||
short_mscale: Optional[float] = None,
|
||||
long_mscale: Optional[float] = None,
|
||||
):
|
||||
@ -662,7 +662,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
long_short_cache,
|
||||
persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
||||
def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
|
||||
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
||||
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
|
||||
@ -671,7 +671,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
def _compute_cos_sin_cache(
|
||||
self,
|
||||
max_position_embeddings: int,
|
||||
rescale_factors: List[float],
|
||||
rescale_factors: list[float],
|
||||
mscale: float,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(rescale_factors)
|
||||
@ -688,7 +688,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
@ -799,7 +799,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
@ -930,7 +930,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
@ -958,7 +958,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
mrope_section: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
@ -976,7 +976,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
@ -1024,16 +1024,16 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@classmethod
|
||||
def get_input_positions(
|
||||
cls,
|
||||
input_tokens: List[int],
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[List[float]],
|
||||
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||
second_per_grid_ts: Optional[list[float]],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
) -> tuple[list[list[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
image_grid_thw = [] if image_grid_thw is None else image_grid_thw
|
||||
@ -1059,16 +1059,16 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@classmethod
|
||||
def get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: List[int],
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
second_per_grid_ts: List[float],
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||
if thinker_uses_mrope(hf_config):
|
||||
return cls._omni_get_input_positions_tensor(
|
||||
@ -1096,14 +1096,14 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@classmethod
|
||||
def _vl_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: List[int],
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
second_per_grid_ts: List[float],
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
@ -1195,16 +1195,16 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
@classmethod
|
||||
def _omni_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: List[int],
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
second_per_grid_ts: Optional[List[float]] = None,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
|
||||
|
||||
Differences from MRotaryEmbedding:
|
||||
@ -1329,7 +1329,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
|
||||
pure_audio_len = place_num - 2
|
||||
added_audio_len = 0
|
||||
audio_llm_pos_ids_list: List[torch.Tensor] = []
|
||||
audio_llm_pos_ids_list: list[torch.Tensor] = []
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = len(
|
||||
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
@ -1382,7 +1382,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
start_idx: int,
|
||||
vision_idx: int,
|
||||
spatial_merge_size: int,
|
||||
t_index: List[int],
|
||||
t_index: list[int],
|
||||
grid_hs: torch.Tensor,
|
||||
grid_ws: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
@ -1402,8 +1402,8 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
@staticmethod
|
||||
def _split_list_into_ranges(lst: torch.Tensor,
|
||||
interval: int) -> List[List[int]]:
|
||||
ranges: List[List[int]] = [[]
|
||||
interval: int) -> list[list[int]]:
|
||||
ranges: list[list[int]] = [[]
|
||||
for _ in range((max(lst) // interval) + 1)]
|
||||
for num in lst:
|
||||
index = num // interval
|
||||
@ -1415,7 +1415,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> List[List[int]]:
|
||||
) -> list[list[int]]:
|
||||
return [
|
||||
list(
|
||||
range(context_len + mrope_position_delta,
|
||||
@ -1438,9 +1438,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
cls,
|
||||
thinker_config: PretrainedConfig,
|
||||
audio_len: int,
|
||||
video_grid_thw: Union[List[int], torch.Tensor],
|
||||
video_grid_thw: Union[list[int], torch.Tensor],
|
||||
video_second_per_grid_t: float,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""Get video prompt updates when `use_audio_in_video` is True.
|
||||
|
||||
In this case, audio and vision update ids will be split into
|
||||
@ -1593,7 +1593,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
@ -1664,7 +1664,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
def get_rope(
|
||||
@ -1673,10 +1673,10 @@ def get_rope(
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from math import inf
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@ -42,14 +43,14 @@ def get_sampler() -> torch.nn.Module:
|
||||
|
||||
|
||||
# (num_token_ids, num_parent_ids) per sequence group.
|
||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||
SampleResultType = list[tuple[list[int], list[int]]]
|
||||
|
||||
# Types of temporary data structures used for
|
||||
# computing sample_result
|
||||
SampleMetadataType = Dict[SamplingType, Tuple[List[int],
|
||||
List[SequenceGroupToSample]]]
|
||||
MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
|
||||
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
|
||||
SampleMetadataType = dict[SamplingType, tuple[list[int],
|
||||
list[SequenceGroupToSample]]]
|
||||
MultinomialSamplesType = dict[SamplingType, torch.Tensor]
|
||||
SampleResultsDictType = dict[int, tuple[list[int], list[int]]]
|
||||
|
||||
|
||||
# Encapsulates temporary data structures for computing
|
||||
@ -76,7 +77,7 @@ class SampleResultArgsType:
|
||||
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
|
||||
|
||||
# Abbreviation of the _sample() return type
|
||||
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
|
||||
SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
|
||||
|
||||
|
||||
class SamplerOutput(
|
||||
@ -90,7 +91,7 @@ class SamplerOutput(
|
||||
also has optional fields for device tensors.
|
||||
"""
|
||||
|
||||
outputs: List[CompletionSequenceGroupOutput]
|
||||
outputs: list[CompletionSequenceGroupOutput]
|
||||
|
||||
# On-device tensor containing probabilities of each token.
|
||||
sampled_token_probs: Optional[torch.Tensor] = None
|
||||
@ -350,7 +351,7 @@ def _apply_min_tokens_penalty(
|
||||
have not been generated yet
|
||||
"""
|
||||
# list of indices in logits that will be set to -inf
|
||||
logits_to_penalize: List[Tuple[int, int]] = []
|
||||
logits_to_penalize: list[tuple[int, int]] = []
|
||||
logits_applied = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
@ -366,7 +367,7 @@ def _apply_min_tokens_penalty(
|
||||
min_tokens = sampling_params.min_tokens
|
||||
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
||||
if min_tokens > 0 and token_ids_to_penalize:
|
||||
seqs_to_penalize: List[int] = []
|
||||
seqs_to_penalize: list[int] = []
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
if len(seq_data.output_token_ids_array) < min_tokens:
|
||||
@ -436,7 +437,7 @@ def _apply_min_p(
|
||||
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
selected_seq_groups: list[SequenceGroupToSample],
|
||||
samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run greedy sampling on a given samples.
|
||||
@ -471,7 +472,7 @@ def _greedy_sample(
|
||||
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
selected_seq_groups: list[SequenceGroupToSample],
|
||||
random_samples: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run random sampling on a given samples.
|
||||
@ -522,7 +523,7 @@ def _random_sample(
|
||||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
||||
seq_groups: Optional[list[SequenceGroupToSample]] = None,
|
||||
) -> torch.Tensor:
|
||||
if num_samples > 1:
|
||||
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||
@ -543,7 +544,7 @@ def _multinomial(
|
||||
|
||||
def _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
|
||||
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
|
||||
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
|
||||
max_top_k_round = 32
|
||||
if num_samples > 1:
|
||||
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||
@ -648,7 +649,7 @@ def _sample_with_torch(
|
||||
tensors required for Pythonization
|
||||
'''
|
||||
|
||||
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
|
||||
categorized_seq_group_ids: dict[SamplingType, list[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
@ -812,7 +813,7 @@ def get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_results: SampleResultType,
|
||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||
) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]:
|
||||
"""Return sample logprobs and prompt logprobs.
|
||||
|
||||
The logic consists of 3 parts.
|
||||
@ -841,9 +842,9 @@ def get_logprobs(
|
||||
"""
|
||||
# The index of query token to calculate logprobs. It includes both
|
||||
# prompt and sample logprob indices.
|
||||
query_indices: List[int] = []
|
||||
query_indices: list[int] = []
|
||||
# The next token ids to get the logprob value from.
|
||||
next_token_ids: List[int] = []
|
||||
next_token_ids: list[int] = []
|
||||
# The largest requested number of logprobs. We find logprobs as many as the
|
||||
# largest num logprobs in this API. If every logprobs is None, it will be
|
||||
# set to -1.
|
||||
@ -925,8 +926,8 @@ def get_logprobs(
|
||||
ranks = ranks.to('cpu')
|
||||
|
||||
# Find prompt/sample logprobs.
|
||||
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
||||
prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = []
|
||||
sample_logprobs_per_seq_group: list[SampleLogprobs] = []
|
||||
top_logprob_idx = 0
|
||||
selected_logprobs_idx = 0
|
||||
|
||||
@ -977,7 +978,7 @@ def _get_prompt_logprob_if_needed(
|
||||
for idx, token_id in enumerate(next_prompt_tokens):
|
||||
# Calculate the prompt logprob of the real prompt tokens.
|
||||
# {token_id: (logprob, rank_from_vocab)}
|
||||
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||
prompt_logprobs_dict: dict[int, tuple[float, int]] = {
|
||||
token_id: (selected_logprob_items[idx], rank_items[idx])
|
||||
}
|
||||
|
||||
@ -1009,7 +1010,7 @@ def _get_prompt_logprob_if_needed(
|
||||
|
||||
def _get_sampled_logprob_if_needed(
|
||||
seq_group: SequenceGroupToSample,
|
||||
sample_result: Tuple[List[int], List[int]],
|
||||
sample_result: tuple[list[int], list[int]],
|
||||
selected_logprobs: torch.Tensor,
|
||||
ranks: torch.Tensor,
|
||||
top_token_ids: torch.Tensor,
|
||||
@ -1130,9 +1131,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
def _build_sampler_output(
|
||||
maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
||||
sample_logprobs: Optional[List[SampleLogprobs]],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
||||
prompt_logprobs: Optional[list[Optional[PromptLogprobs]]],
|
||||
sample_logprobs: Optional[list[SampleLogprobs]],
|
||||
on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]],
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
) -> SamplerOutput:
|
||||
@ -1144,7 +1145,7 @@ def _build_sampler_output(
|
||||
allows post-processing without copies to CPU/serialization, e.g. in
|
||||
speculative decoding rejection sampling.
|
||||
"""
|
||||
sampler_output: List[CompletionSequenceGroupOutput] = []
|
||||
sampler_output: list[CompletionSequenceGroupOutput] = []
|
||||
|
||||
if skip_sampler_cpu_output:
|
||||
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
||||
@ -1166,7 +1167,7 @@ def _build_sampler_output(
|
||||
prompt_logprobs, sample_logprobs):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
seq_outputs: list[SequenceOutput] = []
|
||||
for parent_id, next_token_id, logprobs in zip(
|
||||
parent_ids, next_token_ids, group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@ -253,6 +253,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
seeded_seqs: Optional[dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Utility methods for model layers."""
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,7 +13,7 @@ def get_token_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
@ -141,7 +142,7 @@ def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: List[int] = []
|
||||
added_embeddings: List[int] = []
|
||||
padding: List[int] = []
|
||||
base_embeddings: list[int] = []
|
||||
added_embeddings: list[int] = []
|
||||
padding: list[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user