mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
from math import prod
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
per_token_group_quant_fp8)
|
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
|
per_token_group_quant_int8, per_token_quant_int8)
|
|
from vllm.utils import cdiv
|
|
|
|
|
|
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
|
"""
|
|
Shrink the given tensor and apply the given view to it. This is
|
|
used to resize the intermediate fused_moe caches.
|
|
"""
|
|
assert prod(
|
|
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
|
|
return x.flatten()[:prod(v)].view(*v)
|
|
|
|
|
|
def _fp8_quantize(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
per_act_token: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Perform fp8 quantization on the inputs. If a block_shape
|
|
is provided, the output will be blocked.
|
|
"""
|
|
if block_shape is None:
|
|
A, A_scale = ops.scaled_fp8_quant(
|
|
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
|
else:
|
|
assert len(block_shape) == 2
|
|
_, block_k = block_shape[0], block_shape[1]
|
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
|
|
|
return A, A_scale
|
|
|
|
|
|
def _int8_quantize(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
per_act_token: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Perform int8 quantization on the inputs. If a block_shape
|
|
is provided, the output will be blocked.
|
|
"""
|
|
|
|
# If weights are per-channel (per_channel_quant=True), then
|
|
# activations apply per-token quantization. Otherwise, assume
|
|
# activation tensor-wise fp8/int8 quantization, dynamic or static
|
|
if block_shape is None:
|
|
assert per_act_token, \
|
|
"int8 quantization only supports block or channel-wise"
|
|
A, A_scale = per_token_quant_int8(A)
|
|
else:
|
|
assert len(block_shape) == 2
|
|
_, block_k = block_shape[0], block_shape[1]
|
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
|
|
|
return A, A_scale
|
|
|
|
|
|
def moe_kernel_quantize_input(
|
|
A: torch.Tensor,
|
|
A_scale: Optional[torch.Tensor],
|
|
qtype: Optional[torch.dtype],
|
|
per_channel_quant: bool,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if qtype == torch.float8_e4m3fn:
|
|
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
|
elif qtype == torch.int8:
|
|
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
|
else:
|
|
assert A_scale is None
|
|
return A, A_scale
|
|
|
|
|
|
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
A permutation routine that works on fp8 types.
|
|
"""
|
|
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
|
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
|
else:
|
|
return m[idx, ...]
|