[Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (#13425)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-02-17 19:32:48 -05:00 committed by GitHub
parent cd4a72a28d
commit b3942e157e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 11 deletions

View File

@ -9,8 +9,8 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
requantize_with_max_scale) normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
@ -93,6 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes

View File

@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped) is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) apply_fp8_linear, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -84,6 +85,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)

View File

@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise, all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale) per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
@ -162,6 +162,8 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")

View File

@ -9,7 +9,7 @@ from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) TORCH_DEVICE_IDENTITY = None
# The condition to determine if it is on a platform that supports # The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature. # torch._scaled_mm rowwise feature.
@ -113,6 +113,13 @@ def requantize_with_max_scale(
return max_w_scale, weight return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def apply_fp8_linear( def apply_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
@ -215,11 +222,6 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place