From 644ba2042422c5856480e9c0e4f2f5579511b1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ianar=C3=A9=20S=C3=A9vi?= Date: Mon, 25 Aug 2025 14:06:34 +0200 Subject: [PATCH] :sparkles: add easier accessors --- .../java/com/mindee/geometry/Polygon.java | 3 + .../mindee/parsing/v2/field/DynamicField.java | 44 +++++++++ .../parsing/v2/field/FieldLocation.java | 2 +- .../parsing/v2/field/InferenceFields.java | 47 ++++++++-- .../com/mindee/parsing/v2/InferenceTest.java | 91 +++++++++++++++++-- 5 files changed, 169 insertions(+), 18 deletions(-) diff --git a/src/main/java/com/mindee/geometry/Polygon.java b/src/main/java/com/mindee/geometry/Polygon.java index fa6f06144..d76a317ef 100644 --- a/src/main/java/com/mindee/geometry/Polygon.java +++ b/src/main/java/com/mindee/geometry/Polygon.java @@ -12,6 +12,9 @@ @Getter @JsonDeserialize(using = PolygonDeserializer.class) public class Polygon { + /** + * Position information as a list of points in clockwise order. + */ private List coordinates = new ArrayList<>(); @Builder diff --git a/src/main/java/com/mindee/parsing/v2/field/DynamicField.java b/src/main/java/com/mindee/parsing/v2/field/DynamicField.java index 210aff5d6..b7b5001a8 100644 --- a/src/main/java/com/mindee/parsing/v2/field/DynamicField.java +++ b/src/main/java/com/mindee/parsing/v2/field/DynamicField.java @@ -50,6 +50,50 @@ public static DynamicField of(ListField value) { return new DynamicField(FieldType.LIST_FIELD, null, value, null); } + public SimpleField getSimpleField() throws IllegalStateException { + if (type != FieldType.SIMPLE_FIELD) { + throw new IllegalStateException("Field is not a simple field"); + } + return simpleField; + } + + public ListField getListField() throws IllegalStateException { + if (type != FieldType.LIST_FIELD) { + throw new IllegalStateException("Field is not a list field"); + } + return listField; + } + + public ObjectField getObjectField() throws IllegalStateException { + if (type != FieldType.OBJECT_FIELD) { + throw new IllegalStateException("Field is not an object field"); + } + return objectField; + } + + /** + * Returns the field as the specified class. + * + * @param type the class representing the desired field type + * @param the type of field to return + * @throws IllegalArgumentException if the requested type is not SimpleField, ListField, or ObjectField + * @throws IllegalStateException if the field's internal type does not match the requested type + */ + public T getField(Class type) throws IllegalArgumentException { + if (type == SimpleField.class) { + return (T) this.getSimpleField(); + } + if (type == ListField.class) { + return (T) this.getListField(); + } + if (type == ObjectField.class) { + return (T) this.getObjectField(); + } + throw new IllegalArgumentException( + "Cannot cast to " + type.getSimpleName() + ); + } + @Override public String toString() { if (simpleField != null) return simpleField.toString(); diff --git a/src/main/java/com/mindee/parsing/v2/field/FieldLocation.java b/src/main/java/com/mindee/parsing/v2/field/FieldLocation.java index dea8b6783..35a7ec2ce 100644 --- a/src/main/java/com/mindee/parsing/v2/field/FieldLocation.java +++ b/src/main/java/com/mindee/parsing/v2/field/FieldLocation.java @@ -28,7 +28,7 @@ public class FieldLocation { private Polygon polygon; /** - * Page ID. + * 0-based page index of where the polygon is located. */ @JsonProperty("page") private int page; diff --git a/src/main/java/com/mindee/parsing/v2/field/InferenceFields.java b/src/main/java/com/mindee/parsing/v2/field/InferenceFields.java index 669a60d57..cd9d871f2 100644 --- a/src/main/java/com/mindee/parsing/v2/field/InferenceFields.java +++ b/src/main/java/com/mindee/parsing/v2/field/InferenceFields.java @@ -12,6 +12,37 @@ @EqualsAndHashCode(callSuper = true) @JsonIgnoreProperties(ignoreUnknown = true) public final class InferenceFields extends LinkedHashMap { + + /** + * Retrieves the field as a `SimpleField`. + * + * @param fieldName the name of the field + * @throws IllegalStateException if the field is not a SimpleField + */ + public SimpleField getSimpleField(String fieldName) throws IllegalStateException { + return this.get(fieldName).getSimpleField(); + } + + /** + * Retrieves the field as a `ListField`. + * + * @param fieldName the name of the field + * @throws IllegalStateException if the field is not a ListField + */ + public ListField getListField(String fieldName) throws IllegalStateException { + return this.get(fieldName).getListField(); + } + + /** + * Retrieves the field as an `ObjectField`. + * + * @param fieldName the name of the field + * @throws IllegalStateException if the field is not a ObjectField + */ + public ObjectField getObjectField(String fieldName) throws IllegalStateException { + return this.get(fieldName).getObjectField(); + } + public String toString(int indent) { String padding = String.join("", java.util.Collections.nCopies(indent, " ")); if (this.isEmpty()) { @@ -19,21 +50,21 @@ public String toString(int indent) { } StringJoiner joiner = new StringJoiner("\n"); - this.forEach((fieldKey, fieldValue) -> { + this.forEach((fieldKey, fieldInstance) -> { StringBuilder strBuilder = new StringBuilder(); strBuilder.append(padding).append(":").append(fieldKey).append(": "); - if (fieldValue.getListField() != null) { - ListField listField = fieldValue.getListField(); + if (fieldInstance.getType() == DynamicField.FieldType.LIST_FIELD) { + ListField listField = fieldInstance.getListField(); if (listField.getItems() != null && !listField.getItems().isEmpty()) { strBuilder.append(listField); } - } else if (fieldValue.getObjectField() != null) { - strBuilder.append(fieldValue.getObjectField()); - } else if (fieldValue.getSimpleField() != null) { + } else if (fieldInstance.getType() == DynamicField.FieldType.OBJECT_FIELD) { + strBuilder.append(fieldInstance.getObjectField()); + } else if (fieldInstance.getType() == DynamicField.FieldType.SIMPLE_FIELD) { strBuilder.append( - fieldValue.getSimpleField().getValue() != null - ? fieldValue.getSimpleField().toString() + fieldInstance.getSimpleField().getValue() != null + ? fieldInstance.getSimpleField().toString() : "" ); diff --git a/src/test/java/com/mindee/parsing/v2/InferenceTest.java b/src/test/java/com/mindee/parsing/v2/InferenceTest.java index dc9642d7a..45a0e3866 100644 --- a/src/test/java/com/mindee/parsing/v2/InferenceTest.java +++ b/src/test/java/com/mindee/parsing/v2/InferenceTest.java @@ -1,7 +1,15 @@ package com.mindee.parsing.v2; +import com.mindee.geometry.Point; +import com.mindee.geometry.Polygon; import com.mindee.input.LocalResponse; -import com.mindee.parsing.v2.field.*; +import com.mindee.parsing.v2.field.DynamicField; +import com.mindee.parsing.v2.field.FieldConfidence; +import com.mindee.parsing.v2.field.FieldLocation; +import com.mindee.parsing.v2.field.InferenceFields; +import com.mindee.parsing.v2.field.SimpleField; +import com.mindee.parsing.v2.field.ListField; +import com.mindee.parsing.v2.field.ObjectField; import com.mindee.parsing.v2.field.DynamicField.FieldType; import java.io.IOException; import java.util.List; @@ -62,21 +70,19 @@ void asyncPredict_whenEmpty_mustHaveValidProperties() throws IOException { switch (type) { case LIST_FIELD: assertNotNull(value.getListField(), entry.getKey() + " – ListField expected"); - assertNull(value.getObjectField(), entry.getKey() + " – ObjectField must be null"); - assertNull(value.getSimpleField(), entry.getKey() + " – SimpleField must be null"); + assertThrows(IllegalStateException.class, value::getSimpleField); + assertThrows(IllegalStateException.class, value::getObjectField); break; - case OBJECT_FIELD: assertNotNull(value.getObjectField(), entry.getKey() + " – ObjectField expected"); - assertNull(value.getListField(), entry.getKey() + " – ListField must be null"); - assertNull(value.getSimpleField(), entry.getKey() + " – SimpleField must be null"); + assertThrows(IllegalStateException.class, value::getSimpleField); + assertThrows(IllegalStateException.class, value::getListField); break; - case SIMPLE_FIELD: default: assertNotNull(value.getSimpleField(), entry.getKey() + " – SimpleField expected"); - assertNull(value.getListField(), entry.getKey() + " – ListField must be null"); - assertNull(value.getObjectField(), entry.getKey() + " – ObjectField must be null"); + assertThrows(IllegalStateException.class, value::getListField); + assertThrows(IllegalStateException.class, value::getObjectField); break; } } @@ -263,6 +269,73 @@ void standardFieldTypes_mustExposeCorrectTypes() throws IOException { } } + @Test + @DisplayName("allow getting fields using generics") + void standardFieldTypes_getWithGenerics() throws IOException { + InferenceResponse response = loadFromResource("v2/inference/standard_field_types.json"); + Inference inference = response.getInference(); + assertNotNull(inference); + InferenceFields fields = inference.getResult().getFields(); + + assertEquals( + fields.get("field_simple_bool").getSimpleField(), + fields.get("field_simple_bool").getField(SimpleField.class) + ); + assertEquals( + fields.get("field_simple_bool").getSimpleField(), + fields.getSimpleField("field_simple_bool") + ); + + assertEquals( + fields.get("field_simple_list").getListField(), + fields.get("field_simple_list").getField(ListField.class) + ); + assertEquals( + fields.get("field_simple_list").getListField(), + fields.getListField("field_simple_list") + ); + + assertEquals( + fields.get("field_object").getObjectField(), + fields.get("field_object").getField(ObjectField.class) + ); + assertEquals( + fields.get("field_object").getObjectField(), + fields.getObjectField("field_object") + ); + } + + @Test + @DisplayName("confidence and locations must be usable") + void standardFieldTypes_confidenceAndLocations() throws IOException { + InferenceResponse response = loadFromResource("v2/inference/standard_field_types.json"); + Inference inference = response.getInference(); + assertNotNull(inference); + + InferenceFields fields = inference.getResult().getFields(); + + SimpleField fieldSimpleString = fields.get("field_simple_string").getField(SimpleField.class); + FieldConfidence confidence = fieldSimpleString.getConfidence(); + boolean isCertain = confidence == FieldConfidence.Certain; + assertTrue(isCertain); + + List locations = fieldSimpleString.getLocations(); + assertEquals(1, locations.size()); + FieldLocation location = locations.get(0); + + Polygon polygon = location.getPolygon(); + List coords = polygon.getCoordinates(); + assertEquals(4, coords.size()); + double topX = coords.get(0).getX(); + assertEquals(0.0, topX); + + Point center = polygon.getCentroid(); + assertEquals(0.5, center.getX(), 0.00001); + + int pageIndex = location.getPage(); + assertEquals(0, pageIndex); + } + @Nested @DisplayName("raw_texts.json") class RawTexts {