11import argparse
22import json
3+ from timeit import default_timer as timer
4+ from datetime import date
35
46import numpy as np
57import torch
@@ -14,7 +16,8 @@ def parse_arguments():
1416 parser .add_argument ("json" , default = False , help = "the path the json containing all papers." )
1517 parser .add_argument ("outpath" , default = False , help = "the target path of the visualizations papers." )
1618 parser .add_argument ("--seed" , default = 0 , help = "The seed for TSNE." , type = int )
17- parser .add_argument ("--model" , default = 'sentence-transformers/all-MiniLM-L6-v2' , help = "Name of the HF model" )
19+ parser .add_argument ("--model" , default = 'sentence-transformers/all-MiniLM-L6-v2' , help = "The name of the HF model" )
20+ parser .add_argument ("--save_emb" , action = 'store_true' , help = "Save embeddings in CSV for Tensorboard Projector" )
1821
1922 return parser .parse_args ()
2023
@@ -23,9 +26,7 @@ def mean_pooling(token_embeddings, attention_mask):
2326 input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
2427 return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
2528
26-
27- if __name__ == "__main__" :
28- args = parse_arguments ()
29+ def main (args ):
2930 tokenizer = transformers .AutoTokenizer .from_pretrained (args .model )
3031 model = transformers .AutoModel .from_pretrained (args .model )
3132 model .eval ()
@@ -39,18 +40,38 @@ def mean_pooling(token_embeddings, attention_mask):
3940 for paper_info in data :
4041 corpus .append (tokenizer .sep_token .join ([paper_info ['title' ], paper_info ['abstract' ]]))
4142
42- encoded_corpus = tokenizer (corpus , padding = True , truncation = True , return_tensors = 'pt' )
43- with torch .no_grad ():
44- hidden_states = model (** encoded_corpus ).last_hidden_state
45-
46- corpus_embeddings = mean_pooling (hidden_states , encoded_corpus ['attention_mask' ])
47- corpus_embeddings = F .normalize (corpus_embeddings , p = 2 , dim = 1 )
43+ batch_size = 4
44+ all_embeddings = []
45+ start = timer ()
46+ for i in range (0 , len (corpus ), batch_size ):
47+ encoded_batch = tokenizer (corpus [i :min (i + batch_size , len (corpus ))], padding = True , truncation = True , return_tensors = 'pt' )
48+ with torch .no_grad ():
49+ hidden_state = model (** encoded_batch ).last_hidden_state
50+ all_embeddings .append (mean_pooling (hidden_state , encoded_batch ['attention_mask' ]))
51+
52+ all_embeddings = torch .cat (all_embeddings , dim = 0 )
53+ all_embeddings = F .normalize (all_embeddings , p = 2 , dim = 1 )
54+ print (f"elapsed { timer ()- start :.1f} s" )
55+
56+ if args .save_emb :
57+ filename = f"{ args .model .replace ('/' , '_' )} -{ date .today ().strftime ('%d.%m.%y' )} "
58+ np .savetxt (f"{ filename } -emb.tsv" , all_embeddings , delimiter = "\t " )
59+ import csv
60+ with open (f"{ filename } -meta.tsv" , 'w' , newline = '' ) as csvfile :
61+ w = csv .writer (csvfile , delimiter = '\t ' , quoting = csv .QUOTE_MINIMAL )
62+ w .writerow (["year" , "key" , "title" ])
63+ for paper in data :
64+ w .writerow ([paper ["year" ], paper ["key" ], paper ["title" ]])
4865
4966 np .random .seed (args .seed )
50- out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (corpus_embeddings )
67+ out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (all_embeddings )
5168
5269 for i , paper_info in enumerate (data ):
5370 paper_info ['tsne_embedding' ] = out [i ].tolist ()
5471
5572 with open (args .outpath , 'w' ) as f :
5673 json .dump (data , f )
74+
75+ if __name__ == "__main__" :
76+ args = parse_arguments ()
77+ main (args )
0 commit comments