mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 06:52:18 +08:00
[BUGFIX] MistralTokenizer._call__ adds an invalid EOS token (#29607)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
cc0f2a0e19
commit
b2c1d294fa
@ -331,6 +331,7 @@ class TestMistralTokenizer:
|
|||||||
)
|
)
|
||||||
== token_ids
|
== token_ids
|
||||||
)
|
)
|
||||||
|
assert mistral_tokenizer.encode_one("") == []
|
||||||
|
|
||||||
def test_encode(self, mistral_tokenizer: MistralTokenizer):
|
def test_encode(self, mistral_tokenizer: MistralTokenizer):
|
||||||
token_ids = (
|
token_ids = (
|
||||||
@ -370,6 +371,51 @@ class TestMistralTokenizer:
|
|||||||
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
|
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
|
||||||
== token_ids[1:]
|
== token_ids[1:]
|
||||||
)
|
)
|
||||||
|
assert mistral_tokenizer.encode("", add_special_tokens=False) == []
|
||||||
|
|
||||||
|
def test_call(self, mistral_tokenizer: MistralTokenizer):
|
||||||
|
token_ids = (
|
||||||
|
[1, 22177, 4304, 2662]
|
||||||
|
if mistral_tokenizer.is_tekken
|
||||||
|
else [1, 23325, 2294, 1686]
|
||||||
|
)
|
||||||
|
attn_mask = [1 for _ in range(len(token_ids))]
|
||||||
|
|
||||||
|
# Test 1: default
|
||||||
|
assert mistral_tokenizer("Hello world !") == {
|
||||||
|
"attention_mask": attn_mask[1:],
|
||||||
|
"input_ids": token_ids[1:],
|
||||||
|
}
|
||||||
|
# Test 2: special tokens
|
||||||
|
assert mistral_tokenizer("Hello world !", add_special_tokens=True) == {
|
||||||
|
"attention_mask": attn_mask,
|
||||||
|
"input_ids": token_ids,
|
||||||
|
}
|
||||||
|
# Test 3: special tokens + truncation
|
||||||
|
assert mistral_tokenizer(
|
||||||
|
"Hello world !", add_special_tokens=True, truncation=True, max_length=3
|
||||||
|
) == {
|
||||||
|
"attention_mask": attn_mask[:-1],
|
||||||
|
"input_ids": token_ids[:-1],
|
||||||
|
}
|
||||||
|
# Test 4: special tokens + no truncation + max length
|
||||||
|
assert mistral_tokenizer(
|
||||||
|
"Hello world !", add_special_tokens=True, max_length=3
|
||||||
|
) == {
|
||||||
|
"attention_mask": attn_mask,
|
||||||
|
"input_ids": token_ids,
|
||||||
|
}
|
||||||
|
# Test 5: empty string
|
||||||
|
assert mistral_tokenizer("") == {
|
||||||
|
"attention_mask": [],
|
||||||
|
"input_ids": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=(r"`text_pair` is not supported by `MistralTokenizer.__call__`."),
|
||||||
|
):
|
||||||
|
mistral_tokenizer("Hello world !", "invalid pair")
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
|
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
|
||||||
@ -1087,6 +1133,24 @@ class TestMistralTokenizer:
|
|||||||
)
|
)
|
||||||
== expected_tokens[mistral_tokenizer.is_tekken]
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
mistral_tokenizer.decode(
|
||||||
|
ids[mistral_tokenizer.is_tekken],
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
)
|
||||||
|
== expected_tokens[mistral_tokenizer.is_tekken]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decode_empty(
|
||||||
|
self,
|
||||||
|
mistral_tokenizer: MistralTokenizer,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
mistral_tokenizer.decode(
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
== ""
|
||||||
|
)
|
||||||
|
|
||||||
def test_decode_int(
|
def test_decode_int(
|
||||||
self,
|
self,
|
||||||
@ -1390,6 +1454,8 @@ class TestMistralTokenizer:
|
|||||||
== expected_strings[mistral_tokenizer.is_tekken]
|
== expected_strings[mistral_tokenizer.is_tekken]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert mistral_tokenizer.convert_tokens_to_string([]) == ""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"skip_special_tokens,tuple_expected_tokens",
|
"skip_special_tokens,tuple_expected_tokens",
|
||||||
(
|
(
|
||||||
@ -2220,3 +2286,5 @@ class TestMistralTokenizer:
|
|||||||
ids, skip_special_tokens=skip_special_tokens
|
ids, skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
assert actual_tokens == expected_tokens
|
assert actual_tokens == expected_tokens
|
||||||
|
|
||||||
|
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
|
||||||
|
|||||||
@ -312,13 +312,27 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
max_length: int | None = None,
|
max_length: int | None = None,
|
||||||
):
|
):
|
||||||
return self.transformers_tokenizer(
|
if text_pair is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`text_pair` is not supported by `MistralTokenizer.__call__`."
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = self.transformers_tokenizer(
|
||||||
text=text,
|
text=text,
|
||||||
text_pair=text_pair,
|
text_pair=text_pair,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
|
||||||
|
# is in, revert to only call self.transformers_tokenizer(...).
|
||||||
|
# Hack to fix wrongly added eos token, when fix will be supported the condition
|
||||||
|
# below will be False even before the revert is done.
|
||||||
|
if encoded["input_ids"] and encoded["input_ids"][-1] == self.eos_token_id:
|
||||||
|
encoded["input_ids"].pop(-1)
|
||||||
|
if attention_mask := encoded.get("attention_mask"):
|
||||||
|
attention_mask.pop(-1)
|
||||||
|
return encoded
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab(self) -> list[str]:
|
def vocab(self) -> list[str]:
|
||||||
@ -349,6 +363,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
max_length: int | None = None,
|
max_length: int | None = None,
|
||||||
add_special_tokens: bool | None = None,
|
add_special_tokens: bool | None = None,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
|
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
|
||||||
|
# is in, directly call self.transformers_tokenizer.encode(...).
|
||||||
encoded = self.tokenizer.encode(
|
encoded = self.tokenizer.encode(
|
||||||
text, bos=add_special_tokens is not False, eos=False
|
text, bos=add_special_tokens is not False, eos=False
|
||||||
)
|
)
|
||||||
@ -387,6 +403,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
|
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
|
||||||
|
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
|
||||||
|
# is in, directly call self.transformers_tokenizer.decode(...).
|
||||||
if isinstance(ids, int):
|
if isinstance(ids, int):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user