@@ -17,10 +17,12 @@ def run_classification(
1717):
1818 from util .active_transfer_learning import ATLClassifier
1919
20+ print ("progress: 0.05" , flush = True )
2021 classifier = ATLClassifier ()
2122 prediction_probabilities = classifier .fit_predict (
2223 corpus_embeddings , corpus_labels , corpus_ids , training_ids
2324 )
25+ print ("progress: 0.8" , flush = True )
2426 if os .path .exists ("/inference" ):
2527 pickle_path = os .path .join (
2628 "/inference" , f"active-learner-{ information_source_id } .pkl"
@@ -36,6 +38,7 @@ def run_classification(
3638 prediction = classifier .model .classes_ [probas .argmax ()]
3739 predictions_with_probabilities .append ([proba , prediction ])
3840
41+ print ("progress: 0.9" , flush = True )
3942 ml_results_by_record_id = {}
4043 for record_id , (probability , prediction ) in zip (
4144 corpus_ids , predictions_with_probabilities
@@ -48,8 +51,12 @@ def run_classification(
4851 probability ,
4952 prediction ,
5053 )
54+ print ("progress: 0.95" , flush = True )
5155 if len (ml_results_by_record_id ) == 0 :
52- print ("No records were predicted. Try lowering the confidence threshold." )
56+ print (
57+ "No records were predicted. Try lowering the confidence threshold." ,
58+ flush = True ,
59+ )
5360 return ml_results_by_record_id
5461
5562
@@ -62,10 +69,12 @@ def run_extraction(
6269):
6370 from util .active_transfer_learning import ATLExtractor
6471
72+ print ("progress: 0.05" , flush = True )
6573 extractor = ATLExtractor ()
6674 predictions , probabilities = extractor .fit_predict (
6775 corpus_embeddings , corpus_labels , corpus_ids , training_ids
6876 )
77+ print ("progress: 0.5" , flush = True )
6978 if os .path .exists ("/inference" ):
7079 pickle_path = os .path .join (
7180 "/inference" , f"active-learner-{ information_source_id } .pkl"
@@ -75,8 +84,9 @@ def run_extraction(
7584 print ("Saved model to disk" , flush = True )
7685
7786 ml_results_by_record_id = {}
78- for record_id , prediction , probability in zip (
79- corpus_ids , predictions , probabilities
87+ amount = len (corpus_ids )
88+ for idx , (record_id , prediction , probability ) in enumerate (
89+ zip (corpus_ids , predictions , probabilities )
8090 ):
8191 df = pd .DataFrame (
8292 list (zip (prediction , probability )),
@@ -101,14 +111,22 @@ def run_extraction(
101111 )
102112 new_start_idx = True
103113 ml_results_by_record_id [record_id ] = predictions_with_probabilities
114+ if idx % 100 == 0 :
115+ progress = round ((idx + 1 ) / amount , 4 ) * 0.5 + 0.5
116+ print ("progress: " , progress , flush = True )
117+
118+ print ("progress: 0.9" , flush = True )
104119 if len (ml_results_by_record_id ) == 0 :
105- print ("No records were predicted. Try lowering the confidence threshold." )
120+ print (
121+ "No records were predicted. Try lowering the confidence threshold." ,
122+ flush = True ,
123+ )
106124 return ml_results_by_record_id
107125
108126
109127if __name__ == "__main__" :
110128 _ , payload_url = sys .argv
111- print ("Preparing data for machine learning." )
129+ print ("Preparing data for machine learning." , flush = True )
112130
113131 (
114132 information_source_id ,
@@ -120,7 +138,7 @@ def run_extraction(
120138 is_extractor = any ([isinstance (val , list ) for val in corpus_labels ["manual" ]])
121139
122140 if is_extractor :
123- print ("Running extractor." )
141+ print ("Running extractor." , flush = True )
124142 ml_results_by_record_id = run_extraction (
125143 information_source_id ,
126144 corpus_embeddings ,
@@ -129,7 +147,7 @@ def run_extraction(
129147 training_ids ,
130148 )
131149 else :
132- print ("Running classifier." )
150+ print ("Running classifier." , flush = True )
133151 ml_results_by_record_id = run_classification (
134152 information_source_id ,
135153 corpus_embeddings ,
@@ -138,5 +156,6 @@ def run_extraction(
138156 training_ids ,
139157 )
140158
141- print ("Finished execution." )
159+ print ("progress: 1" , flush = True )
160+ print ("Finished execution." , flush = True )
142161 requests .put (payload_url , json = ml_results_by_record_id )
0 commit comments