11#!/usr/bin/env python3
2+ import os
23import sys
3- import util
4+ from util import util
45import requests
5- from collections import defaultdict
66import pandas as pd
7+ import pickle
8+ from typing import List , Dict , Tuple , Any
79
8- CONSTANT__OUTSIDE = "OUTSIDE" # enum from graphql-gateway; if it changes, the extraction service breaks!
910
10-
11- def run_classification (corpus_embeddings , corpus_labels , corpus_ids , training_ids ):
12- from active_transfer_learning import ATLClassifier
11+ def run_classification (
12+ information_source_id : str ,
13+ corpus_embeddings : Dict [str , List [List [float ]]],
14+ corpus_labels : List [str ],
15+ corpus_ids : List [str ],
16+ training_ids : List [str ],
17+ ):
18+ from util .active_transfer_learning import ATLClassifier
1319
1420 classifier = ATLClassifier ()
1521 prediction_probabilities = classifier .fit_predict (
1622 corpus_embeddings , corpus_labels , corpus_ids , training_ids
1723 )
24+ if os .path .exists ("/inference" ):
25+ pickle_path = os .path .join (
26+ "/inference" , f"active-learner-{ information_source_id } .pkl"
27+ )
28+ with open (pickle_path , "wb" ) as f :
29+ pickle .dump (classifier , f )
30+ print ("Saved model to disk" , flush = True )
1831
1932 prediction_indices = prediction_probabilities .argmax (axis = 1 )
2033 predictions_with_probabilities = []
@@ -40,15 +53,31 @@ def run_classification(corpus_embeddings, corpus_labels, corpus_ids, training_id
4053 return ml_results_by_record_id
4154
4255
43- def run_extraction (corpus_embeddings , corpus_labels , corpus_ids , training_ids ):
44- from active_transfer_learning import ATLExtractor
56+ def run_extraction (
57+ information_source_id : str ,
58+ corpus_embeddings : Dict [str , List [Any ]],
59+ corpus_labels : List [Tuple [str , str , List [Any ]]],
60+ corpus_ids : List [str ],
61+ training_ids : List [str ],
62+ ):
63+ from util .active_transfer_learning import ATLExtractor
4564
4665 extractor = ATLExtractor ()
4766 predictions , probabilities = extractor .fit_predict (
4867 corpus_embeddings , corpus_labels , corpus_ids , training_ids
4968 )
69+ if os .path .exists ("/inference" ):
70+ pickle_path = os .path .join (
71+ "/inference" , f"active-learner-{ information_source_id } .pkl"
72+ )
73+ with open (pickle_path , "wb" ) as f :
74+ pickle .dump (extractor , f )
75+ print ("Saved model to disk" , flush = True )
76+
5077 ml_results_by_record_id = {}
51- for record_id , prediction , probability in zip (corpus_ids , predictions , probabilities ):
78+ for record_id , prediction , probability in zip (
79+ corpus_ids , predictions , probabilities
80+ ):
5281 df = pd .DataFrame (
5382 list (zip (prediction , probability )),
5483 columns = ["prediction" , "probability" ],
@@ -57,7 +86,7 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
5786 predictions_with_probabilities = []
5887 new_start_idx = True
5988 for idx , row in df .loc [
60- (df .prediction != CONSTANT__OUTSIDE )
89+ (df .prediction != util . CONSTANT__OUTSIDE )
6190 & (df .prediction .isin (extractor .label_names ))
6291 & (df .probability > extractor .min_confidence )
6392 ].iterrows ():
@@ -67,7 +96,9 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
6796 if row .prediction != row .next :
6897 prob = df .loc [start_idx :idx ].probability .mean ()
6998 end_idx = idx + 1
70- predictions_with_probabilities .append ([float (prob ), row .prediction , start_idx , end_idx ])
99+ predictions_with_probabilities .append (
100+ [float (prob ), row .prediction , start_idx , end_idx ]
101+ )
71102 new_start_idx = True
72103 ml_results_by_record_id [record_id ] = predictions_with_probabilities
73104 if len (ml_results_by_record_id ) == 0 :
@@ -79,18 +110,32 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
79110 _ , payload_url = sys .argv
80111 print ("Preparing data for machine learning." )
81112
82- corpus_embeddings , corpus_labels , corpus_ids , training_ids = util .get_corpus ()
113+ (
114+ information_source_id ,
115+ corpus_embeddings ,
116+ corpus_labels ,
117+ corpus_ids ,
118+ training_ids ,
119+ ) = util .get_corpus ()
83120 is_extractor = any ([isinstance (val , list ) for val in corpus_labels ["manual" ]])
84121
85122 if is_extractor :
86123 print ("Running extractor." )
87124 ml_results_by_record_id = run_extraction (
88- corpus_embeddings , corpus_labels , corpus_ids , training_ids
125+ information_source_id ,
126+ corpus_embeddings ,
127+ corpus_labels ,
128+ corpus_ids ,
129+ training_ids ,
89130 )
90131 else :
91132 print ("Running classifier." )
92133 ml_results_by_record_id = run_classification (
93- corpus_embeddings , corpus_labels , corpus_ids , training_ids
134+ information_source_id ,
135+ corpus_embeddings ,
136+ corpus_labels ,
137+ corpus_ids ,
138+ training_ids ,
94139 )
95140
96141 print ("Finished execution." )
0 commit comments