remove maybe_create_device_identity

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-12 15:52:49 +00:00
parent 65ecf487ad
commit 405d2802c6
7 changed files with 7 additions and 61 deletions

View File

@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import TestFP8Layer
@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
custom_ops = []
if enable_rms_norm_custom_op:

View File

@ -28,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
maybe_create_device_identity,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@ -108,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None

View File

@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
@ -112,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

View File

@ -86,7 +86,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
@ -416,8 +415,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes

View File

@ -14,17 +14,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
)
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def torch_per_tensor_w8a8_scaled_mm(
*,
@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm(
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# Note:
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
scale_a=dummy_tensor,
scale_b=dummy_tensor,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
@ -214,19 +206,11 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_static = c.activation_quant_key.scale.static
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not is_static:
return (
False,
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,

View File

@ -3,25 +3,10 @@
import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time-consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
current_platform.is_rocm()
and version.parse(torch.__version__) >= version.parse("2.7")
and current_platform.has_device_capability(94)
)
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
@ -129,13 +114,6 @@ def requantize_with_max_scale(
return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,