From feaf202e93a18228cd34ae2e6698aece83a15e87 Mon Sep 17 00:00:00 2001 From: Remy Date: Wed, 10 Sep 2025 15:24:42 +0900 Subject: [PATCH] [Bugfix] Guard `_may_reorder_batch` for encoder-only models on CPU (#24319) (#24348) Signed-off-by: Remy Co-authored-by: Li, Jiang --- .../models/language/pooling/test_embedding.py | 10 ++++++++-- vllm/config/__init__.py | 3 ++- vllm/v1/worker/cpu_model_runner.py | 20 +++++++++++++++---- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 0733ac85c11f..41574b844a66 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -27,11 +27,17 @@ from ...utils import check_embeddings_close pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2"), + pytest.param( + "sentence-transformers/stsb-roberta-base-v2", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) def test_models( diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index a859f44eb40d..7422527a6854 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3665,7 +3665,8 @@ class VllmConfig: # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. - if not (current_platform.is_cuda() or current_platform.is_rocm()): + if not (current_platform.is_cuda() or current_platform.is_rocm() + or current_platform.is_cpu()): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index feb49978d751..d5ec19b86b06 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -55,11 +55,23 @@ class CPUModelRunner(GPUModelRunner): raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - assert type(self.attn_groups[0] - [0].metadata_builder) is TorchSDPAMetadataBuilderV1 + # Guard against encoder-only / pooling models where `attn_groups` + # may be empty or lack the expected metadata_builder. + # Without this check, accessing `attn_groups[0][0]` would trigger + # an AssertionError on CPU backend. + if not hasattr(self, "attn_groups") or not self.attn_groups: + return + if not self.attn_groups[0]: + return - self.attn_groups[0][0].metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + mb = getattr(self.attn_groups[0][0], "metadata_builder", None) + if not isinstance(mb, TorchSDPAMetadataBuilderV1): + # Encoder-only / rerank models do not benefit from reordering, + # so we safely skip here. + return + + # Safe path for decoder/attention-heavy models + mb.reorder_batch(self.input_batch, scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors