33
44import numpy as np
55import torch
6+ import torch .nn .functional as F
67import sklearn .manifold
78import transformers
89
@@ -13,30 +14,40 @@ def parse_arguments():
1314 parser .add_argument ("json" , default = False , help = "the path the json containing all papers." )
1415 parser .add_argument ("outpath" , default = False , help = "the target path of the visualizations papers." )
1516 parser .add_argument ("--seed" , default = 0 , help = "The seed for TSNE." , type = int )
17+ parser .add_argument ("--model" , default = 'sentence-transformers/all-MiniLM-L6-v2' , help = "Name of the HF model" )
18+
1619 return parser .parse_args ()
1720
21+ def mean_pooling (token_embeddings , attention_mask ):
22+ """ Mean Pooling, takes attention mask into account for correct averaging"""
23+ input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
24+ return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
25+
1826
1927if __name__ == "__main__" :
2028 args = parse_arguments ()
21- tokenizer = transformers .AutoTokenizer .from_pretrained ("deepset/sentence_bert" )
22- model = transformers .AutoModel .from_pretrained ("deepset/sentence_bert" )
29+ tokenizer = transformers .AutoTokenizer .from_pretrained (args . model )
30+ model = transformers .AutoModel .from_pretrained (args . model )
2331 model .eval ()
2432
2533 with open (args .json ) as f :
2634 data = json .load (f )
2735
2836 print (f"Num papers: { len (data )} " )
2937
30- all_embeddings = []
38+ corpus = []
3139 for paper_info in data :
32- with torch .no_grad ():
33- token_ids = torch .tensor ([tokenizer .encode (paper_info ["abstract" ])][:512 ])
34- hidden_states , _ = model (token_ids )[- 2 :]
35- all_embeddings .append (hidden_states .mean (0 ).mean (0 ).numpy ())
40+ corpus .append (tokenizer .sep_token .join ([paper_info ['title' ], paper_info ['abstract' ]]))
41+
42+ encoded_corpus = tokenizer (corpus , padding = True , truncation = True , return_tensors = 'pt' )
43+ with torch .no_grad ():
44+ hidden_states = model (** encoded_corpus ).last_hidden_state
45+
46+ corpus_embeddings = mean_pooling (hidden_states , encoded_corpus ['attention_mask' ])
47+ corpus_embeddings = F .normalize (corpus_embeddings , p = 2 , dim = 1 )
3648
3749 np .random .seed (args .seed )
38- all_embeddings = np .array (all_embeddings )
39- out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (all_embeddings )
50+ out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (corpus_embeddings )
4051
4152 for i , paper_info in enumerate (data ):
4253 paper_info ['tsne_embedding' ] = out [i ].tolist ()
0 commit comments