From b3942e157ed540ca6cb82ef4a431233ffd9c03b2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 17 Feb 2025 19:32:48 -0500 Subject: [PATCH] [Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (#13425) Signed-off-by: Tyler Michael Smith --- .../schemes/compressed_tensors_w8a8_fp8.py | 6 ++++-- .../layers/quantization/fbgemm_fp8.py | 4 +++- vllm/model_executor/layers/quantization/fp8.py | 6 ++++-- .../layers/quantization/utils/w8a8_utils.py | 14 ++++++++------ 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5dcc41a9e5dab..32072e9fa570f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -9,8 +9,8 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) + apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -93,6 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + maybe_create_device_identity() + output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3bb8188f725c8..20f2c3da600d7 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) + apply_fp8_linear, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform @@ -84,6 +85,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + maybe_create_device_identity() weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f928ea7e23ca8..fe8ff7ca5e12f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, - requantize_with_max_scale) + maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -162,6 +162,8 @@ class Fp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + maybe_create_device_identity() + output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index bea6390f71ff7..0f93b7f6c45ba 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -9,7 +9,7 @@ from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +TORCH_DEVICE_IDENTITY = None # The condition to determine if it is on a platform that supports # torch._scaled_mm rowwise feature. @@ -113,6 +113,13 @@ def requantize_with_max_scale( return max_w_scale, weight +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -215,11 +222,6 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) - # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place