Skip to content

Commit 2179d3a

Browse files
✨ add support for split
1 parent c749640 commit 2179d3a

File tree

17 files changed

+303
-74
lines changed

17 files changed

+303
-74
lines changed

mindee/client_v2.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from time import sleep
2-
from typing import Optional, Union
2+
from typing import Optional, Union, Type
33

44
from mindee.client_mixin import ClientMixin
55
from mindee.error.mindee_error import MindeeError
66
from mindee.error.mindee_http_error_v2 import handle_error_v2
7-
from mindee.input import UrlInputSource
7+
from mindee.input import UrlInputSource, UtilityParameters
88
from mindee.input.inference_parameters import InferenceParameters
99
from mindee.input.polling_options import PollingOptions
1010
from mindee.input.sources.local_input_source import LocalInputSource
@@ -15,6 +15,7 @@
1515
is_valid_post_response,
1616
)
1717
from mindee.parsing.v2.common_response import CommonStatus
18+
from mindee.v2 import BaseInferenceResponse
1819
from mindee.parsing.v2.inference_response import InferenceResponse
1920
from mindee.parsing.v2.job_response import JobResponse
2021

@@ -41,20 +42,21 @@ def __init__(self, api_key: Optional[str] = None) -> None:
4142
def enqueue_inference(
4243
self,
4344
input_source: Union[LocalInputSource, UrlInputSource],
44-
params: InferenceParameters,
45+
params: Union[InferenceParameters, UtilityParameters],
46+
slug: Optional[str] = None,
4547
) -> JobResponse:
4648
"""
4749
Enqueues a document to a given model.
4850
4951
:param input_source: The document/source file to use. Can be local or remote.
50-
5152
:param params: Parameters to set when sending a file.
53+
:param slug: Slug for the endpoint.
54+
5255
:return: A valid inference response.
5356
"""
5457
logger.debug("Enqueuing inference using model: %s", params.model_id)
55-
5658
response = self.mindee_api.req_post_inference_enqueue(
57-
input_source=input_source, params=params
59+
input_source=input_source, params=params, slug=slug
5860
)
5961
dict_response = response.json()
6062

@@ -79,13 +81,18 @@ def get_job(self, job_id: str) -> JobResponse:
7981
dict_response = response.json()
8082
return JobResponse(dict_response)
8183

82-
def get_inference(self, inference_id: str) -> InferenceResponse:
84+
def get_inference(
85+
self,
86+
inference_id: str,
87+
inference_response_type: Type[InferenceResponse] = InferenceResponse,
88+
) -> BaseInferenceResponse:
8389
"""
8490
Get the result of an inference that was previously enqueued.
8591
8692
The inference will only be available after it has finished processing.
8793
8894
:param inference_id: UUID of the inference to retrieve.
95+
:param inference_response_type: Class of the product to instantiate.
8996
:return: An inference response.
9097
"""
9198
logger.debug("Fetching inference: %s", inference_id)
@@ -94,19 +101,20 @@ def get_inference(self, inference_id: str) -> InferenceResponse:
94101
if not is_valid_get_response(response):
95102
handle_error_v2(response.json())
96103
dict_response = response.json()
97-
return InferenceResponse(dict_response)
104+
return inference_response_type(dict_response)
98105

99-
def enqueue_and_get_inference(
106+
def _enqueue_and_get(
100107
self,
101108
input_source: Union[LocalInputSource, UrlInputSource],
102-
params: InferenceParameters,
109+
params: Union[InferenceParameters, UtilityParameters],
110+
inference_response_type: Optional[Type[InferenceResponse]] = InferenceResponse,
103111
) -> InferenceResponse:
104112
"""
105113
Enqueues to an asynchronous endpoint and automatically polls for a response.
106114
107115
:param input_source: The document/source file to use. Can be local or remote.
108-
109116
:param params: Parameters to set when sending a file.
117+
:param inference_response_type: The product class to use for the response object.
110118
111119
:return: A valid inference response.
112120
"""
@@ -117,9 +125,14 @@ def enqueue_and_get_inference(
117125
params.polling_options.delay_sec,
118126
params.polling_options.max_retries,
119127
)
120-
enqueue_response = self.enqueue_inference(input_source, params)
128+
slug = (
129+
inference_response_type.inference.get_slug()
130+
if inference_response_type
131+
else None
132+
)
133+
enqueue_response = self.enqueue_inference(input_source, params, slug)
121134
logger.debug(
122-
"Successfully enqueued inference with job id: %s", enqueue_response.job.id
135+
"Successfully enqueued document with job id: %s", enqueue_response.job.id
123136
)
124137
sleep(params.polling_options.initial_delay_sec)
125138
try_counter = 0
@@ -134,8 +147,57 @@ def enqueue_and_get_inference(
134147
f"Parsing failed for job {job_response.job.id}: {detail}"
135148
)
136149
if job_response.job.status == CommonStatus.PROCESSED.value:
137-
return self.get_inference(job_response.job.id)
150+
result = self.get_inference(
151+
job_response.job.id, inference_response_type or InferenceResponse
152+
)
153+
assert isinstance(result, InferenceResponse), (
154+
f'Invalid response type "{type(result)}"'
155+
)
156+
return result
138157
try_counter += 1
139158
sleep(params.polling_options.delay_sec)
140159

141160
raise MindeeError(f"Couldn't retrieve document after {try_counter + 1} tries.")
161+
162+
def enqueue_and_get_inference(
163+
self,
164+
input_source: Union[LocalInputSource, UrlInputSource],
165+
params: InferenceParameters,
166+
) -> InferenceResponse:
167+
"""
168+
Enqueues to an asynchronous endpoint and automatically polls for a response.
169+
170+
:param input_source: The document/source file to use. Can be local or remote.
171+
172+
:param params: Parameters to set when sending a file.
173+
174+
:return: A valid inference response.
175+
"""
176+
response = self._enqueue_and_get(input_source, params)
177+
assert isinstance(response, InferenceResponse), (
178+
f'Invalid response type "{type(response)}"'
179+
)
180+
return response
181+
182+
def enqueue_and_get_utility(
183+
self,
184+
inference_response_type: Type[InferenceResponse],
185+
input_source: Union[LocalInputSource, UrlInputSource],
186+
params: UtilityParameters,
187+
) -> InferenceResponse:
188+
"""
189+
Enqueues to an asynchronous endpoint and automatically polls for a response.
190+
191+
:param input_source: The document/source file to use. Can be local or remote.
192+
193+
:param params: Parameters to set when sending a file.
194+
195+
:param inference_response_type: The product class to use for the response object.
196+
197+
:return: A valid inference response.
198+
"""
199+
response = self._enqueue_and_get(input_source, params, inference_response_type)
200+
assert isinstance(response, inference_response_type), (
201+
f'Invalid response type "{type(response)}"'
202+
)
203+
return response

mindee/input/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from mindee.input.local_response import LocalResponse
2+
from mindee.input.base_parameters import BaseParameters
3+
from mindee.input.inference_parameters import InferenceParameters
4+
from mindee.input.utility_parameters import UtilityParameters
25
from mindee.input.page_options import PageOptions
36
from mindee.input.polling_options import PollingOptions
47
from mindee.input.sources.base_64_input import Base64Input
@@ -11,15 +14,18 @@
1114
from mindee.input.workflow_options import WorkflowOptions
1215

1316
__all__ = [
17+
"Base64Input",
18+
"BaseParameters",
19+
"BytesInput",
20+
"FileInput",
1421
"InputType",
22+
"InferenceParameters",
1523
"LocalInputSource",
16-
"UrlInputSource",
24+
"LocalResponse",
25+
"PageOptions",
1726
"PathInput",
18-
"FileInput",
19-
"Base64Input",
20-
"BytesInput",
21-
"WorkflowOptions",
2227
"PollingOptions",
23-
"PageOptions",
24-
"LocalResponse",
28+
"UrlInputSource",
29+
"UtilityParameters",
30+
"WorkflowOptions",
2531
]

mindee/input/base_parameters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from abc import ABC
2+
from dataclasses import dataclass
3+
from typing import Optional, List
4+
5+
from mindee.input.polling_options import PollingOptions
6+
7+
8+
@dataclass
9+
class BaseParameters(ABC):
10+
"""Base class for parameters accepted by all V2 endpoints."""
11+
12+
model_id: str
13+
"""ID of the model, required."""
14+
alias: Optional[str] = None
15+
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
16+
webhook_ids: Optional[List[str]] = None
17+
"""IDs of webhooks to propagate the API response to."""
18+
polling_options: Optional[PollingOptions] = None
19+
"""Options for polling. Set only if having timeout issues."""
20+
close_file: bool = True
21+
"""Whether to close the file after parsing."""

mindee/input/inference_parameters.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass, asdict
33
from typing import List, Optional, Union
44

5-
from mindee.input.polling_options import PollingOptions
5+
from mindee.input.base_parameters import BaseParameters
66

77

88
@dataclass
@@ -44,7 +44,7 @@ class DataSchemaField(StringDataClass):
4444
guidelines: Optional[str] = None
4545
"""Optional extraction guidelines."""
4646
nested_fields: Optional[dict] = None
47-
"""Subfields when type is `nested_object`. Leave empty for other types"""
47+
"""Subfields when type is `nested_object`. Leave empty for other types."""
4848

4949

5050
@dataclass
@@ -78,11 +78,9 @@ def __post_init__(self) -> None:
7878

7979

8080
@dataclass
81-
class InferenceParameters:
81+
class InferenceParameters(BaseParameters):
8282
"""Inference parameters to set when sending a file."""
8383

84-
model_id: str
85-
"""ID of the model, required."""
8684
rag: Optional[bool] = None
8785
"""Enhance extraction accuracy with Retrieval-Augmented Generation."""
8886
raw_text: Optional[bool] = None
@@ -94,14 +92,6 @@ class InferenceParameters:
9492
Boost the precision and accuracy of all extractions.
9593
Calculate confidence scores for all fields, and fill their ``confidence`` attribute.
9694
"""
97-
alias: Optional[str] = None
98-
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
99-
webhook_ids: Optional[List[str]] = None
100-
"""IDs of webhooks to propagate the API response to."""
101-
polling_options: Optional[PollingOptions] = None
102-
"""Options for polling. Set only if having timeout issues."""
103-
close_file: bool = True
104-
"""Whether to close the file after parsing."""
10595
text_context: Optional[str] = None
10696
"""
10797
Additional text context used by the model during inference.

mindee/input/utility_parameters.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from dataclasses import dataclass
2+
3+
from mindee.input.base_parameters import BaseParameters
4+
5+
6+
@dataclass
7+
class UtilityParameters(BaseParameters):
8+
"""
9+
Parameters accepted by any of the asynchronous **inference** utility v2 endpoints.
10+
"""

mindee/mindee_http/mindee_api_v2.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55

66
from mindee.error.mindee_error import MindeeApiV2Error
7-
from mindee.input import LocalInputSource, UrlInputSource
7+
from mindee.input import LocalInputSource, UrlInputSource, UtilityParameters
88
from mindee.input.inference_parameters import InferenceParameters
99
from mindee.logger import logger
1010
from mindee.mindee_http.base_settings import USER_AGENT
@@ -74,34 +74,37 @@ def set_from_env(self) -> None:
7474
def req_post_inference_enqueue(
7575
self,
7676
input_source: Union[LocalInputSource, UrlInputSource],
77-
params: InferenceParameters,
77+
params: Union[InferenceParameters, UtilityParameters],
78+
slug: Optional[str] = None,
7879
) -> requests.Response:
7980
"""
8081
Make an asynchronous request to POST a document for prediction on the V2 API.
8182
8283
:param input_source: Input object.
8384
:param params: Options for the enqueueing of the document.
85+
:param slug: Slug to use for the enqueueing, defaults to 'inferences'.
8486
:return: requests response.
8587
"""
88+
slug = slug if slug else "inferences"
8689
data: Dict[str, Union[str, list]] = {"model_id": params.model_id}
87-
url = f"{self.url_root}/inferences/enqueue"
88-
89-
if params.rag is not None:
90-
data["rag"] = str(params.rag).lower()
91-
if params.raw_text is not None:
92-
data["raw_text"] = str(params.raw_text).lower()
93-
if params.confidence is not None:
94-
data["confidence"] = str(params.confidence).lower()
95-
if params.polygon is not None:
96-
data["polygon"] = str(params.polygon).lower()
90+
url = f"{self.url_root}/{slug}/enqueue"
91+
if isinstance(params, InferenceParameters):
92+
if params.rag is not None:
93+
data["rag"] = str(params.rag).lower()
94+
if params.raw_text is not None:
95+
data["raw_text"] = str(params.raw_text).lower()
96+
if params.confidence is not None:
97+
data["confidence"] = str(params.confidence).lower()
98+
if params.polygon is not None:
99+
data["polygon"] = str(params.polygon).lower()
100+
if params.text_context and len(params.text_context):
101+
data["text_context"] = params.text_context
102+
if params.data_schema is not None:
103+
data["data_schema"] = str(params.data_schema)
97104
if params.webhook_ids and len(params.webhook_ids) > 0:
98105
data["webhook_ids"] = params.webhook_ids
99106
if params.alias and len(params.alias):
100107
data["alias"] = params.alias
101-
if params.text_context and len(params.text_context):
102-
data["text_context"] = params.text_context
103-
if params.data_schema is not None:
104-
data["data_schema"] = str(params.data_schema)
105108

106109
if isinstance(input_source, LocalInputSource):
107110
files = {"file": input_source.read_contents(params.close_file)}

mindee/parsing/v2/inference.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,20 @@
11
from mindee.parsing.common.string_dict import StringDict
2+
from mindee.v2.parsing.inference import BaseInference
23
from mindee.parsing.v2.inference_active_options import InferenceActiveOptions
3-
from mindee.parsing.v2.inference_file import InferenceFile
4-
from mindee.parsing.v2.inference_model import InferenceModel
54
from mindee.parsing.v2.inference_result import InferenceResult
65

76

8-
class Inference:
7+
class Inference(BaseInference):
98
"""Inference object for a V2 API return."""
109

11-
id: str
12-
"""ID of the inference."""
13-
model: InferenceModel
14-
"""Model info for the inference."""
15-
file: InferenceFile
16-
"""File info for the inference."""
1710
result: InferenceResult
1811
"""Result of the inference."""
1912
active_options: InferenceActiveOptions
2013
"""Active options for the inference."""
14+
_slug: str = "inferences"
15+
"""Slug of the inference."""
2116

2217
def __init__(self, raw_response: StringDict):
23-
self.id = raw_response["id"]
24-
self.model = InferenceModel(raw_response["model"])
25-
self.file = InferenceFile(raw_response["file"])
18+
super().__init__(raw_response)
2619
self.result = InferenceResult(raw_response["result"])
2720
self.active_options = InferenceActiveOptions(raw_response["active_options"])
28-
29-
def __str__(self) -> str:
30-
return (
31-
f"Inference\n#########"
32-
f"\n{self.model}"
33-
f"\n\n{self.file}"
34-
f"\n\n{self.active_options}"
35-
f"\n\n{self.result}\n"
36-
)

0 commit comments

Comments
 (0)