mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 02:27:02 +08:00
remove maybe_create_device_identity
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
65ecf487ad
commit
405d2802c6
@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant(
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
|
||||
@ -28,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
||||
@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant(
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
maybe_create_device_identity()
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
|
||||
@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
@ -108,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
@ -112,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
@ -86,7 +86,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
@ -416,8 +415,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
@ -14,17 +14,6 @@ from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(
|
||||
*,
|
||||
@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm(
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# Note:
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm(
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
|
||||
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
scale_a=dummy_tensor,
|
||||
scale_b=dummy_tensor,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
@ -214,19 +206,11 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
is_static = c.activation_quant_key.scale.static
|
||||
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not is_static:
|
||||
return (
|
||||
False,
|
||||
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -3,25 +3,10 @@
|
||||
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time-consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
||||
current_platform.is_rocm()
|
||||
and version.parse(torch.__version__) >= version.parse("2.7")
|
||||
and current_platform.has_device_capability(94)
|
||||
)
|
||||
|
||||
|
||||
def sparse_cutlass_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
@ -129,13 +114,6 @@ def requantize_with_max_scale(
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user