From abd1dbc548e42c47377fb29188da93617cb13bb1 Mon Sep 17 00:00:00 2001 From: majiayu000 <1835304752@qq.com> Date: Wed, 24 Dec 2025 16:02:48 +0800 Subject: [PATCH] [Bugfix] Preserve original tokenizer class name in CachedTokenizer HuggingFace transformers processor validates tokenizer type by checking the class name. When vLLM creates a CachedTokenizer with a modified class name (e.g., 'CachedQwen2TokenizerFast'), the processor type check fails with TypeError. This fix preserves the original tokenizer class name and qualname in CachedTokenizer, ensuring compatibility with HuggingFace transformers processor type checking. Fixes #31080 Signed-off-by: Claude Signed-off-by: majiayu000 <1835304752@qq.com> --- tests/tokenizers_/test_hf.py | 17 +++++++++++++++++ vllm/tokenizers/hf.py | 6 +++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/tokenizers_/test_hf.py b/tests/tokenizers_/test_hf.py index c1238900ce0d1..a343727188ee1 100644 --- a/tests/tokenizers_/test_hf.py +++ b/tests/tokenizers_/test_hf.py @@ -41,3 +41,20 @@ def _check_consistency(target: TokenizerLike, expected: TokenizerLike): ) assert target.encode("prompt") == expected.encode("prompt") + + +def test_cached_tokenizer_preserves_class_name(): + """Test that cached tokenizer preserves original class name. + + This is important for compatibility with HuggingFace transformers + processor type checking, which validates tokenizer class name. + See: https://github.com/vllm-project/vllm/issues/31080 + """ + tokenizer = AutoTokenizer.from_pretrained("gpt2") + original_class_name = tokenizer.__class__.__name__ + + cached_tokenizer = get_cached_tokenizer(tokenizer) + + # The cached tokenizer's class should have the same name as original + assert cached_tokenizer.__class__.__name__ == original_class_name + assert cached_tokenizer.__class__.__qualname__ == tokenizer.__class__.__qualname__ diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index a7b565dca5d8f..1a8a68d647140 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -58,7 +58,11 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: def __reduce__(self): return get_cached_tokenizer, (tokenizer,) - CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + # Keep the original class name to maintain compatibility with + # HuggingFace transformers processor type checking. + # The processor checks tokenizer class name against expected types. + CachedTokenizer.__name__ = tokenizer.__class__.__name__ + CachedTokenizer.__qualname__ = tokenizer.__class__.__qualname__ cached_tokenizer.__class__ = CachedTokenizer return cached_tokenizer