remove maybe_create_device_identity

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-12 15:52:49 +00:00
parent 65ecf487ad
commit 405d2802c6
7 changed files with 7 additions and 61 deletions

View File

@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
ScaleDesc, ScaleDesc,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import TestFP8Layer from ..utils import TestFP8Layer
@ -191,7 +188,6 @@ def test_fusion_rmsnorm_quant(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
custom_ops = [] custom_ops = []
if enable_rms_norm_custom_op: if enable_rms_norm_custom_op:

View File

@ -28,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import TestFP8Layer, override_cutlass_fp8_supported from ..utils import TestFP8Layer, override_cutlass_fp8_supported
@ -157,7 +154,6 @@ def test_fusion_silu_and_mul_quant(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2) x = torch.rand(num_tokens, hidden_size * 2)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
maybe_create_device_identity,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
@ -108,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.weight_block_size = None layer.weight_block_size = None

View File

@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTokenSym, kFp8StaticTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
@ -112,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)

View File

@ -86,7 +86,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, all_close_1d,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
cutlass_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, per_tensor_dequantize,
) )
@ -416,8 +415,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes

View File

@ -14,17 +14,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def torch_per_tensor_w8a8_scaled_mm( def torch_per_tensor_w8a8_scaled_mm(
*, *,
@ -57,8 +46,7 @@ def torch_row_wise_w8a8_scaled_mm(
bias: torch.Tensor, bias: torch.Tensor,
output_shape: list, output_shape: list,
) -> torch.Tensor: ) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # Note:
# when using it.
# For now it has only been validated on ROCm platform. # For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in # fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using # https://github.com/pytorch/pytorch/pull/144432 using
@ -106,14 +94,18 @@ def torch_channelwise_w8a8_scaled_mm(
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm( output = torch._scaled_mm(
A, A,
B, B,
scale_a=TORCH_DEVICE_IDENTITY, scale_a=dummy_tensor,
scale_b=TORCH_DEVICE_IDENTITY, scale_b=dummy_tensor,
out_dtype=torch.float32, out_dtype=torch.float32,
) )
# A fix for discrepancy in scaled_mm which returns tuple # A fix for discrepancy in scaled_mm which returns tuple
@ -214,19 +206,11 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_static = c.activation_quant_key.scale.static
per_tensor_activation_scales = ( per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor() c.activation_quant_key.scale.group_shape.is_per_tensor()
) )
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not is_static:
return (
False,
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
)
if per_tensor_activation_scales and per_tensor_weight_scales: if per_tensor_activation_scales and per_tensor_weight_scales:
return ( return (
False, False,

View File

@ -3,25 +3,10 @@
import torch import torch
from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time-consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
current_platform.is_rocm()
and version.parse(torch.__version__) >= version.parse("2.7")
and current_platform.has_device_capability(94)
)
def sparse_cutlass_supported() -> bool: def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
@ -129,13 +114,6 @@ def requantize_with_max_scale(
return max_w_scale, weight return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,