Skip to content

Commit cd6ba8f

Browse files
authored
Implementing inference api (#15)
* save active learner as pickle * change structure and use relative imports * changes scikit-learn version * write activelearner to pickle for extraction tasks * save activelearner only if is_managed * renames pickle file * pr comments
1 parent c1617c4 commit cd6ba8f

File tree

7 files changed

+73
-21
lines changed

7 files changed

+73
-21
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ s3transfer==0.6.0
107107
# via
108108
# -r requirements/torch-cpu-requirements.txt
109109
# boto3
110-
scikit-learn==1.0.2
110+
scikit-learn==1.1.2
111111
# via
112112
# -r requirements/requirements.in
113113
# sequencelearn

requirements/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
-r torch-cpu-requirements.txt
2-
scikit-learn==1.0.2
2+
scikit-learn==1.1.2
33
scipy==1.9.0
44
sequencelearn==0.0.9

run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

33
/usr/bin/curl -s "$1" > input.json;
4-
/usr/bin/curl -s echo "$2" >> active_transfer_learning.py;
4+
/usr/bin/curl -s echo "$2" >> util/active_transfer_learning.py;
55
/usr/bin/curl -s "$3" > embedding.csv.bz2;
66
/usr/local/bin/python run_ml.py "$4";

run_ml.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
#!/usr/bin/env python3
2+
import os
23
import sys
3-
import util
4+
from util import util
45
import requests
5-
from collections import defaultdict
66
import pandas as pd
7+
import pickle
8+
from typing import List, Dict, Tuple, Any
79

8-
CONSTANT__OUTSIDE = "OUTSIDE" # enum from graphql-gateway; if it changes, the extraction service breaks!
910

10-
11-
def run_classification(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
12-
from active_transfer_learning import ATLClassifier
11+
def run_classification(
12+
information_source_id: str,
13+
corpus_embeddings: Dict[str, List[List[float]]],
14+
corpus_labels: List[str],
15+
corpus_ids: List[str],
16+
training_ids: List[str],
17+
):
18+
from util.active_transfer_learning import ATLClassifier
1319

1420
classifier = ATLClassifier()
1521
prediction_probabilities = classifier.fit_predict(
1622
corpus_embeddings, corpus_labels, corpus_ids, training_ids
1723
)
24+
if os.path.exists("/inference"):
25+
pickle_path = os.path.join(
26+
"/inference", f"active-learner-{information_source_id}.pkl"
27+
)
28+
with open(pickle_path, "wb") as f:
29+
pickle.dump(classifier, f)
30+
print("Saved model to disk", flush=True)
1831

1932
prediction_indices = prediction_probabilities.argmax(axis=1)
2033
predictions_with_probabilities = []
@@ -40,15 +53,31 @@ def run_classification(corpus_embeddings, corpus_labels, corpus_ids, training_id
4053
return ml_results_by_record_id
4154

4255

43-
def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
44-
from active_transfer_learning import ATLExtractor
56+
def run_extraction(
57+
information_source_id: str,
58+
corpus_embeddings: Dict[str, List[Any]],
59+
corpus_labels: List[Tuple[str, str, List[Any]]],
60+
corpus_ids: List[str],
61+
training_ids: List[str],
62+
):
63+
from util.active_transfer_learning import ATLExtractor
4564

4665
extractor = ATLExtractor()
4766
predictions, probabilities = extractor.fit_predict(
4867
corpus_embeddings, corpus_labels, corpus_ids, training_ids
4968
)
69+
if os.path.exists("/inference"):
70+
pickle_path = os.path.join(
71+
"/inference", f"active-learner-{information_source_id}.pkl"
72+
)
73+
with open(pickle_path, "wb") as f:
74+
pickle.dump(extractor, f)
75+
print("Saved model to disk", flush=True)
76+
5077
ml_results_by_record_id = {}
51-
for record_id, prediction, probability in zip(corpus_ids, predictions, probabilities):
78+
for record_id, prediction, probability in zip(
79+
corpus_ids, predictions, probabilities
80+
):
5281
df = pd.DataFrame(
5382
list(zip(prediction, probability)),
5483
columns=["prediction", "probability"],
@@ -57,7 +86,7 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
5786
predictions_with_probabilities = []
5887
new_start_idx = True
5988
for idx, row in df.loc[
60-
(df.prediction != CONSTANT__OUTSIDE)
89+
(df.prediction != util.CONSTANT__OUTSIDE)
6190
& (df.prediction.isin(extractor.label_names))
6291
& (df.probability > extractor.min_confidence)
6392
].iterrows():
@@ -67,7 +96,9 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
6796
if row.prediction != row.next:
6897
prob = df.loc[start_idx:idx].probability.mean()
6998
end_idx = idx + 1
70-
predictions_with_probabilities.append([float(prob), row.prediction, start_idx, end_idx])
99+
predictions_with_probabilities.append(
100+
[float(prob), row.prediction, start_idx, end_idx]
101+
)
71102
new_start_idx = True
72103
ml_results_by_record_id[record_id] = predictions_with_probabilities
73104
if len(ml_results_by_record_id) == 0:
@@ -79,18 +110,32 @@ def run_extraction(corpus_embeddings, corpus_labels, corpus_ids, training_ids):
79110
_, payload_url = sys.argv
80111
print("Preparing data for machine learning.")
81112

82-
corpus_embeddings, corpus_labels, corpus_ids, training_ids = util.get_corpus()
113+
(
114+
information_source_id,
115+
corpus_embeddings,
116+
corpus_labels,
117+
corpus_ids,
118+
training_ids,
119+
) = util.get_corpus()
83120
is_extractor = any([isinstance(val, list) for val in corpus_labels["manual"]])
84121

85122
if is_extractor:
86123
print("Running extractor.")
87124
ml_results_by_record_id = run_extraction(
88-
corpus_embeddings, corpus_labels, corpus_ids, training_ids
125+
information_source_id,
126+
corpus_embeddings,
127+
corpus_labels,
128+
corpus_ids,
129+
training_ids,
89130
)
90131
else:
91132
print("Running classifier.")
92133
ml_results_by_record_id = run_classification(
93-
corpus_embeddings, corpus_labels, corpus_ids, training_ids
134+
information_source_id,
135+
corpus_embeddings,
136+
corpus_labels,
137+
corpus_ids,
138+
training_ids,
94139
)
95140

96141
print("Finished execution.")

util/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
import util
2+
from . import util
33
from typing import Callable, List, Optional
44

55

util.py renamed to util/util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pandas as pd
44

5-
from run_ml import CONSTANT__OUTSIDE
5+
CONSTANT__OUTSIDE = "OUTSIDE" # enum from graphql-gateway; if it changes, the extraction service breaks!
66

77
pd.options.mode.chained_assignment = None # default='warn'
88

@@ -20,6 +20,7 @@
2020
def get_corpus():
2121
with open("input.json", "r") as infile:
2222
input_data = json.load(infile)
23+
information_source_id = input_data["information_source_id"]
2324
embedding_type = input_data["embedding_type"]
2425
embedding_name = input_data["embedding_name"]
2526
labels = input_data["labels"]
@@ -45,10 +46,16 @@ def get_corpus():
4546
if x != "data"
4647
]
4748
}
48-
except:
49+
except Exception:
4950
print("Can't parse the embedding. Please contact the support.")
5051
raise ValueError("Can't parse the embedding. Please contact the support.")
51-
return embeddings, labels, ids, training_ids
52+
return (
53+
information_source_id,
54+
embeddings,
55+
labels,
56+
ids,
57+
training_ids,
58+
)
5259

5360

5461
def transform_corpus_classification_inference(embeddings):

0 commit comments

Comments
 (0)