From 32c9be2200a26589b356fe333246b33b5006f4d7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 5 Jul 2025 17:41:10 +0800 Subject: [PATCH] [v1] Re-add fp32 support to v1 engine through FlexAttention (#19754) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py --- .github/workflows/lint-and-deploy.yaml | 2 +- .../attention/test_attention_selector.py | 28 +++++++++++++++++++ tests/v1/worker/test_gpu_model_runner.py | 5 ++++ vllm/engine/arg_utils.py | 7 ----- .../model_loader/tensorizer_loader.py | 8 ++++-- vllm/platforms/cuda.py | 4 +++ vllm/v1/attention/backends/flex_attention.py | 12 +++++++- vllm/v1/sample/ops/topk_topp_sampler.py | 5 +++- 8 files changed, 59 insertions(+), 12 deletions(-) diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 64011922ad825..74a7a3a3530f5 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -68,7 +68,7 @@ jobs: export AWS_ACCESS_KEY_ID=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" + helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - name: curl test run: | diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index a8ed749ba13b5..0437bb8293cea 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -181,6 +181,34 @@ def test_env( assert backend.get_name() == expected +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("use_v1", [True, False]) +def test_fp32_fallback( + device: str, + use_v1: bool, + monkeypatch: pytest.MonkeyPatch, +): + """Test attention backend selection with fp32.""" + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + + if device == "cpu": + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, + 16, False) + assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" + if use_v1 else "TORCH_SDPA") + + elif device == "cuda": + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, + 16, False) + assert (backend.get_name() == "FLEX_ATTENTION" + if use_v1 else "XFORMERS") + + def test_flash_attn(monkeypatch: pytest.MonkeyPatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 2e1deecbd9e67..d13df553db623 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} must come before the current layer" @@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" invalid_layer = "model.layers.0.cross_attn.attn" @@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_same_as_current(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} cannot be the same as the current layer" @@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_without_kv_sharing(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() @@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_with_kv_sharing_valid(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 284f092361311..cf94b6a642818 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1393,13 +1393,6 @@ class EngineArgs: recommend_to_remove=False) return False - # Only Fp16 and Bf16 dtypes since we only support FA. - V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] - if model_config.dtype not in V1_SUPPORTED_DTYPES: - _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}", - recommend_to_remove=False) - return False - # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index b9982f312fe52..0b62e744e445c 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -104,8 +104,12 @@ class TensorizerLoader(BaseModelLoader): if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) - model = init_tensorizer_model(tensorizer_config=tensorizer_config, - vllm_config=vllm_config) + device_config = vllm_config.device_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = init_tensorizer_model( + tensorizer_config=tensorizer_config, + vllm_config=vllm_config) self.load_weights(model, model_config) return model return self._load_model_serialized_cpu(vllm_config=vllm_config) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 15cab757d2c03..f82c1e569977c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -251,6 +251,10 @@ class CudaPlatformBase(Platform): # Default backends for V1 engine # Prefer FlashInfer for Blackwell GPUs if installed + if dtype not in (torch.float16, torch.bfloat16): + logger.info_once( + f"Using FlexAttenion backend for {dtype} on V1 engine.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 if cls.is_device_capability(100): try: import flashinfer # noqa: F401 diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index f5407aaeb54f8..ebd5914ee40ac 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl): query = query[:, :, :num_actual_tokens, :] # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) + + # default M=64, N=64 may run out of shared memory on + # some GPUs with fp32, so we use smaller M and N. + extra_kernel_options = { + "BLOCK_M": 32, + "BLOCK_N": 32 + } if query.dtype == torch.float32 else {} out = flex_attention_compiled( query, key_cache, @@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl): attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, + kernel_options={ + "FORCE_USE_FLEX_ATTENTION": True, + **extra_kernel_options + }, ) # Flex doesn't have an out variant today, rely on epilogue fusion diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 30396f1594337..87a84e5bf4350 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -101,7 +101,10 @@ class TopKTopPSampler(nn.Module): "per-request generators. Falling back to " "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) - return flashinfer_sample(logits, k, p, generators) + # flashinfer sampling functions expect contiguous logits. + # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous + # because of slicing operation in logits_processor. + return flashinfer_sample(logits.contiguous(), k, p, generators) def forward_tpu( self,