Skip to content

Commit da84eeb

Browse files
bzzmallamanis
authored andcommitted
tsne vis: change the model & embeddings
Use smaller model that is fast and proived a better quality 'all-MiniLM-L6-v2' from https://www.sbert.net/docs/pretrained_models.html Use title as well as abstract for paper embeddings. Encode & avg. in batches.
1 parent 41523e3 commit da84eeb

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

etc/compute_embeddings.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import torch
6+
import torch.nn.functional as F
67
import sklearn.manifold
78
import transformers
89

@@ -13,30 +14,40 @@ def parse_arguments():
1314
parser.add_argument("json", default=False, help="the path the json containing all papers.")
1415
parser.add_argument("outpath", default=False, help="the target path of the visualizations papers.")
1516
parser.add_argument("--seed", default=0, help="The seed for TSNE.", type=int)
17+
parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="Name of the HF model")
18+
1619
return parser.parse_args()
1720

21+
def mean_pooling(token_embeddings, attention_mask):
22+
""" Mean Pooling, takes attention mask into account for correct averaging"""
23+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
24+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
25+
1826

1927
if __name__ == "__main__":
2028
args = parse_arguments()
21-
tokenizer = transformers.AutoTokenizer.from_pretrained("deepset/sentence_bert")
22-
model = transformers.AutoModel.from_pretrained("deepset/sentence_bert")
29+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
30+
model = transformers.AutoModel.from_pretrained(args.model)
2331
model.eval()
2432

2533
with open(args.json) as f:
2634
data = json.load(f)
2735

2836
print(f"Num papers: {len(data)}")
2937

30-
all_embeddings = []
38+
corpus = []
3139
for paper_info in data:
32-
with torch.no_grad():
33-
token_ids = torch.tensor([tokenizer.encode(paper_info["abstract"])][:512])
34-
hidden_states, _ = model(token_ids)[-2:]
35-
all_embeddings.append(hidden_states.mean(0).mean(0).numpy())
40+
corpus.append(tokenizer.sep_token.join([paper_info['title'], paper_info['abstract']]))
41+
42+
encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt')
43+
with torch.no_grad():
44+
hidden_states = model(**encoded_corpus).last_hidden_state
45+
46+
corpus_embeddings = mean_pooling(hidden_states, encoded_corpus['attention_mask'])
47+
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1)
3648

3749
np.random.seed(args.seed)
38-
all_embeddings = np.array(all_embeddings)
39-
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(all_embeddings)
50+
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(corpus_embeddings)
4051

4152
for i, paper_info in enumerate(data):
4253
paper_info['tsne_embedding'] = out[i].tolist()

0 commit comments

Comments
 (0)