Skip to content

Commit caf2b61

Browse files
bzzmallamanis
authored andcommitted
tsne vis: batch_size=4 & cli arg for TF Projector format
1 parent da84eeb commit caf2b61

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

etc/compute_embeddings.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
22
import json
3+
from timeit import default_timer as timer
4+
from datetime import date
35

46
import numpy as np
57
import torch
@@ -14,7 +16,8 @@ def parse_arguments():
1416
parser.add_argument("json", default=False, help="the path the json containing all papers.")
1517
parser.add_argument("outpath", default=False, help="the target path of the visualizations papers.")
1618
parser.add_argument("--seed", default=0, help="The seed for TSNE.", type=int)
17-
parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="Name of the HF model")
19+
parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="The name of the HF model")
20+
parser.add_argument("--save_emb", action='store_true', help="Save embeddings in CSV for Tensorboard Projector")
1821

1922
return parser.parse_args()
2023

@@ -23,9 +26,7 @@ def mean_pooling(token_embeddings, attention_mask):
2326
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
2427
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
2528

26-
27-
if __name__ == "__main__":
28-
args = parse_arguments()
29+
def main(args):
2930
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
3031
model = transformers.AutoModel.from_pretrained(args.model)
3132
model.eval()
@@ -39,18 +40,38 @@ def mean_pooling(token_embeddings, attention_mask):
3940
for paper_info in data:
4041
corpus.append(tokenizer.sep_token.join([paper_info['title'], paper_info['abstract']]))
4142

42-
encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt')
43-
with torch.no_grad():
44-
hidden_states = model(**encoded_corpus).last_hidden_state
45-
46-
corpus_embeddings = mean_pooling(hidden_states, encoded_corpus['attention_mask'])
47-
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1)
43+
batch_size = 4
44+
all_embeddings=[]
45+
start = timer()
46+
for i in range(0, len(corpus), batch_size):
47+
encoded_batch = tokenizer(corpus[i:min(i+batch_size, len(corpus))], padding=True, truncation=True, return_tensors='pt')
48+
with torch.no_grad():
49+
hidden_state = model(**encoded_batch).last_hidden_state
50+
all_embeddings.append(mean_pooling(hidden_state, encoded_batch['attention_mask']))
51+
52+
all_embeddings = torch.cat(all_embeddings, dim=0)
53+
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
54+
print(f"elapsed {timer()-start:.1f}s")
55+
56+
if args.save_emb:
57+
filename = f"{args.model.replace('/', '_')}-{date.today().strftime('%d.%m.%y')}"
58+
np.savetxt(f"{filename}-emb.tsv", all_embeddings, delimiter="\t")
59+
import csv
60+
with open(f"{filename}-meta.tsv", 'w', newline='') as csvfile:
61+
w = csv.writer(csvfile, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
62+
w.writerow(["year", "key", "title"])
63+
for paper in data:
64+
w.writerow([paper["year"], paper["key"], paper["title"]])
4865

4966
np.random.seed(args.seed)
50-
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(corpus_embeddings)
67+
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(all_embeddings)
5168

5269
for i, paper_info in enumerate(data):
5370
paper_info['tsne_embedding'] = out[i].tolist()
5471

5572
with open(args.outpath, 'w') as f:
5673
json.dump(data, f)
74+
75+
if __name__ == "__main__":
76+
args = parse_arguments()
77+
main(args)

0 commit comments

Comments
 (0)