Skip to content

Commit 1408aa3

Browse files
ahmedlone127DevinTDHa
authored andcommitted
introducing perfered engine logic
1 parent 5a43dfc commit 1408aa3

File tree

139 files changed

+378
-302
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+378
-302
lines changed

python/sparknlp/annotator/audio/hubert_for_ctc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def loadSavedModel(folder, spark_session):
165165
return HubertForCTC(java_model=jModel)
166166

167167
@staticmethod
168-
def pretrained(name="asr_hubert_large_ls960", lang="en", remote_loc=None):
168+
def pretrained(name="asr_hubert_large_ls960", lang="en", remote_loc=None, engine ="onnx"):
169169
"""Downloads and loads a pretrained model.
170170
171171
Parameters
@@ -185,4 +185,4 @@ def pretrained(name="asr_hubert_large_ls960", lang="en", remote_loc=None):
185185
The restored model
186186
"""
187187
from sparknlp.pretrained import ResourceDownloader
188-
return ResourceDownloader.downloadModel(HubertForCTC, name, lang, remote_loc)
188+
return ResourceDownloader.downloadModel(HubertForCTC, name, lang, remote_loc, engine)

python/sparknlp/annotator/audio/wav2vec2_for_ctc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def loadSavedModel(folder, spark_session):
138138
return Wav2Vec2ForCTC(java_model=jModel)
139139

140140
@staticmethod
141-
def pretrained(name="asr_wav2vec2_base_960h", lang="en", remote_loc=None):
141+
def pretrained(name="asr_wav2vec2_base_960h", lang="en", remote_loc=None,engine ="onnx"):
142142
"""Downloads and loads a pretrained model.
143143
144144
Parameters
@@ -158,4 +158,4 @@ def pretrained(name="asr_wav2vec2_base_960h", lang="en", remote_loc=None):
158158
The restored model
159159
"""
160160
from sparknlp.pretrained import ResourceDownloader
161-
return ResourceDownloader.downloadModel(Wav2Vec2ForCTC, name, lang, remote_loc)
161+
return ResourceDownloader.downloadModel(Wav2Vec2ForCTC, name, lang, remote_loc,engine )

python/sparknlp/annotator/audio/whisper_for_ctc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def loadSavedModel(folder, spark_session):
228228
return WhisperForCTC(java_model=jModel)
229229

230230
@staticmethod
231-
def pretrained(name="asr_whisper_tiny_opt", lang="xx", remote_loc=None):
231+
def pretrained(name="asr_whisper_tiny_opt", lang="xx", remote_loc=None,engine ="onnx"):
232232
"""Downloads and loads a pretrained model.
233233
234234
Parameters
@@ -248,4 +248,4 @@ def pretrained(name="asr_whisper_tiny_opt", lang="xx", remote_loc=None):
248248
The restored model
249249
"""
250250
from sparknlp.pretrained import ResourceDownloader
251-
return ResourceDownloader.downloadModel(WhisperForCTC, name, lang, remote_loc)
251+
return ResourceDownloader.downloadModel(WhisperForCTC, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/albert_for_multiple_choice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def loadSavedModel(folder, spark_session):
138138
return AlbertForMultipleChoice(java_model=jModel)
139139

140140
@staticmethod
141-
def pretrained(name="albert_base_uncased_multiple_choice", lang="en", remote_loc=None):
141+
def pretrained(name="albert_base_uncased_multiple_choice", lang="en", remote_loc=None,engine ="onnx"):
142142
"""Downloads and loads a pretrained model.
143143
144144
Parameters
@@ -158,4 +158,4 @@ def pretrained(name="albert_base_uncased_multiple_choice", lang="en", remote_loc
158158
The restored model
159159
"""
160160
from sparknlp.pretrained import ResourceDownloader
161-
return ResourceDownloader.downloadModel(AlbertForMultipleChoice, name, lang, remote_loc)
161+
return ResourceDownloader.downloadModel(AlbertForMultipleChoice, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/albert_for_question_answering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def loadSavedModel(folder, spark_session):
149149
return AlbertForQuestionAnswering(java_model=jModel)
150150

151151
@staticmethod
152-
def pretrained(name="albert_base_qa_squad2", lang="en", remote_loc=None):
152+
def pretrained(name="albert_base_qa_squad2", lang="en", remote_loc=None,engine ="onnx"):
153153
"""Downloads and loads a pretrained model.
154154
155155
Parameters
@@ -169,4 +169,4 @@ def pretrained(name="albert_base_qa_squad2", lang="en", remote_loc=None):
169169
The restored model
170170
"""
171171
from sparknlp.pretrained import ResourceDownloader
172-
return ResourceDownloader.downloadModel(AlbertForQuestionAnswering, name, lang, remote_loc)
172+
return ResourceDownloader.downloadModel(AlbertForQuestionAnswering, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/albert_for_sequence_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def loadSavedModel(folder, spark_session):
178178
return AlbertForSequenceClassification(java_model=jModel)
179179

180180
@staticmethod
181-
def pretrained(name="albert_base_sequence_classifier_imdb", lang="en", remote_loc=None):
181+
def pretrained(name="albert_base_sequence_classifier_imdb", lang="en", remote_loc=None,engine ="onnx"):
182182
"""Downloads and loads a pretrained model.
183183
184184
Parameters
@@ -198,4 +198,4 @@ def pretrained(name="albert_base_sequence_classifier_imdb", lang="en", remote_lo
198198
The restored model
199199
"""
200200
from sparknlp.pretrained import ResourceDownloader
201-
return ResourceDownloader.downloadModel(AlbertForSequenceClassification, name, lang, remote_loc)
201+
return ResourceDownloader.downloadModel(AlbertForSequenceClassification, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/albert_for_token_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def loadSavedModel(folder, spark_session):
156156
return AlbertForTokenClassification(java_model=jModel)
157157

158158
@staticmethod
159-
def pretrained(name="albert_base_token_classifier_conll03", lang="en", remote_loc=None):
159+
def pretrained(name="albert_base_token_classifier_conll03", lang="en", remote_loc=None,engine ="onnx"):
160160
"""Downloads and loads a pretrained model.
161161
162162
Parameters
@@ -176,4 +176,4 @@ def pretrained(name="albert_base_token_classifier_conll03", lang="en", remote_lo
176176
The restored model
177177
"""
178178
from sparknlp.pretrained import ResourceDownloader
179-
return ResourceDownloader.downloadModel(AlbertForTokenClassification, name, lang, remote_loc)
179+
return ResourceDownloader.downloadModel(AlbertForTokenClassification, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def loadSavedModel(folder, spark_session):
188188
return AlbertForZeroShotClassification(java_model=jModel)
189189

190190
@staticmethod
191-
def pretrained(name="albert_zero_shot_classifier_onnx", lang="en", remote_loc=None):
191+
def pretrained(name="albert_zero_shot_classifier_onnx", lang="en", remote_loc=None,engine ="onnx"):
192192
"""Downloads and loads a pretrained model.
193193
194194
Parameters
@@ -208,4 +208,4 @@ def pretrained(name="albert_zero_shot_classifier_onnx", lang="en", remote_loc=No
208208
The restored model
209209
"""
210210
from sparknlp.pretrained import ResourceDownloader
211-
return ResourceDownloader.downloadModel(AlbertForZeroShotClassification, name, lang, remote_loc)
211+
return ResourceDownloader.downloadModel(AlbertForZeroShotClassification, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/bart_for_zero_shot_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def loadSavedModel(folder, spark_session):
202202
return BartForZeroShotClassification(java_model=jModel)
203203

204204
@staticmethod
205-
def pretrained(name="bart_large_zero_shot_classifier_mnli", lang="en", remote_loc=None):
205+
def pretrained(name="bart_large_zero_shot_classifier_mnli", lang="en", remote_loc=None,engine ="onnx"):
206206
"""Downloads and loads a pretrained model.
207207
208208
Parameters
@@ -222,4 +222,4 @@ def pretrained(name="bart_large_zero_shot_classifier_mnli", lang="en", remote_lo
222222
The restored model
223223
"""
224224
from sparknlp.pretrained import ResourceDownloader
225-
return ResourceDownloader.downloadModel(BartForZeroShotClassification, name, lang, remote_loc)
225+
return ResourceDownloader.downloadModel(BartForZeroShotClassification, name, lang, remote_loc,engine )

python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def loadSavedModel(folder, spark_session):
138138
return BertForMultipleChoice(java_model=jModel)
139139

140140
@staticmethod
141-
def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None):
141+
def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=None,engine ="onnx"):
142142
"""Downloads and loads a pretrained model.
143143
144144
Parameters
@@ -158,4 +158,4 @@ def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=N
158158
The restored model
159159
"""
160160
from sparknlp.pretrained import ResourceDownloader
161-
return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc)
161+
return ResourceDownloader.downloadModel(BertForMultipleChoice, name, lang, remote_loc,engine )

0 commit comments

Comments
 (0)