|
1 | 1 | import json |
2 | | -from dataclasses import dataclass |
| 2 | +from dataclasses import dataclass, asdict |
3 | 3 | from typing import List, Optional, Union |
4 | 4 |
|
5 | 5 | from mindee.input.polling_options import PollingOptions |
6 | 6 |
|
7 | 7 |
|
8 | | -class DataSchema: |
9 | | - """Modify the Data Schema.""" |
| 8 | +@dataclass |
| 9 | +class StringDataClass: |
| 10 | + """Base class for dataclasses that can be serialized to JSON.""" |
10 | 11 |
|
11 | | - _replace: Optional[dict] = None |
12 | | - |
13 | | - def __init__(self, replace: Optional[dict] = None): |
14 | | - self._replace = replace |
15 | | - |
16 | | - @property |
17 | | - def replace(self): |
18 | | - """If set, completely replaces the data schema of the model.""" |
19 | | - return self._replace |
20 | | - |
21 | | - @replace.setter |
22 | | - def replace(self, value: Optional[Union[dict, str]]) -> None: |
23 | | - if value is None: |
24 | | - _replace = None |
25 | | - elif isinstance(value, str): |
26 | | - _replace = json.loads(value) |
27 | | - elif isinstance(value, dict): |
28 | | - _replace = value |
29 | | - else: |
30 | | - raise TypeError("Invalid type for data schema") |
31 | | - if _replace is not None and _replace == {}: |
32 | | - raise ValueError("Empty override provided") |
33 | | - self._replace = _replace |
| 12 | + @staticmethod |
| 13 | + def _no_none_values(x) -> dict: |
| 14 | + """Don't include None values in the JSON output.""" |
| 15 | + return {k: v for (k, v) in x if v is not None} |
34 | 16 |
|
35 | 17 | def __str__(self) -> str: |
36 | | - return json.dumps({"replace": self.replace}) |
| 18 | + return json.dumps( |
| 19 | + asdict(self, dict_factory=self._no_none_values), indent=None, sort_keys=True |
| 20 | + ) |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class DataSchemaField(StringDataClass): |
| 25 | + """A field in the data schema.""" |
| 26 | + |
| 27 | + title: str |
| 28 | + """Display name for the field, also impacts inference results.""" |
| 29 | + name: str |
| 30 | + """Name of the field in the data schema.""" |
| 31 | + is_array: bool |
| 32 | + """Whether this field can contain multiple values.""" |
| 33 | + type: str |
| 34 | + """Data type of the field.""" |
| 35 | + classification_values: Optional[List[str]] = None |
| 36 | + """Allowed values when type is `classification`. Leave empty for other types.""" |
| 37 | + unique_values: Optional[bool] = None |
| 38 | + """ |
| 39 | + Whether to remove duplicate values in the array. |
| 40 | + Only applicable if `is_array` is True. |
| 41 | + """ |
| 42 | + description: Optional[str] = None |
| 43 | + """Detailed description of what this field represents.""" |
| 44 | + guidelines: Optional[str] = None |
| 45 | + """Optional extraction guidelines.""" |
| 46 | + nested_fields: Optional[dict] = None |
| 47 | + """Subfields when type is `nested_object`. Leave empty for other types""" |
| 48 | + |
| 49 | + |
| 50 | +@dataclass |
| 51 | +class DataSchemaReplace(StringDataClass): |
| 52 | + """The structure to completely replace the data schema of the model.""" |
| 53 | + |
| 54 | + fields: List[Union[DataSchemaField, dict]] |
| 55 | + |
| 56 | + def __post_init__(self) -> None: |
| 57 | + if not self.fields: |
| 58 | + raise ValueError("Data schema replacement fields cannot be empty.") |
| 59 | + if isinstance(self.fields[0], dict): |
| 60 | + self.fields = [ |
| 61 | + DataSchemaField(**field) # type: ignore[arg-type] |
| 62 | + for field in self.fields |
| 63 | + ] |
| 64 | + |
| 65 | + |
| 66 | +@dataclass |
| 67 | +class DataSchema(StringDataClass): |
| 68 | + """Modify the Data Schema.""" |
| 69 | + |
| 70 | + replace: Optional[Union[DataSchemaReplace, dict, str]] = None |
| 71 | + """If set, completely replaces the data schema of the model.""" |
| 72 | + |
| 73 | + def __post_init__(self) -> None: |
| 74 | + if isinstance(self.replace, dict): |
| 75 | + self.replace = DataSchemaReplace(**self.replace) |
| 76 | + elif isinstance(self.replace, str): |
| 77 | + self.replace = DataSchemaReplace(**json.loads(self.replace)) |
37 | 78 |
|
38 | 79 |
|
39 | 80 | @dataclass |
@@ -66,8 +107,14 @@ class InferenceParameters: |
66 | 107 | Additional text context used by the model during inference. |
67 | 108 | Not recommended, for specific use only. |
68 | 109 | """ |
69 | | - data_schema: Optional[DataSchema] = None |
| 110 | + data_schema: Optional[Union[DataSchema, str, dict]] = None |
70 | 111 | """ |
71 | 112 | Dynamic changes to the data schema of the model for this inference. |
72 | 113 | Not recommended, for specific use only. |
73 | 114 | """ |
| 115 | + |
| 116 | + def __post_init__(self): |
| 117 | + if isinstance(self.data_schema, str): |
| 118 | + self.data_schema = DataSchema(**json.loads(self.data_schema)) |
| 119 | + elif isinstance(self.data_schema, dict): |
| 120 | + self.data_schema = DataSchema(**self.data_schema) |
0 commit comments