diff --git a/paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py b/paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py index 02cbdff95d..6f1f312eec 100644 --- a/paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +++ b/paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py @@ -262,7 +262,7 @@ def get_html_result( return html -def sort_table_cells_boxes(boxes): +def sort_table_cells_boxes(boxes, overlap_threshold=0.5): """ Sort the input list of bounding boxes. @@ -273,33 +273,37 @@ def sort_table_cells_boxes(boxes): sorted_boxes (list of lists): The list of bounding boxes sorted. """ + def is_same_row(box1, box2): + _, y1a, _, y2a = box1 + _, y1b, _, y2b = box2 + overlap = max(0, min(y2a, y2b) - max(y1a, y1b)) + min_height = min(y2a - y1a, y2b - y1b) + if min_height <= 0: + return False + return overlap / min_height >= overlap_threshold + boxes_sorted_by_y = sorted(boxes, key=lambda box: box[1]) + rows = [] - current_row = [] - current_y = None - tolerance = 10 - for box in boxes_sorted_by_y: - x1, y1, x2, y2 = box - if current_y is None: + current_row = [boxes_sorted_by_y[0]] + + for box in boxes_sorted_by_y[1:]: + if is_same_row(current_row[-1], box): current_row.append(box) - current_y = y1 else: - if abs(y1 - current_y) <= tolerance: - current_row.append(box) - else: - current_row.sort(key=lambda x: x[0]) - rows.append(current_row) - current_row = [box] - current_y = y1 + current_row.sort(key=lambda x: x[0]) + rows.append(current_row) + current_row = [box] if current_row: current_row.sort(key=lambda x: x[0]) rows.append(current_row) + sorted_boxes = [] flag = [0] - for i in range(len(rows)): - sorted_boxes.extend(rows[i]) - if i < len(rows): - flag.append(flag[i] + len(rows[i])) + for row in rows: + sorted_boxes.extend(row) + flag.append(flag[-1] + len(row)) + return sorted_boxes, flag