diff --git a/README.md b/README.md index c6ba7ef61..3076d8e76 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ import com.mindee.product.us.bankcheck.BankCheckV1; ### Custom Documents (docTI & Custom APIs) ```java import com.mindee.MindeeClient; +import com.mindee.PredictOptions; import com.mindee.input.LocalInputSource; import com.mindee.parsing.common.PredictResponse; import com.mindee.product.generated.GeneratedV1; @@ -98,6 +99,7 @@ public class SimpleMindeeClient { Document customDocument = mindeeClient.enqueueAndParse( localInputSource, endpoint + // PredictOptions.builder().build(), ); } } @@ -116,6 +118,7 @@ This is the easiest way to get started. ```java import com.mindee.MindeeClient; +import com.mindee.PredictOptions; import com.mindee.input.LocalInputSource; import com.mindee.parsing.common.AsyncPredictResponse; import com.mindee.product.internationalid.InternationalIdV2; @@ -138,6 +141,7 @@ public class SimpleMindeeClient { AsyncPredictResponse response = mindeeClient.enqueueAndParse( InternationalIdV2.class, inputSource + // PredictOptions.builder().build(), ); // Print a summary of the response diff --git a/docs/code_samples/workflow_execution.txt b/docs/code_samples/workflow_execution.txt index 316ce7488..a0412c84a 100644 --- a/docs/code_samples/workflow_execution.txt +++ b/docs/code_samples/workflow_execution.txt @@ -24,7 +24,6 @@ public class SimpleMindeeClient { inputSource ); - // Alternatively: give an alias to the document // WorkflowResponse response = mindeeClient.executeWorkflow( // workflowId, diff --git a/src/main/java/com/mindee/MindeeClient.java b/src/main/java/com/mindee/MindeeClient.java index 1ee0060cf..272d2195b 100644 --- a/src/main/java/com/mindee/MindeeClient.java +++ b/src/main/java/com/mindee/MindeeClient.java @@ -146,6 +146,30 @@ public AsyncPredictResponse enqueue( ); } + /** + * Send a local file to an async queue. + * @param Type of inference. + * @param type Type of inference. + * @param localInputSource A local input source file. + * @param predictOptions Prediction options for the enqueuing. + * @return an instance of {@link AsyncPredictResponse}. + * @throws IOException Throws if the file can't be accessed. + */ + public AsyncPredictResponse enqueue( + Class type, + LocalInputSource localInputSource, + PredictOptions predictOptions + ) throws IOException { + return this.enqueue( + type, + new Endpoint(type), + localInputSource.getFile(), + localInputSource.getFilename(), + predictOptions, + null + ); + } + /** * Send a remote file to an async queue. * @param Type of inference. @@ -290,6 +314,60 @@ public AsyncPredictResponse enqueueAndParse( ); } + /** + * Send a local file to an async queue, poll, and parse when complete. + * @param Type of inference. + * @param type Type of inference. + * @param localInputSource A local input source file. + * @param predictOptions Prediction options for the enqueuing. + * @param pollingOptions Options for async call parameters. + * @return an instance of {@link AsyncPredictResponse}. + * @throws IOException Throws if the file can't be accessed. + * @throws InterruptedException Throws in the event of a timeout. + */ + public AsyncPredictResponse enqueueAndParse( + Class type, + LocalInputSource localInputSource, + PredictOptions predictOptions, + AsyncPollingOptions pollingOptions + ) throws IOException, InterruptedException { + return this.enqueueAndParse( + type, + new Endpoint(type), + pollingOptions, + localInputSource.getFile(), + localInputSource.getFilename(), + predictOptions, + null + ); + } + + /** + * Send a local file to an async queue, poll, and parse when complete. + * @param Type of inference. + * @param type Type of inference. + * @param localInputSource A local input source file. + * @param predictOptions Prediction options for the enqueuing. + * @return an instance of {@link AsyncPredictResponse}. + * @throws IOException Throws if the file can't be accessed. + * @throws InterruptedException Throws in the event of a timeout. + */ + public AsyncPredictResponse enqueueAndParse( + Class type, + LocalInputSource localInputSource, + PredictOptions predictOptions + ) throws IOException, InterruptedException { + return this.enqueueAndParse( + type, + new Endpoint(type), + null, + localInputSource.getFile(), + localInputSource.getFilename(), + predictOptions, + null + ); + } + /** * Send a remote file to an async queue, poll, and parse when complete. * @param Type of inference. diff --git a/src/main/java/com/mindee/PredictOptions.java b/src/main/java/com/mindee/PredictOptions.java index fa39e4184..40966adf4 100644 --- a/src/main/java/com/mindee/PredictOptions.java +++ b/src/main/java/com/mindee/PredictOptions.java @@ -24,15 +24,28 @@ public class PredictOptions { * size. */ Boolean fullText; + /** + * If set, will enqueue to a workflow queue instead of a product's endpoint. + */ + String workflowId; + /** + * If set, will enable Retrieval-Augmented Generation. + * Only works if a valid workflowId is set. + */ + Boolean rag; @Builder private PredictOptions( Boolean allWords, Boolean fullText, - Boolean cropper + Boolean cropper, + String workflowId, + Boolean rag ) { this.allWords = allWords == null ? Boolean.FALSE : allWords; this.fullText = fullText == null ? Boolean.FALSE : fullText; this.cropper = cropper == null ? Boolean.FALSE : cropper; + this.workflowId = workflowId; + this.rag = rag == null ? Boolean.FALSE : rag; } } diff --git a/src/main/java/com/mindee/http/MindeeHttpApi.java b/src/main/java/com/mindee/http/MindeeHttpApi.java index 9f5b050fb..8a1d298d6 100644 --- a/src/main/java/com/mindee/http/MindeeHttpApi.java +++ b/src/main/java/com/mindee/http/MindeeHttpApi.java @@ -41,8 +41,9 @@ public final class MindeeHttpApi extends MindeeApi { private static final ObjectMapper mapper = new ObjectMapper(); - private final Function buildBaseUrl = this::buildProductUrl; - private final Function buildWorkflowBaseUrl = this::buildWorkflowUrl; + private final Function buildProductPredicBasetUrl = this::buildProductPredictBaseUrl; + private final Function buildWorkflowPredictBaseUrl = this::buildWorkflowPredictBaseUrl; + private final Function buildWorkflowExecutionBaseUrl = this::buildWorkflowExecutionUrl; /** * The MindeeSetting needed to make the api call. */ @@ -53,24 +54,27 @@ public final class MindeeHttpApi extends MindeeApi { */ private final HttpClientBuilder httpClientBuilder; /** - * The function used to generate the API endpoint URL. + * The function used to generate the synchronous API endpoint URL. * Only needs to be set if the api calls need to be directed through internal URLs. */ private final Function urlFromEndpoint; - /** - * The function used to generate the API endpoint URL for workflow execution calls. + * The function used to generate the asynchronous API endpoint URL for a product. * Only needs to be set if the api calls need to be directed through internal URLs. */ private final Function asyncUrlFromEndpoint; + /** + * The function used to generate the asynchronous API endpoint URL for a workflow. + * Only needs to be set if the api calls need to be directed through internal URLs. + */ + private final Function asyncUrlFromWorkflow; /** * The function used to generate the Job status URL for Async calls. * Only needs to be set if the api calls need to be directed through internal URLs. */ private final Function documentUrlFromEndpoint; - /** - * The function used to generate the Job status URL for Async calls. + * The function used to generate the Job status URL for workflow execution calls. * Only needs to be set if the api calls need to be directed through internal URLs. */ private final Function workflowUrlFromId; @@ -82,6 +86,7 @@ public MindeeHttpApi(MindeeSettings mindeeSettings) { null, null, null, + null, null ); } @@ -93,7 +98,8 @@ private MindeeHttpApi( Function urlFromEndpoint, Function asyncUrlFromEndpoint, Function documentUrlFromEndpoint, - Function workflowUrlFromEndpoint + Function workflowUrlFromEndpoint, + Function asyncUrlFromWorkflow ) { this.mindeeSettings = mindeeSettings; @@ -106,26 +112,35 @@ private MindeeHttpApi( if (urlFromEndpoint != null) { this.urlFromEndpoint = urlFromEndpoint; } else { - this.urlFromEndpoint = buildBaseUrl.andThen((url) -> url.concat("/predict")); + this.urlFromEndpoint = buildProductPredicBasetUrl.andThen( + (url) -> url.concat("/predict")); + } + + if (asyncUrlFromWorkflow != null) { + this.asyncUrlFromWorkflow = asyncUrlFromWorkflow; + } else { + this.asyncUrlFromWorkflow = this.buildWorkflowPredictBaseUrl.andThen( + (url) -> url.concat("/predict_async")); } if (asyncUrlFromEndpoint != null) { this.asyncUrlFromEndpoint = asyncUrlFromEndpoint; } else { - this.asyncUrlFromEndpoint = this.urlFromEndpoint.andThen((url) -> url.concat("_async")); + this.asyncUrlFromEndpoint = this.buildProductPredicBasetUrl.andThen( + (url) -> url.concat("/predict_async")); } if (documentUrlFromEndpoint != null) { this.documentUrlFromEndpoint = documentUrlFromEndpoint; } else { - this.documentUrlFromEndpoint = this.buildBaseUrl.andThen( + this.documentUrlFromEndpoint = this.buildProductPredicBasetUrl.andThen( (url) -> url.concat("/documents/queue/")); } if (workflowUrlFromEndpoint != null) { this.workflowUrlFromId = workflowUrlFromEndpoint; } else { - this.workflowUrlFromId = this.buildWorkflowBaseUrl; + this.workflowUrlFromId = this.buildWorkflowExecutionBaseUrl; } } @@ -233,7 +248,12 @@ public AsyncPredictResponse predictAsyncPost( RequestParameters requestParameters ) throws IOException { - String url = asyncUrlFromEndpoint.apply(endpoint); + String url; + if (requestParameters.getPredictOptions().getWorkflowId() != null) { + url = asyncUrlFromWorkflow.apply(requestParameters.getPredictOptions().getWorkflowId()); + } else { + url = asyncUrlFromEndpoint.apply(endpoint); + } HttpPost post = buildHttpPost(url, requestParameters); // required to register jackson date module format to deserialize @@ -340,7 +360,7 @@ private MindeeHttpException getHttpError( return new MindeeHttpException(statusCode, message, details, errorCode); } - private String buildProductUrl(Endpoint endpoint) { + private String buildProductPredictBaseUrl(Endpoint endpoint) { return this.mindeeSettings.getBaseUrl() + "/products/" + endpoint.getAccountName() @@ -350,7 +370,11 @@ private String buildProductUrl(Endpoint endpoint) { + endpoint.getVersion(); } - private String buildWorkflowUrl(String workflowId) { + private String buildWorkflowPredictBaseUrl(String workflowId) { + return this.mindeeSettings.getBaseUrl() + "/workflows/" + workflowId; + } + + private String buildWorkflowExecutionUrl(String workflowId) { return this.mindeeSettings.getBaseUrl() + "/workflows/" + workflowId + "/executions"; } @@ -388,7 +412,9 @@ private List buildPostParams( if (Boolean.TRUE.equals(requestParameters.getPredictOptions().getFullText())) { params.add(new BasicNameValuePair("full_text_ocr", "true")); } - if (Boolean.TRUE.equals(requestParameters.getWorkflowOptions().getRag())) { + if (Boolean.TRUE.equals(requestParameters.getWorkflowOptions().getRag()) + || Boolean.TRUE.equals(requestParameters.getPredictOptions().getRag()) + ) { params.add(new BasicNameValuePair("rag", "true")); } return params; diff --git a/src/main/java/com/mindee/parsing/common/InferenceExtras.java b/src/main/java/com/mindee/parsing/common/InferenceExtras.java index ea5f1aaea..a8f065960 100644 --- a/src/main/java/com/mindee/parsing/common/InferenceExtras.java +++ b/src/main/java/com/mindee/parsing/common/InferenceExtras.java @@ -1,6 +1,7 @@ package com.mindee.parsing.common; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; @@ -17,4 +18,9 @@ public class InferenceExtras { * Full Text OCR result. */ private String fullTextOcr; + /** + * Retrieval-Augmented Generation results. + */ + @JsonProperty("rag") + private Rag rag; } diff --git a/src/main/java/com/mindee/parsing/common/Rag.java b/src/main/java/com/mindee/parsing/common/Rag.java new file mode 100644 index 000000000..ca124b5a7 --- /dev/null +++ b/src/main/java/com/mindee/parsing/common/Rag.java @@ -0,0 +1,22 @@ +package com.mindee.parsing.common; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; + +/** + * Retrieval-Augmented Generation info class. + */ +@Getter +@EqualsAndHashCode +@JsonIgnoreProperties(ignoreUnknown = true) +public class Rag { + /** + * The document ID that was matched. + */ + @Setter + @JsonProperty("matching_document_id") + private String matchingDocumentId; +} diff --git a/src/test/java/com/mindee/workflow/WorkflowIT.java b/src/test/java/com/mindee/workflow/WorkflowIT.java index 72617bef7..5cf866eda 100644 --- a/src/test/java/com/mindee/workflow/WorkflowIT.java +++ b/src/test/java/com/mindee/workflow/WorkflowIT.java @@ -1,12 +1,15 @@ package com.mindee.workflow; import com.mindee.MindeeClient; -import com.mindee.MindeeException; +import com.mindee.PredictOptions; import com.mindee.WorkflowOptions; import com.mindee.input.LocalInputSource; +import com.mindee.input.PageOptions; +import com.mindee.parsing.common.AsyncPredictResponse; import com.mindee.parsing.common.Execution; import com.mindee.parsing.common.ExecutionPriority; import com.mindee.parsing.common.WorkflowResponse; +import com.mindee.product.financialdocument.FinancialDocumentV1; import com.mindee.product.generated.GeneratedV1; import java.io.IOException; import java.time.LocalDateTime; @@ -19,6 +22,7 @@ public class WorkflowIT { private static MindeeClient client; private static LocalInputSource financialDocumentInputSource; private static String currentDateTime; + private static String workflowId; @BeforeAll static void clientSetUp() throws IOException { @@ -26,28 +30,48 @@ static void clientSetUp() throws IOException { DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd-HH:mm:ss"); currentDateTime = now.format(formatter); client = new MindeeClient(); + workflowId = System.getenv("WORKFLOW_ID"); financialDocumentInputSource = new LocalInputSource( "src/test/resources/products/financial_document/default_sample.jpg" ); } - protected Execution getFinancialDocumentWorkflow(String workflowId) throws - IOException, MindeeException { + @Test + public void givenAWorkflowIdUploadShouldReturnACorrectWorkflowObject() throws + IOException { WorkflowOptions options = WorkflowOptions.builder().alias("java-" + currentDateTime).priority( ExecutionPriority.LOW).rag(true).build(); WorkflowResponse response = client.executeWorkflow(workflowId, financialDocumentInputSource, options); - return response.getExecution(); + Execution execution = response.getExecution(); + Assertions.assertEquals("low", execution.getPriority()); + Assertions.assertEquals("java-" + currentDateTime, execution.getFile().getAlias()); } - @Test - public void givenAWorkflowIDShouldReturnACorrectWorkflowObject() throws IOException { - Execution execution = getFinancialDocumentWorkflow(System.getenv("WORKFLOW_ID")); + public void GivenAWorkflowIdPredictCustomShouldPollAndNotMatchRag() throws + IOException, InterruptedException { - Assertions.assertEquals("low", execution.getPriority()); - Assertions.assertEquals("java-" + currentDateTime, execution.getFile().getAlias()); + PredictOptions predictOptions = PredictOptions.builder().workflowId(workflowId).build(); + AsyncPredictResponse response = client.enqueueAndParse( + FinancialDocumentV1.class, financialDocumentInputSource, predictOptions); + Assertions.assertNotNull(response.getDocumentObj().toString()); + Assertions.assertNull( + response.getDocumentObj().getInference().getExtras().getRag()); + } + + @Test + public void GivenAWorkflowIdPredictCustomShouldPollAndMatchRag() throws + IOException, InterruptedException { + PredictOptions predictOptions = PredictOptions.builder().workflowId(workflowId).rag(true).build(); + AsyncPredictResponse response = client.enqueueAndParse( + FinancialDocumentV1.class, financialDocumentInputSource, predictOptions); + Assertions.assertNotNull(response.getDocumentObj().toString()); + Assertions.assertNotNull( + response.getDocumentObj().getInference().getExtras().getRag()); + Assertions.assertNotNull( + response.getDocumentObj().getInference().getExtras().getRag().getMatchingDocumentId()); } }