diff --git a/monailabel/datastore/cvat.py b/monailabel/datastore/cvat.py index 32c8c2020..461883ea0 100644 --- a/monailabel/datastore/cvat.py +++ b/monailabel/datastore/cvat.py @@ -225,7 +225,7 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): task_id, task_name = self.get_cvat_task_id(project_id, create=False) logger.info(f"Preparing to download/update final labels from: {project_id} => {task_id} => {task_name}") - # Step 1: Initiate export process + # Step 1: Initiate export process using the new POST endpoint. export_url = f"{self.api_url}/api/tasks/{task_id}/dataset/export?format=Segmentation+mask+1.1&location=local&save_images=false" try: response = requests.post(export_url, auth=self.auth) @@ -242,29 +242,35 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): logger.exception(f"Error while initiating export process: {e}") return None - # Step 2: Poll export status + # Step 2: Poll export status using the new GET endpoint. status_url = f"{self.api_url}/api/requests/{rq_id}" for _ in range(max_retry_count): try: status_response = requests.get(status_url, auth=self.auth) - status = status_response.json().get("status") - if status == "finished": + status_data = status_response.json() + current_status = status_data.get("status") + if current_status == "finished": logger.info("Export process completed successfully.") break - elif status == "failed": - logger.error(f"Export process failed: {status_response.json()}") + elif current_status == "failed": + logger.error(f"Export process failed: {status_data}") return None logger.info(f"Export in progress... Retrying in {retry_wait_time} seconds.") time.sleep(retry_wait_time) except Exception as e: logger.exception(f"Error checking export status: {e}") - time.sleep(retry_wait_time) + time.sleep(retry_wait_time) else: logger.error("Export process did not complete within the maximum retries.") return None - # Step 3: Download the dataset - download_url = f"{self.api_url}/api/tasks/{task_id}/annotations?format=Segmentation+mask+1.1&location=local&save_images=false&action=download" + # Step 3: Retrieve the download URL from the export status. + result_url = status_data.get("result_url") + if not result_url: + logger.error("Export process finished but no result_url was provided.") + return None + + # Step 4: Download the ZIP file from the result_url. tmp_folder = tempfile.TemporaryDirectory().name os.makedirs(tmp_folder, exist_ok=True) @@ -272,37 +278,35 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): retry_count = 0 for retry in range(max_retry_count): try: - logger.info(f"Downloading exported dataset from: {download_url}") - r = requests.get(download_url, allow_redirects=True, auth=self.auth) - + logger.info(f"Downloading exported dataset from: {result_url}") + r = requests.get(result_url, allow_redirects=True, auth=self.auth) with open(tmp_zip, "wb") as fp: fp.write(r.content) shutil.unpack_archive(tmp_zip, tmp_folder) + # Process the segmentation files segmentations_dir = os.path.join(tmp_folder, "SegmentationClass") final_labels = self._datastore.label_path(DefaultLabelTag.FINAL) for f in os.listdir(segmentations_dir): label = os.path.join(segmentations_dir, f) if os.path.isfile(label) and label.endswith(".png"): os.makedirs(final_labels, exist_ok=True) - dest = os.path.join(final_labels, f) if self.normalize_label: img = np.array(Image.open(label)) mask = np.zeros_like(img) - labelmap = self._load_labelmap_txt(os.path.join(tmp_folder, "labelmap.txt")) for name, color in labelmap.items(): if name in self.label_map: idx = self.label_map.get(name) mask[np.all(img == color, axis=-1)] = idx - Image.fromarray(mask[:, :, 0]).save(dest) # single channel - logger.info(f"Copy Final Label: {label} to {dest}; unique: {np.unique(mask)}") + Image.fromarray(mask[:, :, 0]).save(dest) + logger.info(f"Copied Final Label: {label} to {dest}; unique: {np.unique(mask)}") else: Image.open(label).save(dest) - logger.info(f"Copy Final Label: {label} to {dest}") + logger.info(f"Copied Final Label: {label} to {dest}") - # Rename task after consuming/downloading the labels + # Rename the task to indicate that labels have been processed. patch_url = f"{self.api_url}/api/tasks/{task_id}" body = {"name": f"{self.done_prefix}_{task_name}"} requests.patch(patch_url, allow_redirects=True, auth=self.auth, json=body) @@ -311,7 +315,7 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): if retry_count: logger.exception(e) logger.error(f"{retry} => Failed to download...") - retry_count = retry_count + 1 + retry_count += 1 return None