Fix infinite generation loop

Fix: prevent infinite token generation loop on repetitive patterns (A5A5...) in generate()
This commit is contained in:
Ceaser1717 2025-10-19 22:29:20 +05:30 committed by GitHub
parent 9b4e9788e4
commit 031930fb29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,8 @@ from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
import re
import torch
def sample(logits, temperature: float = 1.0):
@ -183,3 +185,85 @@ if __name__ == "__main__":
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
import re
import torch
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens with added repetition protection logic.
Prevents infinite loops when repetitive patterns like 'A5A5A5...' occur.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, (
f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
)
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
# --- New repetition tracking variables ---
repeat_threshold = 10
repeat_count = 0
last_token = None
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
# --- 🔍 Repetition detection logic ---
token_text = str(next_token.tolist())
if last_token == token_text:
repeat_count += 1
else:
repeat_count = 0
last_token = token_text
# If same token repeats too many times → stop
if repeat_count > repeat_threshold:
print("[⚠️] Stopping generation: excessive repetition detected.")
break
# Detect long repeating hex-like pattern such as 'A5A5A5...'
output_str = "".join([str(x) for x in tokens[0].tolist() if x != -1])
if re.search(r'(A5){6,}', output_str):
print("[⚠️] Infinite hex pattern detected — stopping early.")
break
# Normal stopping condition
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
# Extract generated completion tokens
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens