-
Notifications
You must be signed in to change notification settings - Fork 736
Open
Labels
Description
Attempting to perform coreference resolution on a large dataframe containing English-language documents.
Description
Attempting to run SpanBertCoref() in a pipeline
document_assembler = DocumentAssembler().setInputCol("extracted_doc").setOutputCol("document")
sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl", "en").setInputCols(["document"]).setOutputCol("sentences")
tokenizer = Tokenizer().setInputCols(["sentences"]).setOutputCol("tokens")
corefResolution = SpanBertCorefModel().pretrained("spanbert_base_coref").setInputCols(["sentences", "tokens"]).setOutputCol("corefs")
# coref_pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer]) # works fine
coref_pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer, corefResolution])
coref_model = coref_pipeline.fit(df)
test_result = coref_model.transform(df)
display(test_result)
Above block downloads the models correctly but then errors out on the coreference resolution step
sentence_detector_dl download started this may take some time.
Approximate size to download 354.6 KB
[OK!]
spanbert_base_coref download started this may take some time.
Approximate size to download 540.1 MB
[OK!]
Error:
~/cluster-env/clonedenv/lib/python3.8/site-packages/notebookutils/visualization/display.py in display(data, summary)
138 log4jLogger\
139 .error(f"display failed with error, language: python, error: {err}")
--> 140 raise err
141
142 log4jLogger\
~/cluster-env/clonedenv/lib/python3.8/site-packages/notebookutils/visualization/display.py in display(data, summary)
118 from IPython.display import publish_display_data
119 publish_display_data({
--> 120 "application/vnd.synapse.display-widget+json": sc._jvm.display.getDisplayResultForIPython(df._jdf, summary)
121 })
122 else:
~/cluster-env/clonedenv/lib/python3.8/site-packages/py4j/java_gateway.py in __call__(self, *args)
1302
1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
1305 answer, self.gateway_client, self.target_id, self.name)
1306
/opt/spark/python/lib/pyspark.zip/pyspark/sql/utils.py in deco(*a, **kw)
109 def deco(*a, **kw):
110 try:
--> 111 return f(*a, **kw)
112 except py4j.protocol.Py4JJavaError as e:
113 converted = convert_exception(e.java_exception)
~/cluster-env/clonedenv/lib/python3.8/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling z:com.microsoft.spark.notebook.visualization.display.getDisplayResultForIPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 39.0 failed 4 times, most recent failure: Lost task 0.3 in stage 39.0 (TID 148) (vm-e5099158 executor 1): org.apache.spark.SparkException: Failed to execute user defined function(HasSimpleAnnotate$$Lambda$6097/176097372: (array<array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>>) => array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:762)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:383)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:905)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:905)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:57)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:131)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IndexOutOfBoundsException: 1
at scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46)
at scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45)
at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1(TensorflowSpanBertCoref.scala:52)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1$adapted(TensorflowSpanBertCoref.scala:34)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.predict(TensorflowSpanBertCoref.scala:34)
at com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel.annotate(SpanBertCorefModel.scala:327)
at com.johnsnowlabs.nlp.HasSimpleAnnotate.$anonfun$dfAnnotate$1(HasSimpleAnnotate.scala:46)
... 17 more
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2313)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2262)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2261)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2261)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1132)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1132)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1132)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2500)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2442)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2431)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:908)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2301)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2322)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2341)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:510)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:463)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47)
at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3709)
at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2735)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3700)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:107)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:181)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:94)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:68)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3698)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2735)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2942)
at org.apache.spark.sql.GetRowsHelper$.getRowsInJsonString(GetRowsHelper.scala:51)
at com.microsoft.spark.notebook.visualization.display$.generateTableConfig(Display.scala:454)
at com.microsoft.spark.notebook.visualization.display$.exec(Display.scala:189)
at com.microsoft.spark.notebook.visualization.display$.getDisplayResultInternal(Display.scala:139)
at com.microsoft.spark.notebook.visualization.display$.getDisplayResultForIPython(Display.scala:80)
at com.microsoft.spark.notebook.visualization.display.getDisplayResultForIPython(Display.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function(HasSimpleAnnotate$$Lambda$6097/176097372: (array<array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>>) => array<struct<annotatorType:string,begin:int,end:int,result:string,metadata:map<string,string>,embeddings:array<float>>>)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:762)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:383)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:905)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:905)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:57)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:131)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
Caused by: java.lang.IndexOutOfBoundsException: 1
at scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46)
at scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45)
at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1(TensorflowSpanBertCoref.scala:52)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.$anonfun$predict$1$adapted(TensorflowSpanBertCoref.scala:34)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at com.johnsnowlabs.ml.tensorflow.TensorflowSpanBertCoref.predict(TensorflowSpanBertCoref.scala:34)
at com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel.annotate(SpanBertCorefModel.scala:327)
at com.johnsnowlabs.nlp.HasSimpleAnnotate.$anonfun$dfAnnotate$1(HasSimpleAnnotate.scala:46)
... 17 more
Possible Solution
The below demo code works properly. Are there potentially issues in the tokens or sentences being detected? Some of the documents can have some artifacts as they may be translated into English from an earlier preprocessing step.
data = spark.createDataFrame([["John told Mary he would like to borrow a book from her."]]).toDF("text")
document_assembler = DocumentAssembler().setInputCol("text").setOutputCol("document")
sentence_detector = SentenceDetector().setInputCols(["document"]).setOutputCol("sentences")
tokenizer = Tokenizer().setInputCols(["sentences"]).setOutputCol("tokens")
corefResolution = SpanBertCorefModel().pretrained("spanbert_base_coref").setInputCols(["sentences", "tokens"]).setOutputCol("corefs")
pipeline = Pipeline(stages=[document_assembler, sentence_detector, tokenizer, corefResolution])
model = pipeline.fit(self.data)
model.transform(self.data).selectExpr("explode(corefs) AS coref").selectExpr("coref.result as token", "coref.metadata").show(truncate=False)
Context
Attempting to use a spark-based coreference resolution as other libraries are not compatible with my spark environment or are slow and error out.
Your Environment
- Spark NLP version
sparknlp.version(): 4.0.2 - Apache NLP version
spark.version: 3.1.2.5.0-66290225 - Java version
java -version: 1.8.0_282 - Setup and installation (Pypi, Conda, Maven, etc.): PyPi
- Operating System and version: Microsoft Azure Synapse