98 lines
3.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
return x.flatten()[:prod(v)].view(*v)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale
def _int8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
return A, A_scale
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
qtype: Optional[torch.dtype],
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
elif qtype == torch.int8:
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
else:
assert A_scale is None
return A, A_scale
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
A permutation routine that works on fp8 types.
"""
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]