mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Add left padding support to tokenizers. (#10753)
This commit is contained in:
parent
443056c401
commit
bd01d9f7fd
@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
@ -468,6 +468,7 @@ class SDTokenizer:
|
|||||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
self.min_padding = min_padding
|
self.min_padding = min_padding
|
||||||
|
self.pad_left = pad_left
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
self.tokenizer_adds_end_token = has_end_token
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
@ -522,6 +523,12 @@ class SDTokenizer:
|
|||||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||||
return (embed, leftover)
|
return (embed, leftover)
|
||||||
|
|
||||||
|
def pad_tokens(self, tokens, amount):
|
||||||
|
if self.pad_left:
|
||||||
|
for i in range(amount):
|
||||||
|
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||||
|
else:
|
||||||
|
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||||
'''
|
'''
|
||||||
@ -600,7 +607,7 @@ class SDTokenizer:
|
|||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
self.pad_tokens(batch, remaining_length)
|
||||||
#start new batch
|
#start new batch
|
||||||
batch = []
|
batch = []
|
||||||
if self.start_token is not None:
|
if self.start_token is not None:
|
||||||
@ -614,11 +621,11 @@ class SDTokenizer:
|
|||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if min_padding is not None:
|
if min_padding is not None:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
self.pad_tokens(batch, min_padding)
|
||||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
self.pad_tokens(batch, self.max_length - len(batch))
|
||||||
if min_length is not None and len(batch) < min_length:
|
if min_length is not None and len(batch) < min_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
self.pad_tokens(batch, min_length - len(batch))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user