diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index f8a7c2a1b..3066de2d7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -468,6 +468,7 @@ class SDTokenizer: self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding + self.pad_left = pad_left empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -522,6 +523,12 @@ class SDTokenizer: return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) return (embed, leftover) + def pad_tokens(self, tokens, amount): + if self.pad_left: + for i in range(amount): + tokens.insert(0, (self.pad_token, 1.0, 0)) + else: + tokens.extend([(self.pad_token, 1.0, 0)] * amount) def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' @@ -600,7 +607,7 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) + self.pad_tokens(batch, remaining_length) #start new batch batch = [] if self.start_token is not None: @@ -614,11 +621,11 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if min_padding is not None: - batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + self.pad_tokens(batch, min_padding) if self.pad_to_max_length and len(batch) < self.max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) + self.pad_tokens(batch, self.max_length - len(batch)) if min_length is not None and len(batch) < min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) + self.pad_tokens(batch, min_length - len(batch)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]