diff --git a/openpmcvl/granular/README.md b/openpmcvl/granular/README.md index f9a666b..7475f8e 100644 --- a/openpmcvl/granular/README.md +++ b/openpmcvl/granular/README.md @@ -1,4 +1,108 @@ -# Granular Package +# **Granular Pipeline** +Our goal is to create a finegrained dataset of biomedical subfigure-subcaption pairs from the raw dataset of PMC figure-caption pairs. We assume that a dataset of PMC figure-caption pairs, e.g. PMC-17M, is already downloaded, formatted as a directory of JSONL files and a directory of image .jpg files. Note that all .sh files require you to pass in the JSONL numbers from the PMC dataset as arguments. -This package contains tools to extract sub-figures and sub-captions from downloaded image-caption pairs. -This enlarges the dataset, and may increase the quality of the data as well since the sub-pairs will be more focused and less confusing. +Sample command: +```bash +sbatch openpmcvl/granular/pipeline/preprocess.sh 0 1 2 3 4 5 6 7 8 9 10 11 +``` + + +## **1. Preprocess** +> **Code:** `preprocess.py & preprocess.sh`
+> **Input:** Directory of figures and PMC metadata in JSONL format
+> **Output:** Filtered figure-caption pairs in JSONL format (`${num}_meta.jsonl`)
+ +- Filter out figure-caption pairs that are not .jpg images, missing, or corrupted. +- Filter for figure-caption pairs that contain target biomedical keywords. + +Each datapoint contains the following fields: +- `id`: A unique identifier for the figure-caption pair. +- `PMC_ID`: The PMC ID of the article. +- `caption`: The caption of the figure. +- `image_path`: The path to the image file. +- `width`: The width of the image in pixels. +- `height`: The height of the image in pixels. +- `media_id`: The ID of the media file. +- `media_url`: The URL of the media file. +- `media_name`: The name of the media file. +- `keywords`: The keywords found in the caption. +- `is_medical`: Whether the caption contains any target biomedical keywords. +

+ +This script saves the output both as a directory of processed JSONL files and a merged JSONL file. The former is used in the next step of the pipeline. +

+ + +## **2. Subfigure Extraction** +> **Code:** `subfigure.py & subfigure.sh`
+> **Input:** Filtered figure-caption pairs in JSONL format (`${num}_meta.jsonl`)
+> **Output:** Directory of subfigure jpg files, and subfigure metadata in JSONL format (`${num}_subfigures.jsonl`)
+ +- Breakdown compound figures into subfigures. +- Keep original figure for non-compound figures or if an exception occurs. + +Each datapoint contains the following fields: + +When a subfigure is successfully detected and separated: +- `id`: Unique identifier for the subfigure (format: {source_figure_id}_{subfigure_number}.jpg) +- `source_fig_id`: ID of the original compound figure +- `PMC_ID`: PMC ID of the source article +- `media_name`: Original filename of the compound figure +- `position`: Coordinates of subfigure bounding box [(x1,y1), (x2,y2)] +- `score`: Detection confidence score +- `subfig_path`: Path to saved subfigure image + +When subfigure extraction fails: +- `id`: Generated ID that would have been used +- `source_fig_id`: ID of the original figure +- `PMC_ID`: PMC ID of the source article +- `media_name`: Original filename + +This script saves extracted subfigures as .jpg files in the target directory. Metadata for each subfigure is stored in separate JSONL files, with unique IDs that link back to the original figure-caption pairs in the source JSONL files. +

+ + +## **3. Subcaption Extraction** +> **Code:** `subcaption.ipynb | subcaption.py & subcaption.sh`
+> **Input:** PMC metadata in JSONL format
+> **Output:** PMC metadata in JSONL format with subcaptions
+ +- Extract subcaptions from captions. +- Keep original caption if the caption cannot be split into subcaptions. + +While this pipeline works, its slow as it goes through API calls one by one. There is a notebook (`subcaption.ipynb`) using batch API calls to speed it up. It's highly recommended to use the notebook instead of this script. +

+ + +## **4. Classification** +> **Code:** `classify.py & classify.sh`
+> **Input:** Subfigure metadata in JSONL format (`${num}_subfigures.jsonl`)
+> **Output:** Subfigure metadata in JSONL format (`${num}_subfigures_classified.jsonl`)
+ +- Classify subfigures and include metadata about their class. + +The following fields are added to each datapoint: +- `is_medical_subfigure`: Whether the subfigure is a medical subfigure. +- `medical_class_rank`: The model's confidence in the medical classification. + +This script preserves all subfigures and adds an `is_medical_subfigure` boolean flag to identify medical subfigures. It also includes a `medical_class_rank` field indicating the model's confidence in the medical classification. +

+ + +## **5. Alignment** +> **Code:** `align.py & align.sh`
+> **Input:** Subfigure metadata in JSONL format (`${num}_subfigures_classified.jsonl`)
+> **Output:** Aligned subfigure metadata in JSONL format (`${num}_aligned.jsonl`)
+ +- Find the label associated with each subfigure. +- If no label is found, it means either: + - The image is a standalone figure (not part of a compound figure) + - The OCR model failed to detect the subfigure label (e.g. "A", "B", etc.) + +The non biomedical subfigures will be removed. The following fields are added to each datapoint: +- `label`: The label associated with the subfigure. (e.g. "Subfigure-A") +- `label_position`: The position of the label in the subfigure. + + +The outputs from steps 3 and 5 contain labeled subcaptions and labeled subfigures respectively. By matching these labels (e.g. "Subfigure-A"), we can create the final subfigure-subcaption pairs. Any cases where labels are missing or captions couldn't be split will be handled in subsequent steps. Refer to notebook for more details. +

diff --git a/openpmcvl/granular/__init__.py b/openpmcvl/granular/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openpmcvl/granular/checkpoints/__init__.py b/openpmcvl/granular/checkpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openpmcvl/granular/config/__init__.py b/openpmcvl/granular/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openpmcvl/granular/config/yolov3_default_subfig.cfg b/openpmcvl/granular/config/yolov3_default_subfig.cfg new file mode 100644 index 0000000..7a879f4 --- /dev/null +++ b/openpmcvl/granular/config/yolov3_default_subfig.cfg @@ -0,0 +1,34 @@ +MODEL: + TYPE: YOLOv3 + BACKBONE: darknet53 + ANCHORS: [[6, 7], [9, 10], [10, 14], + [13, 11], [16, 15], [15, 20], + [21, 19], [24, 24], [34, 31]] + ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + N_CLASSES: 15 +TRAIN: + LR: 0.001 + MOMENTUM: 0.9 + DECAY: 0.0005 + BURN_IN: 1000 + MAXITER: 20000 + STEPS: (400000, 450000) + BATCHSIZE: 4 + SUBDIVISION: 16 + IMGSIZE: 608 + LOSSTYPE: l2 + IGNORETHRE: 0.7 +AUGMENTATION: + RANDRESIZE: True + JITTER: 0.3 + RANDOM_PLACING: True + HUE: 0.1 + SATURATION: 1.5 + EXPOSURE: 1.5 + LRFLIP: False + RANDOM_DISTORT: True +TEST: + CONFTHRE: 0.8 + NMSTHRE: 0.1 + IMGSIZE: 416 +NUM_GPUS: 1 diff --git a/openpmcvl/granular/models/__init__.py b/openpmcvl/granular/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openpmcvl/granular/models/network.py b/openpmcvl/granular/models/network.py new file mode 100644 index 0000000..511d832 --- /dev/null +++ b/openpmcvl/granular/models/network.py @@ -0,0 +1,418 @@ +import torch +from torch import nn +from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def get_model_urls(): + model_urls = { + "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + } + return model_urls + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block, + layers, + num_classes=30, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + # Allow for accessing forward method in a inherited class + forward = _forward + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + model_urls = get_model_urls() + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet( + "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs + ) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet( + "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs + ) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return _resnet( + "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs + ) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet( + "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs + ) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["width_per_group"] = 64 * 2 + return _resnet( + "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs + ) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["width_per_group"] = 64 * 2 + return _resnet( + "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs + ) + + +if __name__ == "__main__": + __all__ = [ + "ResNet", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", + ] diff --git a/openpmcvl/granular/models/process.py b/openpmcvl/granular/models/process.py new file mode 100644 index 0000000..d683c0d --- /dev/null +++ b/openpmcvl/granular/models/process.py @@ -0,0 +1,252 @@ +from __future__ import division + +import cv2 +import numpy as np +import torch + + +def label2yolobox(labels, info_img, maxsize, lrflip): + """ + Transform coco labels to yolo box labels + Args: + labels (numpy.ndarray): label data whose shape is :math:`(N, 5)`. + Each label consists of [class, x1, y1, x2, y2] where \ + class (float): class index. + x1, y1, x2, y2 (float) : coordinates of \ + left-top and right-bottom points of bounding boxes. + Values range from 0 to width or height of the image. + info_img : tuple of h, w, nh, nw, dx, dy. + h, w (int): original shape of the image + nh, nw (int): shape of the resized image without padding + dx, dy (int): pad size + maxsize (int): target image size after pre-processing + lrflip (bool): horizontal flip flag + + Returns + ------- + labels:label data whose size is :math:`(N, 5)`. + Each label consists of [class, xc, yc, w, h] where + class (float): class index. + xc, yc (float) : center of bbox whose values range from 0 to 1. + w, h (float) : size of bbox whose values range from 0 to 1. + """ + h, w, nh, nw, dx, dy = info_img + x1 = labels[:, 1] / w + y1 = labels[:, 2] / h + x2 = (labels[:, 1] + labels[:, 3]) / w + y2 = (labels[:, 2] + labels[:, 4]) / h + labels[:, 1] = (((x1 + x2) / 2) * nw + dx) / maxsize + labels[:, 2] = (((y1 + y2) / 2) * nh + dy) / maxsize + labels[:, 3] *= nw / w / maxsize + labels[:, 4] *= nh / h / maxsize + if lrflip: + labels[:, 1] = 1 - labels[:, 1] + return labels + + +def yolobox2label(box, info_img): + """ + Transform yolo box labels to yxyx box labels. + Args: + box (list): box data with the format of [yc, xc, w, h] + in the coordinate system after pre-processing. + info_img : tuple of h, w, nh, nw, dx, dy. + h, w (int): original shape of the image + nh, nw (int): shape of the resized image without padding + dx, dy (int): pad size + Returns: + label (list): box data with the format of [y1, x1, y2, x2] + in the coordinate system of the input image. + """ + h, w, nh, nw, dx, dy = info_img + y1, x1, y2, x2 = box + box_h = ((y2 - y1) / nh) * h + box_w = ((x2 - x1) / nw) * w + y1 = ((y1 - dy) / nh) * h + x1 = ((x1 - dx) / nw) * w + label = [max(x1, 0), max(y1, 0), min(x1 + box_w, w), min(y1 + box_h, h)] + return label + + +def nms(bbox, thresh, score=None, limit=None): + """Suppress bounding boxes according to their IoUs and confidence scores. + Args: + bbox (array): Bounding boxes to be transformed. The shape is + :math:`(R, 4)`. :math:`R` is the number of bounding boxes. + thresh (float): Threshold of IoUs. + score (array): An array of confidences whose shape is :math:`(R,)`. + limit (int): The upper bound of the number of the output bounding + boxes. If it is not specified, this method selects as many + bounding boxes as possible. + + Returns + ------- + array: + An array with indices of bounding boxes that are selected. \ + They are sorted by the scores of bounding boxes in descending \ + order. \ + The shape of this array is :math:`(K,)` and its dtype is\ + :obj:`numpy.int32`. Note that :math:`K \\leq R`. + + from: https://github.com/chainer/chainercv + """ + if len(bbox) == 0: + return np.zeros((0,), dtype=np.int32) + + if score is not None: + order = score.argsort()[::-1] + bbox = bbox[order] + bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1) + + selec = np.zeros(bbox.shape[0], dtype=bool) + for i, b in enumerate(bbox): + tl = np.maximum(b[:2], bbox[selec, :2]) + br = np.minimum(b[2:], bbox[selec, 2:]) + area = np.prod(br - tl, axis=1) * (tl < br).all(axis=1) + + iou = area / (bbox_area[i] + bbox_area[selec] - area) + if (iou >= thresh).any(): + continue + + selec[i] = True + if limit is not None and np.count_nonzero(selec) >= limit: + break + + selec = np.where(selec)[0] + if score is not None: + selec = order[selec] + return selec.astype(np.int32) + + +def postprocess(prediction, dtype, conf_thre=0.7, nms_thre=0.45): + """ + Postprocess for the output of YOLO model + perform box transformation, specify the class for each detection, + and perform class-wise non-maximum suppression. + Args: + prediction (torch tensor): The shape is :math:`(N, B, 4)`. + :math:`N` is the number of predictions, + :math:`B` the number of boxes. The last axis consists of + :math:`xc, yc, w, h` where `xc` and `yc` represent a center + of a bounding box. + num_classes (int): + number of dataset classes. + conf_thre (float): + confidence threshold ranging from 0 to 1, + which is defined in the config file. + nms_thre (float): + IoU threshold of non-max suppression ranging from 0 to 1. + + Returns + ------- + output (list of torch tensor): + + """ + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + for i, image_pred in enumerate(prediction): + # Filter out confidence scores below threshold + conf_mask = (image_pred[:, 4] >= conf_thre).squeeze() + image_pred = image_pred[conf_mask] + + # If none are remaining => process next image + if not image_pred.size(0): + continue + # Get score and class with highest confidence + class_conf = torch.ones(image_pred[:, 4:5].size()).type(dtype) + class_pred = torch.zeros(image_pred[:, 4:5].size()).type(dtype) + + # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) + detections = torch.cat( + (image_pred[:, :5], class_conf.float(), class_pred.float()), 1 + ) + # Iterate through all predicted classes + unique_labels = detections[:, -1].cpu().unique() + if prediction.is_cuda: + unique_labels = unique_labels.cuda() + for c in unique_labels: + # Get the detections with the particular class + detections_class = detections[detections[:, -1] == c] + nms_in = detections_class.cpu().numpy() + nms_out_index = nms( + nms_in[:, :4], nms_thre, score=nms_in[:, 4] * nms_in[:, 5] + ) + detections_class = detections_class[nms_out_index] + if output[i] is None: + output[i] = detections_class + else: + output[i] = torch.cat((output[i], detections_class)) + + return output + + +def preprocess_mask(mask, imgsize, info_img): + h, w, nh, nw, dx, dy = info_img + sized = np.ones((imgsize, imgsize, 1), dtype=np.uint8) * 127 + mask = cv2.resize(mask, (nw, nh)) + sized[dy : dy + nh, dx : dx + nw, 0] = mask + + return sized + + +def preprocess(img, imgsize, jitter, random_placing=False): + """ + Image preprocess for yolo input + Pad the shorter side of the image and resize to (imgsize, imgsize) + Args: + img (numpy.ndarray): input image whose shape is :math:`(H, W, C)`. + Values range from 0 to 255. + imgsize (int): target image size after pre-processing + jitter (float): amplitude of jitter for resizing + random_placing (bool): if True, place the image at random position + + Returns + ------- + img (numpy.ndarray): input image whose shape is :math:`(C, imgsize, imgsize)`. + Values range from 0 to 1. + info_img : tuple of h, w, nh, nw, dx, dy. + h, w (int): original shape of the image + nh, nw (int): shape of the resized image without padding + dx, dy (int): pad size + """ + h, w, _ = img.shape + img = img[:, :, ::-1] + assert img is not None + + if jitter > 0: + # add jitter + dw = jitter * w + dh = jitter * h + new_ar = (w + np.random.uniform(low=-dw, high=dw)) / ( + h + np.random.uniform(low=-dh, high=dh) + ) + else: + new_ar = w / h + + if new_ar < 1: + nh = imgsize + nw = nh * new_ar + else: + nw = imgsize + nh = nw / new_ar + nw, nh = int(max(nw, 1)), int(max(nh, 1)) + + if random_placing: + dx = int(np.random.uniform(imgsize - nw)) + dy = int(np.random.uniform(imgsize - nh)) + else: + dx = (imgsize - nw) // 2 + dy = (imgsize - nh) // 2 + + img = cv2.resize(img, (nw, nh)) + sized = np.ones((imgsize, imgsize, 3), dtype=np.uint8) * 127 + sized[dy : dy + nh, dx : dx + nw, :] = img + + info_img = (h, w, nh, nw, dx, dy) + return sized, info_img diff --git a/openpmcvl/granular/models/subfigure_detector.py b/openpmcvl/granular/models/subfigure_detector.py new file mode 100644 index 0000000..22380a0 --- /dev/null +++ b/openpmcvl/granular/models/subfigure_detector.py @@ -0,0 +1,215 @@ +import math + +import torch +from einops import repeat +from pytorch_pretrained_bert.modeling import BertModel +from torch import nn +from torchvision import models + +from openpmcvl.granular.models.transformer_module import * + + +class FigCap_Former(nn.Module): + def __init__( + self, + num_query=50, + num_encoder_layers=6, + num_decoder_layers=6, + feature_dim=256, + atten_head_num=8, + mlp_ratio=4, + dropout=0.0, + activation="relu", + alignment_network=False, + bert_path="/remote-home/zihengzhao/CompoundFigure/medicat/code/pretrained_model/PubMed_BERT", + num_text_decoder_layers=6, + text_atten_head_num=8, + text_mlp_ratio=4, + text_dropout=0.0, + text_activation="relu", + resnet=34, + resnet_pretrained=False, + ): + super().__init__() + # Followings are modules for fig detection + if resnet == 18: + self.img_embed = nn.Sequential( + *list(models.resnet18(pretrained=resnet_pretrained).children())[:8] + ).cuda() + self.img_channel_squeeze = nn.Conv2d(512, feature_dim, 1) + elif resnet == 34: + self.img_embed = nn.Sequential( + *list(models.resnet34(pretrained=resnet_pretrained).children())[:8] + ).cuda() + self.img_channel_squeeze = nn.Conv2d(512, feature_dim, 1) + elif resnet == 50: + self.img_embed = nn.Sequential( + *list(models.resnet50(pretrained=resnet_pretrained).children())[:8] + ).cuda() + self.img_channel_squeeze = nn.Conv2d(2048, feature_dim, 1) + else: + print("ResNet Error: Unsupported Version ResNet%d" % resnet) + exit() + self.pos_embed = PositionEncoding(num_pos_feats=feature_dim) + + encoder_layer = TransformerEncoderLayer( + feature_dim, atten_head_num, mlp_ratio * feature_dim, dropout, activation + ) + self.img_encoder = TransformerEncoder(encoder_layer, num_encoder_layers) + + self.query = nn.Parameter(torch.rand(num_query, feature_dim)) + decoder_layer = TransformerDecoderLayer( + feature_dim, atten_head_num, mlp_ratio * feature_dim, dropout, activation + ) + self.img_decoder = TransformerDecoder( + decoder_layer=decoder_layer, num_layers=num_decoder_layers + ) + + self.box_head = nn.Sequential( + nn.Linear(feature_dim, feature_dim), + nn.ReLU(inplace=True), + nn.Linear(feature_dim, 4), + nn.Sigmoid(), + ) + self.det_class_head = nn.Sequential(nn.Linear(feature_dim, 1), nn.Sigmoid()) + + # Followings are modules for fig-cap alignment + self.alignment_network = alignment_network # exclude alignment modules(BERT) to allow multi-gpu acceleration + if self.alignment_network: + self.text_embed = BertModel.from_pretrained(bert_path) + + self.text_channel_squeeze = nn.Sequential( + nn.Linear(768, 768), + nn.ReLU(inplace=True), + nn.Dropout(p=dropout), + nn.Linear(768, feature_dim), + nn.Dropout(p=dropout), + ) + + text_decoder_layer = TransformerDecoderLayer( + feature_dim, + text_atten_head_num, + text_mlp_ratio * feature_dim, + text_dropout, + text_activation, + ) + self.text_decoder = TransformerDecoder( + decoder_layer=text_decoder_layer, num_layers=num_text_decoder_layers + ) + + self.simi_head = nn.Sequential( + nn.Linear(feature_dim * 2, feature_dim), + nn.ReLU(inplace=True), + nn.Linear(feature_dim, feature_dim), + nn.ReLU(inplace=True), + nn.Linear(feature_dim, 1), + nn.Sigmoid(), + ) + + self.img_proj = nn.Parameter(torch.rand(feature_dim, feature_dim)) + + def forward(self, images, texts): + """ + 1. Detect the subfigures (nonobject/object binary classification + box coordinates linear regression) + 2. Align captions to each of the detection output + + Args: + images (compound figure): shape (bs, c, h, w) + texts (caption tokens): shape (bs, max_length_in_this_batch) + + Returns + ------- + output_det_class: tensor (bs, query_num, 1), 0~1 indicate subfigure or no-subfigure + output_box: tensor (bs, query_num, 4), prediction of [cx, cy, w, h] + similarity: tensor (bs, query_num, caption_length), 0~1 indicate belong or not belong to the subfigure + """ + # Img Embed + x = self.img_embed(images) # (bs, 2048, h/32, w/32) + x = self.img_channel_squeeze(x) # (bs, 256, h/32, w/32) + + pos = self.pos_embed( + x.shape[0], x.shape[2], x.shape[3], x.device + ) # (bs, 256, h/32, w/32) + x = x + pos + x = x.view(x.shape[0], x.shape[1], -1) # (bs, 256, (w*h)/(32*32)) + x = x.transpose(1, 2) # (bs, (w*h)/(32*32), 256) + + # Detect + x = self.img_encoder(x) # (bs, (w*h)/(32*32), 256) + query = repeat(self.query, "l d -> bs l d", bs=x.shape[0]) # (bs, 50, 256) + query, _ = self.img_decoder(x, query) # (bs, 50, 256) + + output_det_class = self.det_class_head(query) # (bs, 50, 1) + output_box = self.box_head(query) # (bs, 50, 4) + + # Text Embed + if self.alignment_network: + t = self.text_embed(texts)[0][-1] # (bs, l, 768) + t = self.text_channel_squeeze(t) # (bs, l, 256) + + # Align + query = query @ self.img_proj # (bs, 50, 256) + t, _ = self.text_decoder(query, t) # (bs, l, 256) + + query = query.unsqueeze(2).repeat(1, 1, t.shape[1], 1) # (bs, 50, l, 256) + t = t.unsqueeze(1).repeat(1, query.shape[1], 1, 1) # (bs, 50, l, 256) + similarity = torch.cat((query, t), -1) # (bs, 50, l, 512) + similarity = self.simi_head(similarity).squeeze(-1) # (bs, 50, l) + else: + # We wont use similarity, set to 0 to make code faster + similarity = ( + 0 # torch.zeros(query.shape[0], query.shape[1], texts.shape[-1]).cuda() + ) + + return output_det_class, output_box, similarity + + +class PositionEncoding(nn.Module): + def __init__( + self, normalize=True, scale=100.0, num_pos_feats=256, temperature=10000 + ): + super().__init__() + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, bs, h, w, device): + # Input is b,c,h,w + mask = torch.ones(bs, h, w, device=device) + # Since image is 2D, position encoding is split into x,y directions + # 1 1 1 1 .. 2 2 2 2... 3 3 3... + y_embed = mask.cumsum( + 1, dtype=torch.float32 + ) # (b, h, w) the 'y-index' of each position + # 1 2 3 4 ... 1 2 3 4... + x_embed = mask.cumsum( + 2, dtype=torch.float32 + ) # (b, h, w) the 'x-index' of each position + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + # num_pos_feats = 128 + # 0~127 self.num_pos_feats=128, since input vector is 256, encoding is half sin, half cos + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + # Output shape=b,h,w,128 + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + # Each feature map position is encoded as a 256d vector, first 128d for y direction, last 128d for x direction + return pos # (b,n=256,h,w) diff --git a/openpmcvl/granular/models/subfigure_ocr.py b/openpmcvl/granular/models/subfigure_ocr.py new file mode 100644 index 0000000..a470b83 --- /dev/null +++ b/openpmcvl/granular/models/subfigure_ocr.py @@ -0,0 +1,183 @@ +import os + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import yaml +from PIL import Image +from skimage import io +from torch.autograd import Variable + +from openpmcvl.granular.models.network import resnet152 +from openpmcvl.granular.models.process import postprocess, preprocess, yolobox2label +from openpmcvl.granular.models.yolov3 import YOLOv3 + + +class classifier: + def __init__(self): + self.current_dir = os.path.dirname(os.path.abspath(__file__)) + configuration_file = os.path.join( + self.current_dir, "..", "config", "yolov3_default_subfig.cfg" + ) + + with open(configuration_file, "r") as f: + configuration = yaml.load(f, Loader=yaml.FullLoader) + + self.image_size = configuration["TEST"]["IMGSIZE"] + self.nms_threshold = configuration["TEST"]["NMSTHRE"] + self.confidence_threshold = 0.0001 + self.dtype = torch.cuda.FloatTensor + self.device = torch.device("cuda") + + object_detection_model = YOLOv3(configuration["MODEL"]) + self.object_detection_model = self.load_model_from_checkpoint( + object_detection_model, "object_detection_model.pt" + ) + ## Load text recognition model + text_recognition_model = resnet152() + self.text_recognition_model = self.load_model_from_checkpoint( + text_recognition_model, "text_recognition_model.pt" + ) + + self.object_detection_model.eval() + self.text_recognition_model.eval() + + def load_model_from_checkpoint(self, model, model_name): + """Load checkpoint weights into model""" + checkpoints_path = os.path.join(self.current_dir, "..", "checkpoints") + checkpoint = os.path.join(checkpoints_path, model_name) + model.load_state_dict(torch.load(checkpoint)) + model.to(self.device) + return model + + def detect_subfigure_boundaries(self, figure_path): + """Detects the bounding boxes of subfigures in figure_path + + Args: + figure_path: A string, path to an image of a figure + from a scientific journal + Returns: + subfigure_info (list of lists): Each inner list is + x1, y1, x2, y2, confidence + """ + ## Preprocess the figure for the models + img = io.imread(figure_path) + if len(np.shape(img)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) + + img, info_img = preprocess(img, self.image_size, jitter=0) + img = np.transpose(img / 255.0, (2, 0, 1)) + img = np.copy(img) + img = torch.from_numpy(img).float().unsqueeze(0) + img = Variable(img.type(self.dtype)) + + img_raw = Image.open(figure_path).convert("RGB") + width, height = img_raw.size + + ## Run model on figure + with torch.no_grad(): + outputs = self.object_detection_model(img.to(self.device)) + outputs = postprocess( + outputs, + dtype=self.dtype, + conf_thre=self.confidence_threshold, + nms_thre=self.nms_threshold, + ) + + ## Reformat model outputs to display bounding boxes in our desired format + ## List of lists where each inner list is [x1, y1, x2, y2, confidence] + subfigure_info = list() + + if outputs[0] is None: + return subfigure_info + + for x1, y1, x2, y2, conf, cls_conf, cls_pred in outputs[0]: + box = yolobox2label( + [ + y1.data.cpu().numpy(), + x1.data.cpu().numpy(), + y2.data.cpu().numpy(), + x2.data.cpu().numpy(), + ], + info_img, + ) + box[0] = int(min(max(box[0], 0), width - 1)) + box[1] = int(min(max(box[1], 0), height - 1)) + box[2] = int(min(max(box[2], 0), width)) + box[3] = int(min(max(box[3], 0), height)) + # ensures no extremely small (likely incorrect) boxes are counted + small_box_threshold = 5 + if ( + box[2] - box[0] > small_box_threshold + and box[3] - box[1] > small_box_threshold + ): + box.append("%.3f" % (cls_conf.item())) + subfigure_info.append(box) + return subfigure_info + + def detect_subfigure_labels(self, figure_path, subfigure_info): + """Uses text recognition to read subfigure labels from figure_path + + Note: + To get sensible results, should be run only after + detect_subfigure_boundaries has been run + Args: + figure_path (str): A path to the image (.png, .jpg, or .gif) + file containing the article figure + subfigure_info (list of lists): Details about bounding boxes + of each subfigure from detect_subfigure_boundaries(). Each + inner list has format [x1, y1, x2, y2, confidence] where + x1, y1 are upper left bounding box coordinates as ints, + x2, y2, are lower right, and confidence the models confidence + Returns: + subfigure_info (list of tuples): Details about bounding boxes and + labels of each subfigure in figure. Tuples for each subfigure are + (x1, y1, x2, y2, label) where x1, y1 are upper left x and y coord + divided by image width/height and label is the an integer n + meaning the label is the nth letter + concate_img (np.ndarray): A numpy array representing the figure. + Used in classify_subfigures. Ideally this will be removed to + increase modularity. + """ + img_raw = Image.open(figure_path).convert("RGB") + img_raw = img_raw.copy() + width, height = img_raw.size + binary_img = np.zeros((height, width, 1)) + + detected_label_and_bbox = None + max_confidence = 0.0 + for subfigure in subfigure_info: + ## Preprocess the image for the model + bbox = tuple(subfigure[:4]) + img_patch = img_raw.crop(bbox) + img_patch = np.array(img_patch)[:, :, ::-1] + img_patch, _ = preprocess(img_patch, 28, jitter=0) + img_patch = np.transpose(img_patch / 255.0, (2, 0, 1)) + img_patch = torch.from_numpy(img_patch).type(self.dtype).unsqueeze(0) + + ## Run model on figure + label_prediction = self.text_recognition_model(img_patch.to(self.device)) + label_confidence = np.amax( + F.softmax(label_prediction, dim=1).data.cpu().numpy() + ) + x1, y1, x2, y2, box_confidence = subfigure + total_confidence = float(box_confidence) * label_confidence + if total_confidence < max_confidence: + continue + label_value = chr( + label_prediction.argmax(dim=1).data.cpu().numpy()[0] + ord("a") + ) + if label_value == "z": + continue + detected_label_and_bbox = [label_value, x1, y1, x2, y2] + + return detected_label_and_bbox + + def run(self, figure_path): + subfigure_info = self.detect_subfigure_boundaries(figure_path) + subfigure_info = self.detect_subfigure_labels(figure_path, subfigure_info) + + return subfigure_info diff --git a/openpmcvl/granular/models/transformer_module.py b/openpmcvl/granular/models/transformer_module.py new file mode 100644 index 0000000..d4deedc --- /dev/null +++ b/openpmcvl/granular/models/transformer_module.py @@ -0,0 +1,208 @@ +import copy + +import torch +import torch.nn.functional as F +from torch import nn + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module""" + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention( + temperature=d_k**0.5, attn_dropout=dropout + ) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, q, k, v, mask=None): + residual = q + + q = self.layer_norm(q) + k = self.layer_norm(k) + v = self.layer_norm(v) + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) # For head axis broadcasting. + + q, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q = self.dropout(self.fc(q)) + q += residual + + return q, attn + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention""" + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu" + ): + super().__init__() + self.self_attn = MultiHeadAttention( + nhead, d_model, d_k=d_model // nhead, d_v=d_model // nhead, dropout=dropout + ) # 内含 norm + atten + dropout + residual + + # Implementation of Feedforward model + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + _get_activation_md(activation), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + nn.Dropout(dropout), + ) + + self.norm = nn.LayerNorm(d_model) + + def forward(self, src): + q = k = src + src = self.self_attn(q, k, src)[0] + + src2 = self.norm(src) + src2 = self.ffn(src2) + src = src + src2 + + return src + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu" + ): + super().__init__() + + self.self_attn = MultiHeadAttention( + nhead, d_model, d_k=d_model // nhead, d_v=d_model // nhead, dropout=dropout + ) + + self.cross_attn = MultiHeadAttention( + nhead, d_model, d_k=d_model // nhead, d_v=d_model // nhead, dropout=dropout + ) + + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + _get_activation_md(activation), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + nn.Dropout(dropout), + ) + + self.norm = nn.LayerNorm(d_model) + + def forward(self, tgt, memory): + tgt = self.cross_attn(tgt, memory, memory)[0] + + tgt = self.self_attn(tgt, tgt, tgt)[0] + + tgt2 = self.norm(tgt) + tgt2 = self.ffn(tgt2) + tgt = tgt + tgt2 + + return tgt + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, src): + output = src + + for layer in self.layers: + output = layer(output) # (bs, patch_num, feature_dim) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + + def forward(self, encoder_memory, query, return_intermedia=False): + query_output = query + intermedia = [] + + for i in range(len(self.layers)): + query_output = self.layers[i]( + query_output, encoder_memory + ) # (bs, query_num, feature_dim) + if return_intermedia: + intermedia.append(query_output) + + return query_output, intermedia + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def _get_activation_md(activation): + """Return an activation function given a string""" + if activation == "relu": + return nn.ReLU() + if activation == "gelu": + return nn.GELU() + if activation == "glu": + return nn.GLU() + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/openpmcvl/granular/models/yolo_layer.py b/openpmcvl/granular/models/yolo_layer.py new file mode 100644 index 0000000..e7c48b6 --- /dev/null +++ b/openpmcvl/granular/models/yolo_layer.py @@ -0,0 +1,705 @@ +import warnings + +import numpy as np +import torch +from torch import nn + +from openpmcvl.granular.models.network import resnet152 +from openpmcvl.granular.models.process import preprocess + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + """Calculate the Intersection of Unions (IoUs) between bounding boxes. + IoU is calculated as a ratio of area of the intersection + and area of the union. + + Args: + bbox_a (array): An array whose shape is :math:`(N, 4)`. + :math:`N` is the number of bounding boxes. + The dtype should be :obj:`numpy.float32`. + bbox_b (array): An array similar to :obj:`bbox_a`, + whose shape is :math:`(K, 4)`. + The dtype should be :obj:`numpy.float32`. + + Returns + ------- + array: + An array whose shape is :math:`(N, K)`. \ + An element at index :math:`(n, k)` contains IoUs between \ + :math:`n` th bounding box in :obj:`bbox_a` and :math:`k` th bounding \ + box in :obj:`bbox_b`. + + from: https://github.com/chainer/chainercv + """ + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError + + # top left + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + # bottom right + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + # bottom right + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) + + +class YOLOLayer(nn.Module): + """ + detection layer corresponding to yolo_layer.c of darknet + """ + + def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7): + """ + Args: + config_model (dict) : model configuration. + ANCHORS (list of tuples) : + ANCH_MASK: (list of int list): index indicating the anchors to be + used in YOLO layers. One of the mask group is picked from the list. + N_CLASSES (int): number of classes + layer_no (int): YOLO layer number - one from (0, 1, 2). + in_ch (int): number of input channels. + ignore_thre (float): threshold of IoU above which objectness training is ignored. + """ + super(YOLOLayer, self).__init__() + strides = [32, 16, 8] # fixed + self.anchors = config_model["ANCHORS"] + self.n_anchors = len(self.anchors) + self.n_classes = config_model["N_CLASSES"] + self.ignore_thre = ignore_thre + self.l2_loss = nn.MSELoss(reduction="sum") + self.bce_loss = nn.BCELoss(reduction="sum") + + self.stride = strides[layer_no] + self.all_anchors_grid = [ + (w / self.stride, h / self.stride) for w, h in self.anchors + ] + self.masked_anchors = self.all_anchors_grid + self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4)) + self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid) + self.ref_anchors = torch.FloatTensor(self.ref_anchors) + self.conv = nn.Conv2d( + in_channels=in_ch, + out_channels=self.n_anchors * 5, + kernel_size=1, + stride=1, + padding=0, + ) + self.classifier_model = resnet152() + + def forward(self, xin, compound_labels=None): + """ + In this + Args: + xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \ + where N, C, H, W denote batchsize, channel width, height, width respectively. + labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \ + N and K denote batchsize and number of labels. + Each label consists of [class, xc, yc, w, h]: + class (float): class index. + xc, yc (float) : center of bbox whose values range from 0 to 1. + w, h (float) : size of bbox whose values range from 0 to 1. + + Returns + ------- + loss (torch.Tensor): total loss - the target of backprop. + loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \ + with boxsize-dependent weights. + loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \ + with boxsize-dependent weights. + loss_obj (torch.Tensor): objectness loss - calculated by BCE. + loss_cls (torch.Tensor): classification loss - calculated by BCE for each class. + loss_l2 (torch.Tensor): total l2 loss - only for logging. + """ + output = self.conv(xin) + + batchsize = output.shape[0] + fsize = output.shape[2] + n_ch = 5 + dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor + + output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize) + output = output.permute(0, 1, 3, 4, 2) # .contiguous() + + # logistic activation for xy, obj, cls + output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(output[..., np.r_[:2, 4:n_ch]]) + + # Suppresses incorrect UserWarning about a non-writeable Numpy array + # PR with fix accepted shortly after torch 1.7.1 release + # https://github.com/pytorch/pytorch/pull/47271 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # calculate pred - xywh obj cls + x_shift = dtype( + np.broadcast_to(np.arange(fsize, dtype=np.float32), output.shape[:4]) + ) + y_shift = dtype( + np.broadcast_to( + np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4] + ) + ) + + masked_anchors = np.array(self.masked_anchors) + + w_anchors = dtype( + np.broadcast_to( + np.reshape(masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), + output.shape[:4], + ) + ) + h_anchors = dtype( + np.broadcast_to( + np.reshape(masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), + output.shape[:4], + ) + ) + + pred = output.clone() + pred[..., 0] += x_shift + pred[..., 1] += y_shift + pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors + pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors + + if compound_labels is None: # not training + pred[..., :4] *= self.stride + return pred.reshape(batchsize, -1, n_ch).data + + pred = pred[..., :4].data + + # target assignment + + tgt_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 4).type(dtype) + obj_mask = torch.ones(batchsize, self.n_anchors, fsize, fsize).type(dtype) + tgt_scale = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 2).type(dtype) + + target = torch.zeros(batchsize, self.n_anchors, fsize, fsize, n_ch).type(dtype) + + labels, imgs = compound_labels + imgs = imgs.data.cpu().numpy() + labels = labels.cpu().data + nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects + + truth_x_all = labels[:, :, 1] * fsize + truth_y_all = labels[:, :, 2] * fsize + truth_w_all = labels[:, :, 3] * fsize + truth_h_all = labels[:, :, 4] * fsize + truth_i_all = truth_x_all.to(torch.int16).numpy() + truth_j_all = truth_y_all.to(torch.int16).numpy() + + for b in range(batchsize): + n = int(nlabel[b]) + if n == 0: + continue + img = imgs[b].transpose((1, 2, 0))[:, :, ::-1] + truth_box = dtype(np.zeros((n, 4))) + truth_box[:n, 2] = truth_w_all[b, :n] + truth_box[:n, 3] = truth_h_all[b, :n] + truth_i = truth_i_all[b, :n] + truth_j = truth_j_all[b, :n] + + # calculate iou between truth and reference anchors + anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors) + best_n_all = np.argmax(anchor_ious_all, axis=1) + + truth_box[:n, 0] = truth_x_all[b, :n] + truth_box[:n, 1] = truth_y_all[b, :n] + + pred_ious = bboxes_iou(pred[b].view(-1, 4), truth_box, xyxy=False) + pred_best_iou, pred_best_iou_index = pred_ious.max(dim=1) + pred_best_iou = pred_best_iou.view(pred[b].shape[:3]) + not_obj_mask = pred_best_iou < 0.3 + + is_obj_mask = pred_best_iou > 0.7 + pred_best_iou_index = pred_best_iou_index.view(pred[b].shape[:3]) + + obj_mask[b] = (not_obj_mask + is_obj_mask).type(dtype) + + # encourage bbox with over 0.7 IOU + rp_index = np.nonzero(is_obj_mask) + for rp_i in range(rp_index.size()[0]): + rp_anchor, rp_i, rp_j = rp_index[rp_i] + truth_box_index = int(pred_best_iou_index[rp_anchor, rp_i, rp_j]) + target[b, rp_anchor, rp_i, rp_j, 4] = 1 + + # target label for the bbox + reference_label = int(labels[b, truth_box_index, 0]) + + self.classifier_model.eval() + + pred_x, pred_y, pred_w, pred_h = ( + pred[b, rp_anchor, rp_i, rp_j, :4] * self.stride + ) + x1 = int(min(max(pred_x - pred_w / 2, 0), img.shape[0] - 1)) + y1 = int(min(max(pred_y - pred_h / 2, 0), img.shape[0] - 1)) + x2 = int(min(max(pred_x + pred_w / 2, 0), img.shape[0] - 1)) + y2 = int(min(max(pred_y + pred_h / 2, 0), img.shape[0] - 1)) + + if (x1 + 2 < x2) and (y1 + 2 < y2): + patch = np.uint8(255 * img[y1:y2, x1:x2]) + patch, _ = preprocess(patch, 28, jitter=0) + patch = np.transpose(patch / 255.0, (2, 0, 1)) + patch = torch.from_numpy(patch).unsqueeze(0).type(dtype) + pred_label = int( + self.classifier_model(patch).argmax(dim=1).data.cpu().numpy()[0] + ) + + if pred_label == reference_label: + target[b, rp_anchor, rp_i, rp_j, 4] = 1 + else: + target[b, rp_anchor, rp_i, rp_j, 4] = 0 + else: + target[b, rp_anchor, rp_i, rp_j, 4] = 0 + + for ti in range(best_n_all.shape[0]): + i, j = truth_i[ti], truth_j[ti] + a = best_n_all[ti] + obj_mask[b, a, j, i] = 1 + tgt_mask[b, a, j, i, :] = 1 + target[b, a, j, i, 0] = truth_x_all[b, ti] - i + target[b, a, j, i, 1] = truth_y_all[b, ti] - j + target[b, a, j, i, 2] = torch.log( + truth_w_all[b, ti] / w_anchors[b, a, j, i] + 1e-16 + ) + target[b, a, j, i, 3] = torch.log( + truth_h_all[b, ti] / h_anchors[b, a, j, i] + 1e-16 + ) + target[b, a, j, i, 4] = 1 + tgt_scale[b, a, j, i, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + # loss calculation + + output[..., 4] *= obj_mask + output[..., np.r_[0:4]] *= tgt_mask + output[..., 2:4] *= tgt_scale + + target[..., 4] *= obj_mask + target[..., np.r_[0:4]] *= tgt_mask + target[..., 2:4] *= tgt_scale + + bceloss = nn.BCELoss( + weight=tgt_scale * tgt_scale, reduction="sum" + ) # weighted BCEloss + loss_xy = bceloss(output[..., :2], target[..., :2]) + loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2 + loss_obj = self.bce_loss(output[..., 4], target[..., 4]) + loss_cls = 0 # self.bce_loss(output[..., 5:], target[..., 5:]) + loss_l2 = self.l2_loss(output, target) + + loss = loss_xy + loss_wh + loss_obj + loss_cls + + return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 + + +class YOLOimgLayer(nn.Module): + """ + detection layer corresponding to yolo_layer.c of darknet + """ + + def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7): + """ + Args: + config_model (dict) : model configuration. + ANCHORS (list of tuples) : + ANCH_MASK: (list of int list): index indicating the anchors to be + used in YOLO layers. One of the mask group is picked from the list. + N_CLASSES (int): number of classes + layer_no (int): YOLO layer number - one from (0, 1, 2). + in_ch (int): number of input channels. + ignore_thre (float): threshold of IoU above which objectness training is ignored. + """ + super(YOLOimgLayer, self).__init__() + strides = [32, 16, 8] # fixed + self.anchors = config_model["ANCHORS"] + self.n_anchors = len(self.anchors) + self.n_classes = config_model["N_CLASSES"] + self.ignore_thre = ignore_thre + self.l2_loss = nn.MSELoss(reduction="sum") + self.bce_loss = nn.BCELoss(reduction="sum") + self.stride = strides[layer_no] + self.all_anchors_grid = [ + (w / self.stride, h / self.stride) for w, h in self.anchors + ] + self.masked_anchors = self.all_anchors_grid + self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4)) + self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid) + self.ref_anchors = torch.FloatTensor(self.ref_anchors) + self.conv = nn.Conv2d( + in_channels=in_ch, + out_channels=self.n_anchors * (self.n_classes + 5), + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, xin, all_labels=None): + """ + In this + Args: + xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \ + where N, C, H, W denote batchsize, channel width, height, width respectively. + labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \ + N and K denote batchsize and number of labels. + Each label consists of [class, xc, yc, w, h]: + class (float): class index. + xc, yc (float) : center of bbox whose values range from 0 to 1. + w, h (float) : size of bbox whose values range from 0 to 1. + + Returns + ------- + loss (torch.Tensor): total loss - the target of backprop. + loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \ + with boxsize-dependent weights. + loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \ + with boxsize-dependent weights. + loss_obj (torch.Tensor): objectness loss - calculated by BCE. + loss_cls (torch.Tensor): classification loss - calculated by BCE for each class. + loss_l2 (torch.Tensor): total l2 loss - only for logging. + """ + output = self.conv(xin) + labels, prior_labels = all_labels + + batchsize = output.shape[0] + fsize = output.shape[2] + n_ch = 5 + self.n_classes + dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor + + output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize) + output = output.permute(0, 1, 3, 4, 2) # .contiguous() + + # logistic activation for xy, obj, cls + output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(output[..., np.r_[:2, 4:n_ch]]) + + # calculate pred - xywh obj cls + + x_shift = dtype( + np.broadcast_to(np.arange(fsize, dtype=np.float32), output.shape[:4]) + ) + y_shift = dtype( + np.broadcast_to( + np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4] + ) + ) + + masked_anchors = np.array(self.masked_anchors) + + w_anchors = dtype( + np.broadcast_to( + np.reshape(masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), + output.shape[:4], + ) + ) + h_anchors = dtype( + np.broadcast_to( + np.reshape(masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), + output.shape[:4], + ) + ) + + pred = output.clone() + pred[..., :2] -= 0.5 + pred[..., 0] *= w_anchors + pred[..., 1] *= h_anchors + pred[..., 0] += x_shift + pred[..., 1] += y_shift + pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors + pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors + + prior_labels = prior_labels.cpu().data + nprior_label = (prior_labels.sum(dim=2) > 0).sum(dim=1) + truth_x_all_sub = prior_labels[:, :, 1] * fsize + truth_y_all_sub = prior_labels[:, :, 2] * fsize + truth_i_all_sub = truth_x_all_sub.to(torch.int16).numpy() + truth_j_all_sub = truth_y_all_sub.to(torch.int16).numpy() + + if labels is None: # not training + pred[..., :4] *= self.stride + return pred.data + + pred = pred[..., :4].data + + # target assignment + tgt_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize, n_ch).type( + dtype + ) + in_grid_distance = torch.zeros(batchsize, 80, 2).type(dtype) + tgt_scale = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 2).type(dtype) + target = torch.zeros(batchsize, self.n_anchors, fsize, fsize, n_ch).type(dtype) + + labels = labels.cpu().data + nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects + + truth_x_all = labels[:, :, 1] * fsize + truth_y_all = labels[:, :, 2] * fsize + truth_w_all = labels[:, :, 3] * fsize + truth_h_all = labels[:, :, 4] * fsize + + for b in range(batchsize): + n = int(nlabel[b]) + if n == 0: + continue + truth_box = dtype(np.zeros((n, 4))) + truth_box[:n, 2] = truth_w_all[b, :n] + truth_box[:n, 3] = truth_h_all[b, :n] + truth_i = truth_i_all_sub[b, :n] + truth_j = truth_j_all_sub[b, :n] + + # calculate iou between truth and reference anchors + anchor_ious_all = bboxes_iou(truth_box, (self.ref_anchors).type(dtype)) + best_n_all = torch.argmax(anchor_ious_all, dim=1) + + truth_box[:n, 0] = truth_x_all[b, :n] + truth_box[:n, 1] = truth_y_all[b, :n] + + for ti in range(n): + i, j = truth_i[ti], truth_j[ti] + + # find box with iou over 0.7 and under 0.3 (achor point) + current_truth_box = truth_box[ti : ti + 1] + current_pred_boxes = pred[b, :, j, i, :4] + pred_ious = bboxes_iou( + current_truth_box, current_pred_boxes, xyxy=False + ) + good_anchor_index = torch.nonzero((pred_ious > 0.7)[0]).cpu().numpy() + bad_anchor_index = torch.nonzero((pred_ious < 0.3)[0]).cpu().numpy() + for good_i in range(len(good_anchor_index)): + a = good_anchor_index[good_i] + tgt_mask[b, a, j, i, :] = 1 + target[b, a, j, i, 0] = torch.clamp( + (truth_x_all[b, ti] - i) + / torch.Tensor(self.masked_anchors)[a, 0] + + 0.5, + 0, + 1, + ) + target[b, a, j, i, 1] = torch.clamp( + (truth_y_all[b, ti] - j) + / torch.Tensor(self.masked_anchors)[a, 1] + + 0.5, + 0, + 1, + ) + target[b, a, j, i, 2] = torch.log( + truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 0] + + 1e-16 + ) + target[b, a, j, i, 3] = torch.log( + truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 1] + + 1e-16 + ) + target[b, a, j, i, 4] = 1 + target[b, a, j, i, 5 + labels[b, ti, 0].to(torch.int16).numpy()] = 1 + tgt_scale[b, a, j, i, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + i_best = min( + max(int(pred[b, a, j, i, 0].cpu().numpy()), 0), fsize - 1 + ) + j_best = min( + max(int(pred[b, a, j, i, 1].cpu().numpy()), 0), fsize - 1 + ) + current_pred_boxes_2 = pred[b, :, j_best, i_best, :4] + pred_ious_2 = bboxes_iou( + current_truth_box, current_pred_boxes_2, xyxy=False + ) + good_anchor_index_2 = ( + torch.nonzero((pred_ious_2 > 0.7)[0]).cpu().numpy() + ) + bad_anchor_index_2 = ( + torch.nonzero((pred_ious_2 < 0.3)[0]).cpu().numpy() + ) + + for good_i_2 in range(len(good_anchor_index_2)): + a = good_anchor_index_2[good_i_2] + tgt_mask[b, a, j_best, i_best, :] = 1 + target[b, a, j_best, i_best, 0] = torch.clamp( + (truth_x_all[b, ti] - i_best) + / torch.Tensor(self.masked_anchors)[a, 0] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 1] = torch.clamp( + (truth_y_all[b, ti] - j_best) + / torch.Tensor(self.masked_anchors)[a, 1] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 2] = torch.log( + truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 0] + + 1e-16 + ) + target[b, a, j_best, i_best, 3] = torch.log( + truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 1] + + 1e-16 + ) + target[b, a, j_best, i_best, 4] = 1 + target[ + b, + a, + j_best, + i_best, + 5 + labels[b, ti, 0].to(torch.int16).numpy(), + ] = 1 + tgt_scale[b, a, j_best, i_best, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + for bad_i_2 in range(len(bad_anchor_index_2)): + a = bad_anchor_index_2[bad_i_2] + tgt_mask[b, a, j_best, i_best, 4:] = 1 + + for bad_i in range(len(bad_anchor_index)): + a = bad_anchor_index[bad_i] + tgt_mask[b, a, j, i, 4:] = 1 + + # best anchor box + a = best_n_all[ti] + tgt_mask[b, a, j, i, :] = 1 + target[b, a, j, i, 0] = torch.clamp( + (truth_x_all[b, ti] - i) / torch.Tensor(self.masked_anchors)[a, 0] + + 0.5, + 0, + 1, + ) + target[b, a, j, i, 1] = torch.clamp( + (truth_y_all[b, ti] - j) / torch.Tensor(self.masked_anchors)[a, 1] + + 0.5, + 0, + 1, + ) + target[b, a, j, i, 2] = torch.log( + truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 0] + 1e-16 + ) + target[b, a, j, i, 3] = torch.log( + truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 1] + 1e-16 + ) + target[b, a, j, i, 4] = 1 + target[b, a, j, i, 5 + labels[b, ti, 0].to(torch.int16).numpy()] = 1 + tgt_scale[b, a, j, i, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + i_best = min(max(int(pred[b, a, j, i, 0].cpu().numpy()), 0), fsize - 1) + j_best = min(max(int(pred[b, a, j, i, 1].cpu().numpy()), 0), fsize - 1) + + # find box with iou over 0.7 and under 0.3 (predict center) + current_truth_box = truth_box[ti : ti + 1] + current_pred_boxes = pred[b, :, j_best, i_best, :4] + pred_ious = bboxes_iou( + current_truth_box, current_pred_boxes, xyxy=False + ) + good_anchor_index = torch.nonzero((pred_ious > 0.7)[0]).cpu().numpy() + bad_anchor_index = torch.nonzero((pred_ious < 0.3)[0]).cpu().numpy() + + for good_i in range(len(good_anchor_index)): + a = good_anchor_index[good_i] + tgt_mask[b, a, j_best, i_best, :] = 1 + target[b, a, j_best, i_best, 0] = torch.clamp( + (truth_x_all[b, ti] - i_best) + / torch.Tensor(self.masked_anchors)[a, 0] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 1] = torch.clamp( + (truth_y_all[b, ti] - j_best) + / torch.Tensor(self.masked_anchors)[a, 1] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 2] = torch.log( + truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 0] + + 1e-16 + ) + target[b, a, j_best, i_best, 3] = torch.log( + truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 1] + + 1e-16 + ) + target[b, a, j_best, i_best, 4] = 1 + target[ + b, + a, + j_best, + i_best, + 5 + labels[b, ti, 0].to(torch.int16).numpy(), + ] = 1 + tgt_scale[b, a, j_best, i_best, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + for bad_i in range(len(bad_anchor_index)): + a = bad_anchor_index[bad_i] + tgt_mask[b, a, j_best, i_best, 4:] = 1 + + a = best_n_all[ti] + tgt_mask[b, a, j_best, i_best, :] = 1 + target[b, a, j_best, i_best, 0] = torch.clamp( + (truth_x_all[b, ti] - i_best) + / torch.Tensor(self.masked_anchors)[a, 0] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 1] = torch.clamp( + (truth_y_all[b, ti] - j_best) + / torch.Tensor(self.masked_anchors)[a, 1] + + 0.5, + 0, + 1, + ) + target[b, a, j_best, i_best, 2] = torch.log( + truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 0] + 1e-16 + ) + target[b, a, j_best, i_best, 3] = torch.log( + truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[a, 1] + 1e-16 + ) + target[b, a, j_best, i_best, 4] = 1 + target[ + b, a, j_best, i_best, 5 + labels[b, ti, 0].to(torch.int16).numpy() + ] = 1 + tgt_scale[b, a, j_best, i_best, :] = torch.sqrt( + 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize + ) + + # loss calculation + + output *= tgt_mask + target *= tgt_mask + target_in_grid_distance = torch.zeros(batchsize, 80, 2).type(dtype) + + output[..., 2:4] *= tgt_scale + target[..., 2:4] *= tgt_scale + + bceloss = nn.BCELoss( + weight=tgt_scale * tgt_scale, reduction="sum" + ) # weighted BCEloss + loss_xy = bceloss(output[..., :2], target[..., :2]) + loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2 + loss_obj = self.bce_loss(output[..., 4], target[..., 4]) + loss_cls = self.bce_loss(output[..., 5:], target[..., 5:]) + loss_l2 = self.l2_loss(output, target) + loss_in_grid = self.bce_loss(in_grid_distance, target_in_grid_distance) + + loss = loss_xy + loss_wh + loss_obj + loss_cls + 0.01 * loss_in_grid + + return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_in_grid, loss_l2 diff --git a/openpmcvl/granular/models/yolov3.py b/openpmcvl/granular/models/yolov3.py new file mode 100644 index 0000000..0f8fe1e --- /dev/null +++ b/openpmcvl/granular/models/yolov3.py @@ -0,0 +1,328 @@ +from collections import defaultdict + +import torch +from torch import nn + +from openpmcvl.granular.models.yolo_layer import YOLOimgLayer, YOLOLayer + + +def add_conv(in_ch, out_ch, ksize, stride): + """ + Add a conv2d / batchnorm / leaky ReLU block. + Args: + in_ch (int): number of input channels of the convolution layer. + out_ch (int): number of output channels of the convolution layer. + ksize (int): kernel size of the convolution layer. + stride (int): stride of the convolution layer. + + Returns + ------- + stage (Sequential) : Sequential layers composing a convolution block. + """ + stage = nn.Sequential() + pad = (ksize - 1) // 2 + stage.add_module( + "conv", + nn.Conv2d( + in_channels=in_ch, + out_channels=out_ch, + kernel_size=ksize, + stride=stride, + padding=pad, + bias=False, + ), + ) + stage.add_module("batch_norm", nn.BatchNorm2d(out_ch)) + stage.add_module("leaky", nn.LeakyReLU(0.1)) + return stage + + +class resblock(nn.Module): + """ + Sequential residual blocks each of which consists of \ + two convolution layers. + Args: + ch (int): number of input and output channels. + nblocks (int): number of residual blocks. + shortcut (bool): if True, residual tensor addition is enabled. + """ + + def __init__(self, ch, nblocks=1, shortcut=True): + super().__init__() + self.shortcut = shortcut + self.module_list = nn.ModuleList() + for i in range(nblocks): + resblock_one = nn.ModuleList() + resblock_one.append(add_conv(ch, ch // 2, 1, 1)) + resblock_one.append(add_conv(ch // 2, ch, 3, 1)) + self.module_list.append(resblock_one) + + def forward(self, x): + for module in self.module_list: + h = x + for res in module: + h = res(h) + x = x + h if self.shortcut else h + return x + + +def create_yolov3_modules(config_model, ignore_thre): + """ + Build yolov3 layer modules. + Args: + config_model (dict): model configuration. + See YOLOLayer class for details. + ignore_thre (float): used in YOLOLayer. + + Returns + ------- + mlist (ModuleList): YOLOv3 module list. + """ + # DarkNet53 + mlist = nn.ModuleList() + mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1)) + mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2)) + mlist.append(resblock(ch=64)) + mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2)) + mlist.append(resblock(ch=128, nblocks=2)) + mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2)) + mlist.append(resblock(ch=256, nblocks=8)) # shortcut 1 from here + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2)) + mlist.append(resblock(ch=512, nblocks=8)) # shortcut 2 from here + mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2)) + mlist.append(resblock(ch=1024, nblocks=4)) + + # YOLOv3 + mlist.append(resblock(ch=1024, nblocks=2, shortcut=False)) + mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1)) + # 1st yolo branch + mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1)) + mlist.append( + YOLOLayer(config_model, layer_no=0, in_ch=1024, ignore_thre=ignore_thre) + ) + + mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) + mlist.append(nn.Upsample(scale_factor=2, mode="nearest")) + mlist.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1)) + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) + mlist.append(resblock(ch=512, nblocks=1, shortcut=False)) + mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) + # 2nd yolo branch + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) + mlist.append( + YOLOLayer(config_model, layer_no=1, in_ch=512, ignore_thre=ignore_thre) + ) + + mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1)) + mlist.append(nn.Upsample(scale_factor=2, mode="nearest")) + mlist.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1)) + mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1)) + mlist.append(resblock(ch=256, nblocks=2, shortcut=False)) + mlist.append( + YOLOLayer(config_model, layer_no=2, in_ch=256, ignore_thre=ignore_thre) + ) + + return mlist + + +class YOLOv3(nn.Module): + """ + YOLOv3 model module. The module list is defined by create_yolov3_modules function. \ + The network returns loss values from three YOLO layers during training \ + and detection results during test. + """ + + def __init__(self, config_model, ignore_thre=0.7): + """ + Initialization of YOLOv3 class. + Args: + config_model (dict): used in YOLOLayer. + ignore_thre (float): used in YOLOLayer. + """ + super(YOLOv3, self).__init__() + + if config_model["TYPE"] == "YOLOv3": + self.module_list = create_yolov3_modules(config_model, ignore_thre) + else: + raise Exception( + "Model name {} is not available".format(config_model["TYPE"]) + ) + + def forward(self, x, targets=None): + """ + Forward path of YOLOv3. + Args: + x (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \ + where N, C are batchsize and num. of channels. + targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)` + + Returns + ------- + training: + output (torch.Tensor): loss tensor for backpropagation. + test: + output (torch.Tensor): concatenated detection results. + """ + train = targets is not None + output = [] + self.loss_dict = defaultdict(float) + route_layers = [] + for i, module in enumerate(self.module_list): + # yolo layers + if i in [14, 22, 28]: + if train: + x, *loss_dict = module(x, targets) + for name, loss in zip(["xy", "wh", "conf", "cls", "l2"], loss_dict): + self.loss_dict[name] += loss + else: + x = module(x) + output.append(x) + else: + x = module(x) + + # route layers + if i in [6, 8, 12, 20]: + route_layers.append(x) + if i == 14: + x = route_layers[2] + if i == 22: # yolo 2nd + x = route_layers[3] + if i == 16: + x = torch.cat((x, route_layers[1]), 1) + if i == 24: + x = torch.cat((x, route_layers[0]), 1) + if train: + return sum(output) + return torch.cat(output, 1) + + +def create_yolov3img_modules(config_model, ignore_thre): + """ + Build yolov3 layer modules. + Args: + config_model (dict): model configuration. + See YOLOLayer class for details. + ignore_thre (float): used in YOLOLayer. + + Returns + ------- + mlist (ModuleList): YOLOv3 module list. + """ + # DarkNet53 + mlist = nn.ModuleList() + mlist.append(add_conv(in_ch=4, out_ch=32, ksize=3, stride=1)) + mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2)) + mlist.append(resblock(ch=64)) + mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2)) + mlist.append(resblock(ch=128, nblocks=2)) + mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2)) + mlist.append(resblock(ch=256, nblocks=8)) # shortcut 1 from here + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2)) + mlist.append(resblock(ch=512, nblocks=8)) # shortcut 2 from here + mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2)) + mlist.append(resblock(ch=1024, nblocks=4)) + + # YOLOv3 + mlist.append(resblock(ch=1024, nblocks=2, shortcut=False)) + mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1)) + # 1st yolo branch + mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1)) + mlist.append( + YOLOimgLayer(config_model, layer_no=0, in_ch=1024, ignore_thre=ignore_thre) + ) + + mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) + mlist.append(nn.Upsample(scale_factor=2, mode="nearest")) + mlist.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1)) + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) + mlist.append(resblock(ch=512, nblocks=1, shortcut=False)) + mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) + # 2nd yolo branch + mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) + mlist.append( + YOLOimgLayer(config_model, layer_no=1, in_ch=512, ignore_thre=ignore_thre) + ) + + mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1)) + mlist.append(nn.Upsample(scale_factor=2, mode="nearest")) + mlist.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1)) + mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1)) + mlist.append(resblock(ch=256, nblocks=2, shortcut=False)) + mlist.append( + YOLOimgLayer(config_model, layer_no=2, in_ch=256, ignore_thre=ignore_thre) + ) + + return mlist + + +class YOLOv3img(nn.Module): + """ + YOLOv3 model module. The module list is defined by create_yolov3img_modules function. \ + The network returns loss values from three YOLO layers during training \ + and detection results during test. + """ + + def __init__(self, config_model, ignore_thre=0.7): + """ + Initialization of YOLOv3 class. + Args: + config_model (dict): used in YOLOLayer. + ignore_thre (float): used in YOLOLayer. + """ + super(YOLOv3img, self).__init__() + + if config_model["TYPE"] == "YOLOv3": + self.module_list = create_yolov3img_modules(config_model, ignore_thre) + else: + raise Exception( + "Model name {} is not available".format(config_model["TYPE"]) + ) + + def forward(self, x, targets=None): + """ + Forward path of YOLOv3. + Args: + x (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \ + where N, C are batchsize and num. of channels. + targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)` + + Returns + ------- + training: + output (torch.Tensor): loss tensor for backpropagation. + test: + output (torch.Tensor): concatenated detection results. + """ + train = targets[0] is not None + output = [] + self.loss_dict = defaultdict(float) + route_layers = [] + for i, module in enumerate(self.module_list): + # yolo layers + if i in [14, 22, 28]: + if train: + x, *loss_dict = module(x, targets) + for name, loss in zip( + ["xy", "wh", "conf", "cls", "in_grid", "l2"], loss_dict + ): + self.loss_dict[name] += loss + else: + x = module(x, targets) + output.append(x) + else: + x = module(x) + + # route layers + if i in [6, 8, 12, 20]: + route_layers.append(x) + if i == 14: + x = route_layers[2] + if i == 22: # yolo 2nd + x = route_layers[3] + if i == 16: + x = torch.cat((x, route_layers[1]), 1) + if i == 24: + x = torch.cat((x, route_layers[0]), 1) + if train: + return sum(output) + return output diff --git a/openpmcvl/granular/pipeline/__init__.py b/openpmcvl/granular/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/openpmcvl/granular/pipeline/align.py b/openpmcvl/granular/pipeline/align.py new file mode 100644 index 0000000..ff679e6 --- /dev/null +++ b/openpmcvl/granular/pipeline/align.py @@ -0,0 +1,99 @@ +import argparse +import os +from typing import Dict + +from tqdm import tqdm + +from openpmcvl.granular.models.subfigure_ocr import classifier +from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl + + +def process_subfigure(model: classifier, subfig_data: Dict) -> Dict: + """ + Process a single subfigure using the OCR model. + + Args: + model (classifier): Initialized OCR model + subfig_data (Dict): Dictionary containing subfigure data + + Returns + ------- + Dict: Updated subfigure data with OCR results + """ + if "subfig_path" not in subfig_data: + subfig_data["subfig_path"] = f"{args.root_dir}/images/{subfig_data['image']}" + + try: + ocr_result = model.run(subfig_data["subfig_path"]) + except Exception as e: + ocr_result = "" + print(f"Error processing subfigure {subfig_data['image']}: {e}") + + if ocr_result: + label_letter, *label_position = ocr_result + subfig_data["label"] = f"Subfigure-{label_letter.upper()}" + subfig_data["label_position"] = label_position + else: + subfig_data["label"] = "" + subfig_data["label_position"] = [] + + return subfig_data + + +def main(args: argparse.Namespace) -> None: + """ + Main function to process subfigures and update JSONL file. + + Args: + args (argparse.Namespace): Parsed command-line arguments + """ + # Load model and dataset + model = classifier() + dataset = load_dataset(args.dataset_path) + if args.dataset_slice: + dataset = dataset[args.dataset_slice] + + # Use this line to filter out non-medical subfigures if needed + dataset = [data for data in dataset if data["is_medical_subfigure"]] + print( + f"Total {len(dataset)} medical subfigures from {os.path.basename(args.dataset_path)}" + ) + + # Label each subfigure + labeled_dataset = [] + for data in tqdm(dataset, desc="Labeling subfigures", total=len(dataset)): + updated_item = process_subfigure(model, data) + labeled_dataset.append(updated_item) + + total_labeled = len([data for data in labeled_dataset if data["label"]]) + print(f"Total {total_labeled} subfigures labeled.") + + # Save updated data + save_jsonl(labeled_dataset, args.save_path) + print(f"\nLabeled data saved to {args.save_path}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Subfigure OCR and Labeling") + + parser.add_argument( + "--root_dir", type=str, required=True, help="Path to root directory" + ) + parser.add_argument( + "--dataset_path", type=str, required=True, help="Path to input JSONL file" + ) + parser.add_argument( + "--dataset_slice", + type=str, + help="Start and end indices for dataset slice (e.g. '0:100')", + ) + parser.add_argument( + "--save_path", type=str, required=True, help="Path to output JSONL file" + ) + + args = parser.parse_args() + if args.dataset_slice: + start, end = map(int, args.dataset_slice.split(":")) + args.dataset_slice = slice(start, end) + + main(args) diff --git a/openpmcvl/granular/pipeline/align.sh b/openpmcvl/granular/pipeline/align.sh new file mode 100644 index 0000000..0e508bd --- /dev/null +++ b/openpmcvl/granular/pipeline/align.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Batch script to align subfigures with subcaptions + +#SBATCH -c 6 +#SBATCH --gres=gpu:1 +#SBATCH --partition=a40 +#SBATCH --mem=32GB +#SBATCH --time=12:00:00 +#SBATCH --job-name=align +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err + +# Set environment variables +# VENV_PATH: Path to your virtual environment (e.g. export VENV_PATH=$HOME/venv) +# PROJECT_ROOT: Path to project root directory (e.g. export PROJECT_ROOT=$HOME/project) +# PMC_ROOT: Path to PMC dataset directory (e.g. export PMC_ROOT=$HOME/data) + +# Sample command: +# sbatch openpmcvl/granular/pipeline/align.sh 0 1 2 3 4 5 6 7 8 9 10 11 + + +# Activate virtual environment +source $VENV_PATH/bin/activate + +# Set working directory +cd $PROJECT_ROOT + +# Check if the number of arguments is provided +if [ $# -eq 0 ]; then + echo "Please provide JSONL numbers as arguments." + exit 1 +fi + +# Get the list of JSONL numbers from the command line arguments +JSONL_NUMBERS="$@" + +# Iterate over each JSONL number +for num in $JSONL_NUMBERS; do + input_file="$PMC_ROOT/${num}_subfigures_classified.jsonl" + output_file="$PMC_ROOT/${num}_aligned.jsonl" + + # Run the alignment script + stdbuf -oL -eL srun python3 openpmcvl/granular/pipeline/align.py \ + --root_dir "$PMC_ROOT" \ + --dataset_path "$input_file" \ + --save_path "$output_file" + + echo "Finished aligning ${num}" +done diff --git a/openpmcvl/granular/pipeline/classify.py b/openpmcvl/granular/pipeline/classify.py new file mode 100644 index 0000000..82a3784 --- /dev/null +++ b/openpmcvl/granular/pipeline/classify.py @@ -0,0 +1,173 @@ +import argparse +from typing import Any, Dict, List + +import torch +from PIL import Image +from torch import nn +from torch.utils.data import DataLoader +from torchvision import models, transforms +from tqdm import tqdm + +from openpmcvl.granular.dataset.dataset import SubfigureDataset +from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl + + +MEDICAL_CLASS = 15 +CLASSIFICATION_THRESHOLD = 4 + + +def load_classification_model(model_path: str, device: torch.device) -> nn.Module: + """ + Loads the figure classification model. + + Args: + model_path (str): Path to the classification model checkpoint + device (torch.device): Device to use for processing + + Returns + ------- + nn.Module: Loaded classification model + """ + fig_model = models.resnext101_32x8d() + num_features = fig_model.fc.in_features + fc = list(fig_model.fc.children()) + fc.extend([nn.Linear(num_features, 28)]) + fig_model.fc = nn.Sequential(*fc) + fig_model = fig_model.to(device) + fig_model.load_state_dict(torch.load(model_path, map_location=device)) + fig_model.eval() + return fig_model + + +def classify_dataset( + model: torch.nn.Module, + data_list: List[Dict[str, Any]], + batch_size: int, + device: torch.device, + output_file: str, + num_workers: int, +): + """ + Classifies images in a dataset using the provided model and saves results to a new JSONL file. + + Args: + model (torch.nn.Module): The classification model. + data_list (List[Dict[str, Any]]): The dataset to classify. + batch_size (int): Batch size for processing. + device (torch.device): Device to use for processing. + output_file (str): Path to save the updated JSONL file with classification results. + num_workers (int): Number of workers for processing. + + Returns + ------- + None + """ + mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + transform = transforms.Compose( + [ + transforms.Resize((384, 384), interpolation=Image.LANCZOS), + transforms.ToTensor(), + transforms.Normalize(*mean_std), + ] + ) + + dataset = SubfigureDataset(data_list, transform=transform) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + model.eval() + model.to(device) + + results = [] + + for images, indices in tqdm( + dataloader, desc=f"Classifying for {output_file}", total=len(dataloader) + ): + images = images.to(device) + outputs = model(images) + + for output, idx in zip(outputs, indices): + sorted_pred = torch.argsort(output.cpu(), descending=True) + medical_class_rank = (sorted_pred == MEDICAL_CLASS).nonzero().item() + is_medical = medical_class_rank < CLASSIFICATION_THRESHOLD + + # Get the original item using the index + item = data_list[idx.item()] + result = { + **item, + "medical_class_rank": medical_class_rank, + "is_medical_subfigure": is_medical, + } + results.append(result) + + save_jsonl(results, output_file) + + +def main(args: argparse.Namespace) -> None: + """ + Main function to run the image classification pipeline. + + This function loads the classification model, prepares the dataset, + and performs classification on the images, saving the results to a JSONL file. + + Args: + args (argparse.Namespace): Command-line arguments containing: + - model_path (str): Path to the classification model checkpoint + - dataset_path (str): Path to the dataset + - batch_size (int): Batch size for processing + - output_file (str): Path to save the JSONL file with classification results + + Returns + ------- + None + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_grad_enabled(False) + + model = load_classification_model(args.model_path, device) + dataset = load_dataset(args.dataset_path) + print(f"Loaded {len(dataset)} subfigures from {args.dataset_path}.") + + classify_dataset( + model, + dataset, + args.batch_size, + device, + args.output_file, + args.num_workers, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Classify images in a dataset and update JSONL file" + ) + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the classification model checkpoint", + ) + parser.add_argument( + "--dataset_path", type=str, required=True, help="Path to the dataset" + ) + parser.add_argument( + "--output_file", + type=str, + required=True, + help="Path to save the JSONL file with classification results", + ) + parser.add_argument( + "--batch_size", type=int, default=128, help="Batch size for processing" + ) + parser.add_argument( + "--num_workers", type=int, default=8, help="Number of workers for processing" + ) + args = parser.parse_args() + + main(args) diff --git a/openpmcvl/granular/pipeline/classify.sh b/openpmcvl/granular/pipeline/classify.sh new file mode 100644 index 0000000..48083d9 --- /dev/null +++ b/openpmcvl/granular/pipeline/classify.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Batch script to classify subfigures into figure types + +#SBATCH -c 12 +#SBATCH --gres=gpu:1 +#SBATCH --partition=a40 +#SBATCH --mem=100GB +#SBATCH --time=15:00:00 +#SBATCH --job-name=classify +#SBATCH --output=%x-%j.out + +# Set environment variables: +# VENV_PATH: Path to virtual environment (e.g. export VENV_PATH=$HOME/venv) +# PROJECT_ROOT: Path to project root directory (e.g. export PROJECT_ROOT=$HOME/project) +# PMC_ROOT: Path to PMC dataset directory (e.g. export PMC_ROOT=$HOME/data) + +# Sample command: +# sbatch openpmcvl/granular/pipeline/classify.sh 0 1 2 3 4 5 6 7 8 9 10 11 + + +# Activate virtual environment +source $VENV_PATH/bin/activate + +# Set working directory +cd $PROJECT_ROOT + +# Check if the number of arguments is provided +if [ $# -eq 0 ]; then + echo "Please provide JSONL numbers as arguments." + exit 1 +fi + +# Get the list of JSONL numbers from the command line arguments +JSONL_NUMBERS="$@" + +# Iterate over each JSONL number +for num in $JSONL_NUMBERS; do + input_file="$PMC_ROOT/${num}_subfigures.jsonl" + output_file="$PMC_ROOT/${num}_subfigures_classified.jsonl" + + # Run the classification script + stdbuf -oL -eL srun python3 openpmcvl/granular/pipeline/classify.py \ + --model_path openpmcvl/granular/checkpoints/resnext101_figure_class.pth \ + --dataset_path "$input_file" \ + --output_file "$output_file" \ + --batch_size 256 \ + --num_workers 8 \ + + echo "Finished classifying ${num}" +done diff --git a/openpmcvl/granular/pipeline/preprocess.py b/openpmcvl/granular/pipeline/preprocess.py new file mode 100644 index 0000000..5843beb --- /dev/null +++ b/openpmcvl/granular/pipeline/preprocess.py @@ -0,0 +1,251 @@ +import argparse +import os +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List, Set, Tuple + +from PIL import Image +from tqdm import tqdm + +from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl + + +def get_image_dimensions(image_path: str) -> Tuple[int, int]: + """ + Get the width and height of an image. + + Args: + image_path (str): The path to the image file. + + Returns + ------- + Tuple[int, int]: A tuple containing the width and height of the image. + + Raises + ------ + IOError: If the image file cannot be opened or read. + """ + with Image.open(image_path) as img: + return img.size + + +def check_keywords(caption: str, keywords: Set[str]) -> Tuple[List[str], bool]: + """ + Check for the presence of keywords in the caption. + + Args: + caption (str): The caption text to search in. + keywords (Set[str]): A set of keywords to search for. + + Returns + ------- + Tuple[List[str], bool]: A tuple containing: + - A list of found keywords. + - A boolean indicating whether any keywords were found. + """ + caption_words = set(caption.lower().split()) + found_keywords = list(keywords.intersection(caption_words)) + return found_keywords, bool(found_keywords) + + +def process_single_file( + input_file: str, + figure_root: str, + keywords: Set[str], + output_dir: str, + position: int, +) -> Tuple[List[dict], List[str], List[str]]: + """ + Process a single input file. + + Args: + input_file (str): Path to the input JSONL file. + figure_root (str): Root directory for figure images. + keywords (Set[str]): Set of keywords to search for in captions. + output_dir (str): Directory to save the processed file. + position (int): Position for the tqdm progress bar. + + Returns + ------- + Tuple[List[dict], List[str], List[str]]: Processed data, missing figures, and messages. + """ + data = load_dataset(input_file, num_datapoints=-1) + processed_data = [] + missing_figures_count = 0 + missing_figures = [] + messages = [] + + # Use tqdm with position parameter + pbar = tqdm( + data, + desc=f"Processing {os.path.basename(input_file)}", + position=position, + leave=True, + ncols=100, + ) + + for item in pbar: + pmc_id = item["PMC_ID"] + media_id = item["media_id"] + media_name = item["media_name"] + media_url = item["media_url"] + caption = item["caption"] + image_name = os.path.basename(media_url) + image_path = f"{figure_root}/{media_name}/{image_name}" + + # Check if image doesn't exist or is not a .jpg file + if not image_path.endswith(".jpg"): + continue + + if not os.path.exists(image_path): + missing_figures_count += 1 + missing_figures.append(image_path) + continue + + try: + width, height = get_image_dimensions(image_path) + except Exception as e: + msg = f"Error processing image {image_path}: {str(e)}" + messages.append(msg) + continue + + found_keywords, has_keywords = check_keywords(caption, keywords) + + processed_item = { + "id": f"{pmc_id}_{media_id}", + "PMC_ID": pmc_id, + "caption": caption, + "image_path": image_path, + "width": width, + "height": height, + "media_id": media_id, + "media_url": media_url, + "media_name": media_name, + "keywords": found_keywords, + "is_medical": has_keywords, + } + processed_data.append(processed_item) + + # Update pbar one last time to ensure it's at 100% + pbar.n = len(data) + pbar.refresh() + + # Save processed data for this input file + input_filename = os.path.splitext(os.path.basename(input_file))[0] + temp_output_file = os.path.join(output_dir, f"{input_filename}_meta.jsonl") + save_jsonl(processed_data, temp_output_file) + msg = ( + f"\nProcessed {len(processed_data)} items from {input_file}. Saved to {temp_output_file}" + f"\nTotal missing .jpg images in {input_file}: {missing_figures_count}\n" + ) + messages.append(msg) + + return processed_data, missing_figures, messages + + +def preprocess_data( + input_files: List[str], output_file: str, figure_root: str, keywords: List[str] +) -> None: + """ + Preprocess the input files and generate the output file. + + Args: + input_files (List[str]): List of input JSONL file paths. + output_file (str): Path to the output JSONL file. + figure_root (str): Root directory for figure images. + keywords (List[str]): List of keywords to search for in captions. + + Returns + ------- + None + """ + all_processed_data = [] + missing_figures = {} + output_dir = os.path.dirname(output_file) + messages = [] + + # Create a ProcessPoolExecutor to process files in parallel + with ProcessPoolExecutor() as executor: + # Submit jobs with position parameter + future_to_file = { + executor.submit( + process_single_file, input_file, figure_root, keywords, output_dir, i + ): input_file + for i, input_file in enumerate(input_files) + } + + # Use tqdm to track overall progress + overall_pbar = tqdm( + total=len(input_files), + desc="Overall Progress", + position=len(input_files), + leave=True, + ) + + for future in as_completed(future_to_file): + input_file = future_to_file[future] + try: + processed_data, missing_figs, msgs = future.result() + all_processed_data.extend(processed_data) + missing_figures[input_file] = missing_figs + messages.extend(msgs) + overall_pbar.update(1) + except Exception as exc: + print(f"\nException occurred while processing {input_file}: {exc}") + + overall_pbar.close() + + # Print all messages + for msg in messages: + print(msg) + + # Merge all processed data and save to the final output file + save_jsonl(all_processed_data, output_file) + print(f"All processed data merged and saved to {output_file}") + + # Save missing images to a separate JSONL file + missing_figures_file = os.path.join(output_dir, "missing_figures.jsonl") + save_jsonl(missing_figures, missing_figures_file) + print(f"Missing images information saved to {missing_figures_file}") + + +def main(args: argparse.Namespace): + """ + Main function to parse arguments and run the preprocessing pipeline. + + Args: + args (argparse.Namespace): Command-line arguments. + """ + preprocess_data(args.input_files, args.output_file, args.figure_root, args.keywords) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Preprocess JSONL files for figure caption analysis" + ) + parser.add_argument( + "--input_files", + type=str, + nargs="+", + required=True, + help="List of input JSONL files", + ) + parser.add_argument( + "--output_file", type=str, required=True, help="Path to the output JSONL file" + ) + parser.add_argument( + "--figure_root", + type=str, + required=True, + help="Root directory for figure images", + ) + parser.add_argument( + "--keywords", + type=str, + nargs="+", + default=["CT", "pathology", "radiology"], + help="Keywords to search for in captions", + ) + + args = parser.parse_args() + args.keywords = set(kw.lower() for kw in args.keywords) + main(args) diff --git a/openpmcvl/granular/pipeline/preprocess.sh b/openpmcvl/granular/pipeline/preprocess.sh new file mode 100644 index 0000000..a1ef227 --- /dev/null +++ b/openpmcvl/granular/pipeline/preprocess.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Batch script to preprocess PMC figure-caption pairs + +#SBATCH -c 32 +#SBATCH --partition=cpu +#SBATCH --mem=128GB +#SBATCH --time=12:00:00 +#SBATCH --job-name=preprocess +#SBATCH --output=%x-%j.out + +# Set environment variables: +# VENV_PATH: Path to virtual environment (e.g. export VENV_PATH=$HOME/venv) +# PROJECT_ROOT: Path to project root directory (e.g. export PROJECT_ROOT=$HOME/project) +# PMC_ROOT: Path to PMC dataset directory (e.g. export PMC_ROOT=$HOME/data) + +# Sample command: +# sbatch openpmcvl/granular/pipeline/preprocess.sh 0 1 2 3 4 5 6 7 8 9 10 11 + + +# Activate virtual environment +source $VENV_PATH/bin/activate + +# Set working directory +cd $PROJECT_ROOT + +# Define the paths for the input and output files +INPUT_DIR="$PMC_ROOT" +OUTPUT_FILE="$PMC_ROOT/granular_meta.jsonl" +FIGURE_ROOT="$PMC_ROOT/figures" + +# Check if the number of arguments is provided +if [ $# -eq 0 ]; then + echo "Please provide JSONL numbers as arguments." + exit 1 +fi + +# Get the list of JSONL numbers from the command line arguments +JSONL_NUMBERS="$@" + +# Construct INPUT_FILES string +INPUT_FILES="" +for num in $JSONL_NUMBERS; do + INPUT_FILES+="$INPUT_DIR/$num.jsonl " +done + +# Run the preprocess script +stdbuf -oL -eL srun python3 openpmcvl/granular/pipeline/preprocess.py \ + --input_files $INPUT_FILES \ + --output_file $OUTPUT_FILE \ + --figure_root $FIGURE_ROOT \ + --keywords MRI fMRI CT CAT PET PET-MRI MEG EEG ultrasound X-ray Xray nuclear imaging tracer isotope scan positron EKG spectroscopy radiograph tomography endoscope endoscopy colonoscopy elastography ultrasonic ultrasonography echocardiogram endomicroscopy pancreatoscopy cholangioscopy enteroscopy retroscopy chromoendoscopy sigmoidoscopy cholangiography pancreatography cholangio-pancreatography esophagogastroduodenoscopy radiology pathology histopathology \ + 2>&1 | tee -a %x-%j.out diff --git a/openpmcvl/granular/pipeline/subcaption.ipynb b/openpmcvl/granular/pipeline/subcaption.ipynb new file mode 100644 index 0000000..0fe63b5 --- /dev/null +++ b/openpmcvl/granular/pipeline/subcaption.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "\n", + "from openpmcvl.granular.pipeline.utils import load_dataset\n", + "\n", + "\n", + "PMC_ROOT = \"set this directory\"\n", + "\n", + "# Make sure .env file containt OPENAI_API_KEY\n", + "load_dotenv()\n", + "client = OpenAI()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# **Subcaption Extraction**\n", + "\n", + "Extracts subfigure captions from figure captions using OpenAI's GPT-4o Batch API.\n", + "\n", + "## Pipeline\n", + "1. Input: JSONL with metadata (captions + IDs)\n", + "2. Generate batch API requests (50k limit)\n", + "3. Submit to OpenAI batch processing\n", + "4. Get results as structured subcaptions\n", + "5. Save results to JSONL file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PROMPT = \"\"\"\n", + "Subfigure labels are letters referring to individual subfigures within a larger figure.\n", + "This is a caption: \"%s\"\n", + "Check if the caption contains explicit subfigure label. \n", + "If not, output \"NO\" and end the generation. \n", + "If yes, output \"YES\", then generate the subcaption of the subfigures according to the caption. \n", + "The output should use the template:\n", + " YES\n", + " Subfigure-A: ...\n", + " Subfigure-B: ...\n", + " ...\n", + "The label should be removed from subcaption.\n", + "\"\"\".strip()\n", + "\n", + "caption = \"Try sample caption!\"\n", + "\n", + "\n", + "completion = client.chat.completions.create(\n", + " model=\"gpt-4o-2024-08-06\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": PROMPT % caption},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=500,\n", + ")\n", + "\n", + "print(completion.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_api_request(custom_id, system_content, user_content):\n", + " \"\"\"Generate a single API request in the required format.\"\"\"\n", + " return {\n", + " \"custom_id\": custom_id,\n", + " \"method\": \"POST\",\n", + " \"url\": \"/v1/chat/completions\",\n", + " \"body\": {\n", + " \"model\": \"gpt-4o-2024-08-06\",\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_content},\n", + " {\"role\": \"user\", \"content\": user_content},\n", + " ],\n", + " \"temperature\": 0,\n", + " \"max_tokens\": 2000,\n", + " },\n", + " }\n", + "\n", + "\n", + "def create_prompt(caption):\n", + " \"\"\"Create the prompt template with the given caption.\"\"\"\n", + " return PROMPT.strip() % caption\n", + "\n", + "\n", + "def generate_jsonl(dataset, requests_file):\n", + " \"\"\"Generate JSONL file with API requests.\n", + "\n", + " Args:\n", + " dataset: List of metadata containing captions and IDs\n", + " requests_file: Path to output requests JSONL file\n", + " \"\"\"\n", + " count = 0\n", + "\n", + " # Open output file and write requests line by line\n", + " with open(requests_file, \"w\") as f:\n", + " for data in dataset:\n", + " count += 1\n", + "\n", + " # Skip first 50k entries (already processed)\n", + " if count <= 50000: # Batch API can handle at most 50k requests\n", + " continue\n", + "\n", + " # Only process captions under 400 words\n", + " if len(data[\"caption\"].split()) <= 400:\n", + " # Generate API request for this caption\n", + " request = generate_api_request(\n", + " custom_id=f\"{data['id']}\",\n", + " system_content=\"You are a helpful assistant.\",\n", + " user_content=create_prompt(data[\"caption\"]),\n", + " )\n", + "\n", + " # Write request as JSON line\n", + " f.write(json.dumps(request) + \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the metadata dataset containing captions and IDs\n", + "# If you have multiple datasets, you can merge them into one\n", + "dataset = load_dataset(os.path.join(PMC_ROOT, \"meta.jsonl\"))\n", + "\n", + "# Define output path for API requests\n", + "requests_file = os.path.join(PMC_ROOT, \"requests.jsonl\")\n", + "\n", + "# Generate JSONL file with API requests for each caption\n", + "generate_jsonl(dataset, requests_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Upload the requests file to OpenAI for batch processing\n", + "batch_input_file = client.files.create(file=open(requests_file, \"rb\"), purpose=\"batch\")\n", + "batch_input_file_id = batch_input_file.id\n", + "\n", + "# Create a batch job to process the requests\n", + "# This will run for up to 24 hours and process 50k subcaptions\n", + "client.batches.create(\n", + " input_file_id=batch_input_file_id,\n", + " endpoint=\"/v1/chat/completions\",\n", + " completion_window=\"24h\",\n", + " metadata={\"description\": \"50k subcaptions\"},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note you have to run this separately for each submitted batch\n", + "# Check status of first batch job\n", + "print(client.batches.retrieve(\"batch_xxxxx\"))\n", + "\n", + "# Get the output file content from the completed batch\n", + "file_response = client.files.content(\"file-xxxxxx\")\n", + "\n", + "# Write the batch results to a JSONL file\n", + "with open(f\"{PMC_ROOT}/subcaptions.jsonl\", \"w\") as f:\n", + " f.write(file_response.text)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/openpmcvl/granular/pipeline/subcaption.py b/openpmcvl/granular/pipeline/subcaption.py new file mode 100644 index 0000000..f7cfc7e --- /dev/null +++ b/openpmcvl/granular/pipeline/subcaption.py @@ -0,0 +1,149 @@ +import argparse +import re +from sys import stderr +from typing import Dict + +from openai import OpenAI +from tqdm import tqdm + +from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl + + +PROMPT = """ +Subfigure labels are letters referring to individual subfigures within a larger figure. +Check if the caption contains explicit subfigure label. +If not, output "NO" and end the generation. +If yes, output "YES", then generate the subcaption of the subfigures according to the caption. +The output should use the template: + YES + Subfigure-A: ... + Subfigure-B: ... + ... +The label should be removed from subcaption. +""".strip() + + +def process_caption( + client: OpenAI, system_prompt: str, caption: str, model: str, max_tokens: int +) -> str: + """ + Process a caption using the language model. + + Args: + client (OpenAI): OpenAI client instance. + system_prompt (str): System prompt for the language model. + caption (str): Caption to process. + model (str): Model directory being used. + max_tokens (int): Maximum number of tokens for the model response. + + Returns + ------- + str: Processed caption from the language model. + """ + user_prompt = f"Caption: \n{caption}".strip() + + completion = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=0, + max_tokens=max_tokens, + ) + + return completion.choices[0].message.content + + +def parse_subcaptions(output: str) -> Dict[str, str]: + """ + Parse the output from the language model into subcaptions. + + Args: + output (str): Output from the language model. + + Returns + ------- + Dict[str, str]: Dictionary of subcaptions, where keys are subfigure labels and values are subcaptions. + """ + lines = output.strip().split("\n") + + if not lines[0].upper().startswith("YES"): + return {"Subfigure-A": "\n".join(lines)} + + subcaptions = {} + current_key = None + current_value = [] + + for line in lines[1:]: # Skip the "YES" line + match = re.match(r"^Subfigure-([A-Za-z]):\s*(.*)", line, re.IGNORECASE) + + if match: + if current_key: + subcaptions[current_key] = " ".join(current_value).strip() + current_key = f"Subfigure-{match.group(1).upper()}" + current_value = [match.group(2)] + elif current_key: + current_value.append(line) + + if current_key: + subcaptions[current_key] = " ".join(current_value).strip() + + return subcaptions + + +def main(args: argparse.Namespace) -> None: + """ + Main function to process captions and save results. + + Args: + args (argparse.Namespace): Command-line arguments. + """ + # Initialize OpenAI client + client = OpenAI() # base_url=args.base_url, api_key="EMPTY" + + # Load dataset + dataset = load_dataset(args.input_file) + print(f"\nDataset size: {len(dataset)}") + + # Inference loop + results = [] + + for item in tqdm( + dataset, desc="Processing captions", total=len(dataset), file=stderr + ): + caption = item["caption"] + + output = process_caption( + client=client, + system_prompt=PROMPT, + caption=caption, + model=args.model, + max_tokens=args.max_tokens, + ) + subcaptions = parse_subcaptions(output) + + item["num_subcaptions"] = len(subcaptions) + item["subcaptions"] = subcaptions + item["llm_output"] = output + + results.append(item) + + save_jsonl(results, args.output_file) + print(f"\nResults saved to {args.output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process captions into subcaptions") + + parser.add_argument("--input-file", required=True, help="Path to input JSONL file") + parser.add_argument("--output-file", required=True, help="Path to output JSON file") + parser.add_argument( + "--max-tokens", + type=int, + default=500, + help="Maximum number of tokens for API response", + ) + + args = parser.parse_args() + main(args) diff --git a/openpmcvl/granular/pipeline/subcaption.sh b/openpmcvl/granular/pipeline/subcaption.sh new file mode 100644 index 0000000..2eebdf5 --- /dev/null +++ b/openpmcvl/granular/pipeline/subcaption.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Batch script to extract subcaptions from figure captions using GPT API + +#SBATCH -c 6 +#SBATCH --partition=cpu +#SBATCH --mem=32GB +#SBATCH --time=8:00:00 +#SBATCH --job-name=subcaption +#SBATCH --output=%x-%j.out + +# Set environment variables: +# VENV_PATH: Path to virtual environment (e.g. export VENV_PATH=$HOME/venv) +# PROJECT_ROOT: Path to project root directory (e.g. export PROJECT_ROOT=$HOME/project) +# PMC_ROOT: Path to PMC dataset directory (e.g. export PMC_ROOT=$HOME/data) + +# Sample command: +# sbatch openpmcvl/granular/pipeline/subcaption.sh 0 1 2 3 4 5 6 7 8 9 10 11 + +# Activate virtual environment +source $VENV_PATH/bin/activate + +# Set working directory +cd $PROJECT_ROOT + +# Check if the number of arguments is provided +if [ $# -eq 0 ]; then + echo "Please provide JSONL numbers as arguments." + exit 1 +fi + +# Get the list of JSONL numbers from the command line arguments +JSONL_NUMBERS="$@" + +# Iterate over each JSONL number +for num in $JSONL_NUMBERS; do + # Run the subcaption script + stdbuf -oL -eL srun python3 openpmcvl/granular/pipeline/subcaption.py \ + --input-file "$PMC_ROOT/${num}_meta.jsonl" \ + --output-file "$PMC_ROOT/${num}_subcaptions.jsonl" \ + --max-tokens 500 \ + 2>&1 | tee -a %x-%j.out + + echo "Finished processing ${num}" +done diff --git a/openpmcvl/granular/pipeline/subfigure.py b/openpmcvl/granular/pipeline/subfigure.py new file mode 100644 index 0000000..0795c8a --- /dev/null +++ b/openpmcvl/granular/pipeline/subfigure.py @@ -0,0 +1,300 @@ +import argparse +import os +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchvision import utils as vutils +from tqdm import tqdm + +from openpmcvl.granular.dataset.dataset import ( + Fig_Separation_Dataset, + fig_separation_collate, +) +from openpmcvl.granular.models.subfigure_detector import FigCap_Former +from openpmcvl.granular.pipeline.utils import ( + box_cxcywh_to_xyxy, + find_jaccard_overlap, + save_jsonl, +) + + +def load_dataset(eval_file: str, batch_size: int, num_workers: int) -> DataLoader: + """ + Prepares the dataset and returns a DataLoader. + + Args: + eval_file (str): Path to the evaluation dataset file + batch_size (int): Batch size for the DataLoader + num_workers (int): Number of workers for the DataLoader + + Returns + ------- + DataLoader: Configured DataLoader for the separation dataset + """ + dataset = Fig_Separation_Dataset( + filepath=eval_file, normalization=False, only_medical=True + ) + print(f"\nDataset size: {len(dataset)}\n") + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=fig_separation_collate, + pin_memory=True, + ) + + +def load_separation_model(checkpoint_path: str, device: torch.device) -> FigCap_Former: + """ + Loads the FigCap_Former model from a checkpoint. + + Args: + checkpoint_path (str): Path to the model checkpoint + device (torch.device): Device to use for processing + + Returns + ------- + FigCap_Former: Loaded model + """ + model = FigCap_Former( + num_query=32, + num_encoder_layers=4, + num_decoder_layers=4, + num_text_decoder_layers=4, + bert_path="bert-base-uncased", + alignment_network=False, + resnet_pretrained=False, + resnet=34, + feature_dim=256, + atten_head_num=8, + text_atten_head_num=8, + mlp_ratio=4, + dropout=0.0, + activation="relu", + text_mlp_ratio=4, + text_dropout=0.0, + text_activation="relu", + ) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + + model.eval() + model.to(device) + + return model + + +def process_detections( + det_boxes: torch.Tensor, det_scores: np.ndarray, nms_threshold: float +) -> Tuple[List[List[float]], List[float]]: + """ + Processes detections using Non-Maximum Suppression (NMS). + + Args: + det_boxes (torch.Tensor): Detected bounding boxes + det_scores (np.ndarray): Confidence scores for detections + nms_threshold (float): IoU threshold for NMS + + Returns + ------- + Tuple[List[List[float]], List[float]]: Picked bounding boxes and their scores + """ + order = np.argsort(det_scores) + picked_bboxes = [] + picked_scores = [] + while order.size > 0: + index = order[-1] + picked_bboxes.append(det_boxes[index].tolist()) + picked_scores.append(det_scores[index]) + if order.size == 1: + break + iou_with_left = ( + find_jaccard_overlap( + box_cxcywh_to_xyxy(det_boxes[index]), + box_cxcywh_to_xyxy(det_boxes[order[:-1]]), + ) + .squeeze() + .numpy() + ) + left = np.where(iou_with_left < nms_threshold) + order = order[left] + return picked_bboxes, picked_scores + + +def separate_subfigures( + model: FigCap_Former, + loader: DataLoader, + save_path: str, + rcd_file: str, + score_threshold: float, + nms_threshold: float, + device: torch.device, +) -> None: + """ + Separates compound figures into subfigures and classifies them. + + Args: + model (FigCap_Former): Loaded model for subfigure detection + loader (DataLoader): DataLoader for the dataset + save_path (str): Path to save separated subfigures + rcd_file (str): File to record separation results + score_threshold (float): Confidence score threshold for detections + nms_threshold (float): IoU threshold for NMS + device (torch.device): Device to use for processing + """ + Path(save_path).mkdir(parents=True, exist_ok=True) + subfig_list = [] + failed_subfig_list = [] + subfig_count = 0 + + print("Separating subfigures...") + for batch in tqdm(loader, desc=f"File: {rcd_file}", total=len(loader)): + image = batch["image"].to(device) + img_ids = batch["image_id"] + original_images = batch["original_image"] + unpadded_hws = batch["unpadded_hws"] + + output_det_class, output_box, _ = model(image, None) + + output_box = output_box.cpu() + output_det_class = output_det_class.cpu() + filter_mask = output_det_class.squeeze() > score_threshold + + for i in range(image.shape[0]): + det_boxes = output_box[i, filter_mask[i, :], :] + det_scores = output_det_class.squeeze()[i, filter_mask[i, :]].numpy() + img_id = img_ids[i].split(".jpg")[0] + unpadded_image = original_images[i] + original_h, original_w = unpadded_hws[i] + + scale = max(original_h, original_w) / 512 + + picked_bboxes, picked_scores = process_detections( + det_boxes, det_scores, nms_threshold + ) + + for bbox, score in zip(picked_bboxes, picked_scores): + try: + subfig_path = f"{save_path}/{img_id}_{subfig_count}.jpg" + cx, cy, w, h = bbox + + # Calculate padding in terms of bounding box dimensions + pad_ratio = 0.01 + pad_w = w * pad_ratio + pad_h = h * pad_ratio + + # Adjust the coordinates with padding + x1 = round((cx - w / 2 - pad_w) * image.shape[3] * scale) + x2 = round((cx + w / 2 + pad_w) * image.shape[3] * scale) + y1 = round((cy - h / 2 - pad_h) * image.shape[2] * scale) + y2 = round((cy + h / 2 + pad_h) * image.shape[2] * scale) + + # Ensure the coordinates are within image boundaries + x1, x2 = [max(0, min(x, original_w - 1)) for x in [x1, x2]] + y1, y2 = [max(0, min(y, original_h - 1)) for y in [y1, y2]] + + subfig = unpadded_image[:, y1:y2, x1:x2].detach().cpu() + vutils.save_image(subfig, subfig_path) + + subfig_list.append( + { + "id": f"{img_id}_{subfig_count}.jpg", + "source_fig_id": img_id, + "PMC_ID": img_id.split("_")[0], + "media_name": f"{img_id}.jpg", + "position": [(x1, y1), (x2, y2)], + "score": score.item(), + "subfig_path": subfig_path, + } + ) + subfig_count += 1 + except ValueError: + print( + f"Crop Error: [x1 x2 y1 y2]:[{x1} {x2} {y1} {y2}], w:{original_w}, h:{original_h}" + ) + failed_subfig_list.append( + { + "id": f"{img_id}_{subfig_count}.jpg", + "source_fig_id": img_id, + "PMC_ID": img_id.split("_")[0], + "media_name": f"{img_id}.jpg", + } + ) + continue + + save_jsonl(subfig_list, rcd_file) + save_jsonl(failed_subfig_list, f"{rcd_file.split('.')[0]}_failed.jsonl") + + +def main(args: argparse.Namespace) -> None: + """ + Main function to process images and save results. + + Args: + args (argparse.Namespace): Command-line arguments. + """ + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_grad_enabled(False) + + model = load_separation_model(args.separation_model, device) + dataloader = load_dataset(args.eval_file, args.batch_size, args.num_workers) + separate_subfigures( + model=model, + loader=dataloader, + save_path=args.save_path, + rcd_file=args.rcd_file, + score_threshold=args.score_threshold, + nms_threshold=args.nms_threshold, + device=device, + ) + print("\nSubfigure separation completed.\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Subfigure Separation Script") + + parser.add_argument( + "--separation_model", + type=str, + required=True, + help="Path to subfigure detection model checkpoint", + ) + parser.add_argument( + "--eval_file", type=str, required=True, help="Path to evaluation dataset file" + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to save separated subfigures", + ) + parser.add_argument( + "--rcd_file", + type=str, + required=True, + help="File to record separation results", + ) + parser.add_argument( + "--score_threshold", + type=float, + default=0.75, + help="Confidence score threshold for detections", + ) + parser.add_argument( + "--nms_threshold", type=float, default=0.4, help="IoU threshold for NMS" + ) + parser.add_argument( + "--batch_size", type=int, default=128, help="Batch size for evaluation" + ) + parser.add_argument( + "--num_workers", type=int, default=8, help="Number of workers for data loading" + ) + parser.add_argument("--gpu", type=str, default="0", help="GPU to use") + + args = parser.parse_args() + main(args) diff --git a/openpmcvl/granular/pipeline/subfigure.sh b/openpmcvl/granular/pipeline/subfigure.sh new file mode 100644 index 0000000..7940747 --- /dev/null +++ b/openpmcvl/granular/pipeline/subfigure.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Batch script to extract subfigures from compound figures using a detection model + +#SBATCH -c 12 +#SBATCH --gres=gpu:1 +#SBATCH --partition=a40 +#SBATCH --mem=100GB +#SBATCH --time=15:00:00 +#SBATCH --job-name=subfigure +#SBATCH --output=%x-%j.out + +# Set environment variables: +# VENV_PATH: Path to virtual environment (e.g. export VENV_PATH=$HOME/venv) +# PROJECT_ROOT: Path to project root directory (e.g. export PROJECT_ROOT=$HOME/project) +# PMC_ROOT: Path to PMC dataset directory (e.g. export PMC_ROOT=$HOME/data) + +# Sample command: +# sbatch openpmcvl/granular/pipeline/subfigure.sh 0 1 2 3 4 5 6 7 8 9 10 11 + +# Activate virtual environment +source $VENV_PATH/bin/activate + +# Set working directory +cd $PROJECT_ROOT + +# Check if the number of arguments is provided +if [ $# -eq 0 ]; then + echo "Please provide JSONL numbers as arguments." + exit 1 +fi + +# Get the list of JSONL numbers from the command line arguments +JSONL_NUMBERS="$@" + +# Iterate over each JSONL number +for num in $JSONL_NUMBERS; do + # Define the paths for the evaluation file and the record file + eval_file="$PMC_ROOT/${num}_meta.jsonl" + rcd_file="$PMC_ROOT/${num}_subfigures.jsonl" + + # Run the subfigure separation script + stdbuf -oL -eL srun python3 openpmcvl/granular/pipeline/subfigure.py \ + --separation_model openpmcvl/granular/checkpoints/subfigure_detector.pth \ + --eval_file "$eval_file" \ + --save_path "$PMC_ROOT/${num}_subfigures" \ + --rcd_file "$rcd_file" \ + --score_threshold 0.5 \ + --nms_threshold 0.4 \ + --batch_size 128 \ + --num_workers 8 \ + --gpu 0 + + # Print a message indicating the completion of processing for the current JSONL number + echo "Finished processing ${num}" +done diff --git a/openpmcvl/granular/pipeline/utils.py b/openpmcvl/granular/pipeline/utils.py new file mode 100644 index 0000000..3219b37 --- /dev/null +++ b/openpmcvl/granular/pipeline/utils.py @@ -0,0 +1,100 @@ +import json +from typing import Any, Dict, List + +import torch + + +def load_dataset(file_path: str, num_datapoints: int = -1) -> List[Dict[str, Any]]: + """ + Load dataset from a JSONL file. + + Args: + file_path (str): Path to the input JSONL file. + num_datapoints (int): Number of datapoints to load. If -1, load all datapoints. + + Returns + ------- + List[Dict[str, Any]]: List of dictionaries, each representing an item in the dataset. + """ + with open(file_path, "r") as f: + dataset = [json.loads(line) for line in f] + return dataset[:num_datapoints] if num_datapoints > 0 else dataset + + +def save_jsonl(data: List[Dict[str, Any]], file_path: str) -> None: + """ + Save data to a JSONL (JSON Lines) file. + + This function takes a list of dictionaries and writes each dictionary as a separate JSON object + on a new line in the specified file. This format is known as JSONL (JSON Lines). + + Args: + data (List[Dict[str, Any]]): A list of dictionaries to be saved. Each dictionary + represents a single data point or record. + file_path (str): The path to the output file where the data will be saved. + """ + with open(file_path, "w") as f: + for item in data: + json.dump(item, f) + f.write("\n") + + +def box_cxcywh_to_xyxy(x): + """ + Convert bounding box coordinates from (center_x, center_y, width, height) to (x1, y1, x2, y2) format. + + Args: + x (torch.Tensor): Input tensor of shape (..., 4) containing bounding box coordinates in (cx, cy, w, h) format. + + Returns + ------- + torch.Tensor: Tensor of shape (..., 4) containing bounding box coordinates in (x1, y1, x2, y2) format. + """ + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def find_intersection(set_1, set_2): + """ + Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. + + Args: + set_1 (torch.Tensor): Set 1, a tensor of dimensions (n1, 4) -- (x1, y1, x2, y2) + set_2 (torch.Tensor): Set 2, a tensor of dimensions (n2, 4) + + Returns + ------- + torch.Tensor: Intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) + """ + lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) + upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) + intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) + return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] + + +def find_jaccard_overlap(set_1, set_2): + """ + Find the Jaccard Overlap (IoU) of every box combination between two sets of boxes that are in boundary coordinates. + + Args: + set_1 (torch.Tensor): Set 1, a tensor of dimensions (n1, 4) + set_2 (torch.Tensor): Set 2, a tensor of dimensions (n2, 4) + + Returns + ------- + torch.Tensor: Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) + """ + if set_1.dim() == 1 and set_1.shape[0] == 4: + set_1 = set_1.unsqueeze(0) + if set_2.dim() == 1 and set_2.shape[0] == 4: + set_2 = set_2.unsqueeze(0) + + intersection = find_intersection(set_1, set_2) + + areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) + areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) + + union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection + + return intersection / union