mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 19:39:07 +08:00
[Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (#13425)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
cd4a72a28d
commit
b3942e157e
@ -9,8 +9,8 @@ from torch.nn import Parameter
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
|
||||||
requantize_with_max_scale)
|
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -93,6 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
maybe_create_device_identity()
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
|
apply_fp8_linear, maybe_create_device_identity,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter)
|
ModelWeightParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -84,6 +85,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
maybe_create_device_identity()
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
del input_size, output_size
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|||||||
@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
requantize_with_max_scale)
|
per_tensor_dequantize, requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -162,6 +162,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
maybe_create_device_identity()
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
TORCH_DEVICE_IDENTITY = None
|
||||||
|
|
||||||
# The condition to determine if it is on a platform that supports
|
# The condition to determine if it is on a platform that supports
|
||||||
# torch._scaled_mm rowwise feature.
|
# torch._scaled_mm rowwise feature.
|
||||||
@ -113,6 +113,13 @@ def requantize_with_max_scale(
|
|||||||
return max_w_scale, weight
|
return max_w_scale, weight
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_create_device_identity():
|
||||||
|
# Allocate dummy ones tensor for torch._scaled_mm
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY is None:
|
||||||
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def apply_fp8_linear(
|
def apply_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@ -215,11 +222,6 @@ def apply_fp8_linear(
|
|||||||
# For the scaled_mm fallback case, we break this down, since it
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# does not support s_w being a vector.
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
|
||||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
|
||||||
|
|
||||||
# GEMM
|
# GEMM
|
||||||
# This computes C = (X * W).
|
# This computes C = (X * W).
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user