From 405d2802c68237703068e1ae6f157fc5c4683c2f Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 12 Nov 2025 15:52:49 +0000 Subject: [PATCH] remove maybe_create_device_identity Signed-off-by: vllmellm --- tests/compile/test_fusion.py | 4 --- tests/compile/test_silu_mul_quant_fusion.py | 4 --- .../schemes/compressed_tensors_w8a8_fp8.py | 3 -- .../layers/quantization/fbgemm_fp8.py | 2 -- .../model_executor/layers/quantization/fp8.py | 3 -- .../quantization/kernels/scaled_mm/pytorch.py | 30 +++++-------------- .../layers/quantization/utils/w8a8_utils.py | 22 -------------- 7 files changed, 7 insertions(+), 61 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aa4d2c8cf4537..9a5af9f1d245d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, ScaleDesc, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, -) from vllm.platforms import current_platform from ..utils import TestFP8Layer @@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) - maybe_create_device_identity() # needed for certain non-cutlass fp8 paths custom_ops = [] if enable_rms_norm_custom_op: diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 56b36856f7f29..c4f6f2d4c4d9b 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -28,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, -) from vllm.platforms import current_platform from ..utils import TestFP8Layer, override_cutlass_fp8_supported @@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant( torch.set_default_device("cuda") torch.set_default_dtype(dtype) - maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 01204b10ea18a..5480383126e3b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, - maybe_create_device_identity, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -108,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): weight_loader: Callable, **kwargs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.weight_block_size = None diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index bcd02554008ca..45d2e4e338190 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.parameter import ( @@ -112,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 91115d7437e25..57b3736ed7fd1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -86,7 +86,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) @@ -416,8 +415,6 @@ class Fp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - maybe_create_device_identity() - output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c272f579d8bcc..ad21d68c0f52d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -14,17 +14,6 @@ from .ScaledMMLinearKernel import ( FP8ScaledMMLinearLayerConfig, ) -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - - -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - def torch_per_tensor_w8a8_scaled_mm( *, @@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm( bias: torch.Tensor, output_shape: list, ) -> torch.Tensor: - # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM - # when using it. + # Note: # For now it has only been validated on ROCm platform. # fp8 rowwise scaling in torch._scaled_mm is introduced in # https://github.com/pytorch/pytorch/pull/144432 using @@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # Input scaling factors are no longer optional in _scaled_mm starting + # from pytorch 2.5. Allocating a dummy tensor to pass as scales + dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device) + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place output = torch._scaled_mm( A, B, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, + scale_a=dummy_tensor, + scale_b=dummy_tensor, out_dtype=torch.float32, ) # A fix for discrepancy in scaled_mm which returns tuple @@ -214,19 +206,11 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - is_static = c.activation_quant_key.scale.static - per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() - if not is_static: - return ( - False, - "ChannelWiseTorchScaledMMLinearKernel requires static scales", - ) - if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index f2d8eecdc68e5..c7fcb5a4b33b1 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -3,25 +3,10 @@ import torch -from packaging import version from vllm import _custom_ops as ops from vllm.platforms import current_platform -# Input scaling factors are no longer optional in _scaled_mm starting -# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = None - -# The condition to determine if it is on a platform that supports -# torch._scaled_mm rowwise feature. -# The condition is determined once as the operations -# are time-consuming. -USE_ROWWISE_TORCH_SCALED_MM = ( - current_platform.is_rocm() - and version.parse(torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94) -) - def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): @@ -129,13 +114,6 @@ def requantize_with_max_scale( return max_w_scale, weight -def maybe_create_device_identity(): - # Allocate dummy ones tensor for torch._scaled_mm - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) - - def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor,