Merge abd1dbc548e42c47377fb29188da93617cb13bb1 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
lif 2025-12-25 00:06:54 +00:00 committed by GitHub
commit b59d768625
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 1 deletions

View File

@ -41,3 +41,20 @@ def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
)
assert target.encode("prompt") == expected.encode("prompt")
def test_cached_tokenizer_preserves_class_name():
"""Test that cached tokenizer preserves original class name.
This is important for compatibility with HuggingFace transformers
processor type checking, which validates tokenizer class name.
See: https://github.com/vllm-project/vllm/issues/31080
"""
tokenizer = AutoTokenizer.from_pretrained("gpt2")
original_class_name = tokenizer.__class__.__name__
cached_tokenizer = get_cached_tokenizer(tokenizer)
# The cached tokenizer's class should have the same name as original
assert cached_tokenizer.__class__.__name__ == original_class_name
assert cached_tokenizer.__class__.__qualname__ == tokenizer.__class__.__qualname__

View File

@ -58,7 +58,11 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
def __reduce__(self):
return get_cached_tokenizer, (tokenizer,)
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
# Keep the original class name to maintain compatibility with
# HuggingFace transformers processor type checking.
# The processor checks tokenizer class name against expected types.
CachedTokenizer.__name__ = tokenizer.__class__.__name__
CachedTokenizer.__qualname__ = tokenizer.__class__.__qualname__
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer