diff options
Diffstat (limited to 'merge_embeddings.py')
-rw-r--r-- | merge_embeddings.py | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/merge_embeddings.py b/merge_embeddings.py new file mode 100644 index 0000000..3bb7493 --- /dev/null +++ b/merge_embeddings.py @@ -0,0 +1,111 @@ +from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder +from ldm.modules.embedding_manager import EmbeddingManager + +import argparse, os +from functools import partial + +import torch + +def get_placeholder_loop(placeholder_string, embedder, is_sd): + + new_placeholder = None + + while True: + if new_placeholder is None: + new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") + else: + new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") + + token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder) + + if token is not None: + return new_placeholder, token + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + + if torch.count_nonzero(tokens - 49407) == 2: + return tokens[0, 1] + + return None + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + if torch.count_nonzero(token) == 3: + return token[0, 1] + + return None + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manager_ckpts", + type=str, + nargs="+", + required=True, + help="Paths to a set of embedding managers to be merged." + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output path for the merged manager", + ) + + parser.add_argument( + "-sd", "--stable_diffusion", + action="store_true", + help="Flag to denote that we are merging stable diffusion embeddings" + ) + + args = parser.parse_args() + + if args.stable_diffusion: + embedder = FrozenCLIPEmbedder().cuda() + else: + embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() + + EmbeddingManager = partial(EmbeddingManager, embedder, ["*"]) + + string_to_token_dict = {} + string_to_param_dict = torch.nn.ParameterDict() + + placeholder_to_src = {} + + for manager_ckpt in args.manager_ckpts: + print(f"Parsing {manager_ckpt}...") + + manager = EmbeddingManager() + manager.load(manager_ckpt) + + for placeholder_string in manager.string_to_token_dict: + if not placeholder_string in string_to_token_dict: + string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] + string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[placeholder_string] = manager_ckpt + else: + new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion) + string_to_token_dict[new_placeholder] = new_token + string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[new_placeholder] = manager_ckpt + + print("Saving combined manager...") + merged_manager = EmbeddingManager() + merged_manager.string_to_param_dict = string_to_param_dict + merged_manager.string_to_token_dict = string_to_token_dict + merged_manager.save(args.output_path) + + print("Managers merged. Final list of placeholders: ") + print(placeholder_to_src) + + + + |