mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 02:09:08 +08:00
remove maybe_create_device_identity
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
65ecf487ad
commit
405d2802c6
@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
QuantKey,
|
QuantKey,
|
||||||
ScaleDesc,
|
ScaleDesc,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
maybe_create_device_identity,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import TestFP8Layer
|
from ..utils import TestFP8Layer
|
||||||
@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant(
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
|
||||||
|
|
||||||
custom_ops = []
|
custom_ops = []
|
||||||
if enable_rms_norm_custom_op:
|
if enable_rms_norm_custom_op:
|
||||||
|
|||||||
@ -28,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kNvfp4Quant,
|
kNvfp4Quant,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
||||||
maybe_create_device_identity,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
||||||
@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
maybe_create_device_identity()
|
|
||||||
|
|
||||||
x = torch.rand(num_tokens, hidden_size * 2)
|
x = torch.rand(num_tokens, hidden_size * 2)
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported,
|
||||||
maybe_create_device_identity,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
@ -108,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
weight_loader: Callable,
|
weight_loader: Callable,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
maybe_create_device_identity()
|
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
layer.weight_block_size = None
|
layer.weight_block_size = None
|
||||||
|
|||||||
@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTokenSym,
|
kFp8StaticTokenSym,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
maybe_create_device_identity,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
@ -112,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
maybe_create_device_identity()
|
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
del input_size, output_size
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|||||||
@ -86,7 +86,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
all_close_1d,
|
all_close_1d,
|
||||||
cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
maybe_create_device_identity,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
)
|
)
|
||||||
@ -416,8 +415,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
maybe_create_device_identity()
|
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|||||||
@ -14,17 +14,6 @@ from .ScaledMMLinearKernel import (
|
|||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
|
||||||
TORCH_DEVICE_IDENTITY = None
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_device_identity():
|
|
||||||
# Allocate dummy ones tensor for torch._scaled_mm
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY is None:
|
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_per_tensor_w8a8_scaled_mm(
|
def torch_per_tensor_w8a8_scaled_mm(
|
||||||
*,
|
*,
|
||||||
@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm(
|
|||||||
bias: torch.Tensor,
|
bias: torch.Tensor,
|
||||||
output_shape: list,
|
output_shape: list,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
# Note:
|
||||||
# when using it.
|
|
||||||
# For now it has only been validated on ROCm platform.
|
# For now it has only been validated on ROCm platform.
|
||||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||||
@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm(
|
|||||||
# For the scaled_mm fallback case, we break this down, since it
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# does not support s_w being a vector.
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
|
||||||
|
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
# GEMM
|
# GEMM
|
||||||
# This computes C = (X * W).
|
# This computes C = (X * W).
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
scale_a=dummy_tensor,
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
scale_b=dummy_tensor,
|
||||||
out_dtype=torch.float32,
|
out_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
@ -214,19 +206,11 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
is_static = c.activation_quant_key.scale.static
|
|
||||||
|
|
||||||
per_tensor_activation_scales = (
|
per_tensor_activation_scales = (
|
||||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||||
)
|
)
|
||||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||||
|
|
||||||
if not is_static:
|
|
||||||
return (
|
|
||||||
False,
|
|
||||||
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
|
|
||||||
)
|
|
||||||
|
|
||||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -3,25 +3,10 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
|
||||||
TORCH_DEVICE_IDENTITY = None
|
|
||||||
|
|
||||||
# The condition to determine if it is on a platform that supports
|
|
||||||
# torch._scaled_mm rowwise feature.
|
|
||||||
# The condition is determined once as the operations
|
|
||||||
# are time-consuming.
|
|
||||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
|
||||||
current_platform.is_rocm()
|
|
||||||
and version.parse(torch.__version__) >= version.parse("2.7")
|
|
||||||
and current_platform.has_device_capability(94)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_cutlass_supported() -> bool:
|
def sparse_cutlass_supported() -> bool:
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -129,13 +114,6 @@ def requantize_with_max_scale(
|
|||||||
return max_w_scale, weight
|
return max_w_scale, weight
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_device_identity():
|
|
||||||
# Allocate dummy ones tensor for torch._scaled_mm
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY is None:
|
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user