|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | import importlib |
| 11 | +import logging |
11 | 12 | import subprocess |
12 | 13 | import sys |
| 14 | +import traceback |
| 15 | +from typing import List |
13 | 16 |
|
14 | | -PII_ANNOTATION_LABELS = ["DATE_TIME", "LOC", "NRP", "ORG", "PER"] |
15 | | -MAXIMAL_STRING_SIZE = 1000000 |
16 | | - |
17 | | - |
18 | | -def pii_annotator(text: str, broadcasted_nlp) -> list[list[str]]: |
19 | | - """Extract features using en_core_web_lg model. |
20 | | -
|
21 | | - Returns: |
22 | | - list[list[str]]: Values as arrays in order defined in the PII_ANNOTATION_LABELS. |
23 | | - """ |
24 | | - ensure_installed("pyspark") |
25 | | - ensure_installed("spacy") |
| 17 | +try: |
26 | 18 | import spacy |
| 19 | +except ImportError: |
| 20 | + print("Spacy not found. Please install it: pip install spacy") |
| 21 | + print("and download the model: python -m spacy download en_core_web_lg") |
| 22 | + spacy = None |
| 23 | + traceback.print_exc() |
| 24 | + sys.exit(1) |
| 25 | + |
| 26 | +try: |
27 | 27 | from pyspark.sql import SparkSession |
28 | 28 | from pyspark.sql.functions import udf |
29 | | - from pyspark.sql.types import ArrayType, StringType, StructField, StructType |
| 29 | + from pyspark.sql.types import ArrayType, StringType |
| 30 | +except ImportError: |
| 31 | + print( |
| 32 | + "PySpark not found. Please install it with the [spark] extra: pip install 'datafog[spark]'" |
| 33 | + ) |
| 34 | + |
| 35 | + # Set placeholders to allow module import even if pyspark is not installed |
| 36 | + def placeholder_udf(*args, **kwargs): |
| 37 | + return None |
| 38 | + |
| 39 | + def placeholder_arraytype(x): |
| 40 | + return None |
30 | 41 |
|
31 | | - if text: |
32 | | - if len(text) > MAXIMAL_STRING_SIZE: |
33 | | - # Cut the strings for required sizes |
34 | | - text = text[:MAXIMAL_STRING_SIZE] |
35 | | - nlp = broadcasted_nlp.value |
36 | | - doc = nlp(text) |
| 42 | + def placeholder_stringtype(): |
| 43 | + return None |
37 | 44 |
|
38 | | - # Pre-create dictionary with labels matching to expected extracted entities |
39 | | - classified_entities: dict[str, list[str]] = { |
40 | | - _label: [] for _label in PII_ANNOTATION_LABELS |
41 | | - } |
42 | | - for ent in doc.ents: |
43 | | - # Add entities from extracted values |
44 | | - classified_entities[ent.label_].append(ent.text) |
| 45 | + udf = placeholder_udf |
| 46 | + ArrayType = placeholder_arraytype |
| 47 | + StringType = placeholder_stringtype |
| 48 | + SparkSession = None # Define a placeholder |
| 49 | + traceback.print_exc() |
| 50 | + # Do not exit, allow basic import but functions using Spark will fail later if called |
45 | 51 |
|
46 | | - return [_ent for _ent in classified_entities.values()] |
47 | | - else: |
48 | | - return [[] for _ in PII_ANNOTATION_LABELS] |
| 52 | +from datafog.processing.text_processing.spacy_pii_annotator import pii_annotator |
| 53 | + |
| 54 | +PII_ANNOTATION_LABELS = ["DATE_TIME", "LOC", "NRP", "ORG", "PER"] |
| 55 | +MAXIMAL_STRING_SIZE = 1000000 |
49 | 56 |
|
50 | 57 |
|
51 | 58 | def broadcast_pii_annotator_udf( |
52 | 59 | spark_session=None, spacy_model: str = "en_core_web_lg" |
53 | 60 | ): |
54 | 61 | """Broadcast PII annotator across Spark cluster and create UDF""" |
55 | | - ensure_installed("pyspark") |
56 | | - ensure_installed("spacy") |
57 | | - import spacy |
58 | | - from pyspark.sql import SparkSession |
59 | | - from pyspark.sql.functions import udf |
60 | | - from pyspark.sql.types import ArrayType, StringType, StructField, StructType |
61 | | - |
62 | 62 | if not spark_session: |
63 | | - spark_session = SparkSession.builder.getOrCreate() |
| 63 | + spark_session = SparkSession.builder.getOrCreate() # noqa: F821 |
64 | 64 | broadcasted_nlp = spark_session.sparkContext.broadcast(spacy.load(spacy_model)) |
65 | 65 |
|
66 | 66 | pii_annotation_udf = udf( |
67 | 67 | lambda text: pii_annotator(text, broadcasted_nlp), |
68 | 68 | ArrayType(ArrayType(StringType())), |
69 | 69 | ) |
70 | 70 | return pii_annotation_udf |
71 | | - |
72 | | - |
73 | | -def ensure_installed(self, package_name): |
74 | | - try: |
75 | | - importlib.import_module(package_name) |
76 | | - except ImportError: |
77 | | - subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) |
0 commit comments