aboutsummaryrefslogtreecommitdiff
path: root/merge_embeddings.py
blob: 3bb7493bfa3d5e78b199d1815e4839e1798ccdd1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager

import argparse, os
from functools import partial

import torch

def get_placeholder_loop(placeholder_string, embedder, is_sd):
    
    new_placeholder   = None
    
    while True:
        if new_placeholder is None:
            new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
        else:
            new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")

        token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder)

        if token is not None:
            return new_placeholder, token
            
def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    tokens = batch_encoding["input_ids"]

    if torch.count_nonzero(tokens - 49407) == 2:
        return tokens[0, 1]
    
    return None

def get_bert_token_for_string(tokenizer, string):
    token = tokenizer(string)
    if torch.count_nonzero(token) == 3:
        return token[0, 1]

    return None


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--manager_ckpts", 
        type=str, 
        nargs="+", 
        required=True,
        help="Paths to a set of embedding managers to be merged."
    )

    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="Output path for the merged manager",
    )

    parser.add_argument(
        "-sd", "--stable_diffusion",
        action="store_true",
        help="Flag to denote that we are merging stable diffusion embeddings"
    )

    args = parser.parse_args()

    if args.stable_diffusion:
        embedder = FrozenCLIPEmbedder().cuda()
    else:
        embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()

    EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])

    string_to_token_dict = {}    
    string_to_param_dict = torch.nn.ParameterDict()

    placeholder_to_src = {}

    for manager_ckpt in args.manager_ckpts:
        print(f"Parsing {manager_ckpt}...")

        manager = EmbeddingManager()
        manager.load(manager_ckpt)

        for placeholder_string in manager.string_to_token_dict:
            if not placeholder_string in string_to_token_dict:
                string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
                string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]

                placeholder_to_src[placeholder_string] = manager_ckpt
            else:
                new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion)
                string_to_token_dict[new_placeholder] = new_token
                string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]

                placeholder_to_src[new_placeholder] = manager_ckpt

    print("Saving combined manager...")
    merged_manager = EmbeddingManager()
    merged_manager.string_to_param_dict = string_to_param_dict
    merged_manager.string_to_token_dict = string_to_token_dict
    merged_manager.save(args.output_path)

    print("Managers merged. Final list of placeholders: ")
    print(placeholder_to_src)