Skip to content

Commit b264e7c

Browse files
author
Donglai Wei
committed
fix formatting and e2e testing error
1 parent 7845f39 commit b264e7c

File tree

13 files changed

+85
-67
lines changed

13 files changed

+85
-67
lines changed

connectomics/data/augment/monai_transforms.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def _apply_misalignment_translation(
9999

100100
output = np.zeros(out_shape, img.dtype)
101101
if mode == "slip":
102-
output = img[:, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
103-
output[idx] = img[idx, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
102+
output = img[:, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
103+
output[idx] = img[idx, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
104104
else:
105-
output[:idx] = img[:idx, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
106-
output[idx:] = img[idx:, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
105+
output[:idx] = img[:idx, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
106+
output[idx:] = img[idx:, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
107107

108108
if is_tensor:
109109
output = torch.from_numpy(output).to(device)
@@ -299,7 +299,7 @@ def _apply_missing_parts(
299299
x_start = self.R.randint(0, img.shape[2] - hole_w + 1)
300300

301301
# Create hole (set to 0 or mean value)
302-
img[section_idx, y_start: y_start + hole_h, x_start: x_start + hole_w] = 0
302+
img[section_idx, y_start : y_start + hole_h, x_start : x_start + hole_w] = 0
303303

304304
return img
305305

@@ -452,24 +452,24 @@ def _apply_cut_noise(
452452
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
453453
region = img[
454454
:,
455-
z_start: z_start + z_len,
456-
y_start: y_start + y_len,
457-
x_start: x_start + x_len,
455+
z_start : z_start + z_len,
456+
y_start : y_start + y_len,
457+
x_start : x_start + x_len,
458458
]
459459
noisy_region = np.clip(region + noise, 0, 1)
460460
img[
461461
:,
462-
z_start: z_start + z_len,
463-
y_start: y_start + y_len,
464-
x_start: x_start + x_len,
462+
z_start : z_start + z_len,
463+
y_start : y_start + y_len,
464+
x_start : x_start + x_len,
465465
] = noisy_region
466466
else:
467467
# (C, H, W) - 2D with channels
468468
noise_shape = (img.shape[0], y_len, x_len)
469469
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
470-
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
470+
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
471471
noisy_region = np.clip(region + noise, 0, 1)
472-
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
472+
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
473473
elif img.ndim == 3:
474474
# 3D case: (Z, Y, X) or (C, H, W)
475475
# Heuristic: if first dim is small (<=4), assume it's channel (2D with channels)
@@ -478,29 +478,29 @@ def _apply_cut_noise(
478478
# (C, H, W) - 2D with channels
479479
noise_shape = (img.shape[0], y_len, x_len)
480480
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
481-
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
481+
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
482482
noisy_region = np.clip(region + noise, 0, 1)
483-
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
483+
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
484484
else:
485485
# (Z, Y, X) - 3D
486486
z_len = max(1, int(self.length_ratio * img.shape[0])) # Ensure at least 1
487487
z_start = self.R.randint(0, max(1, img.shape[0] - z_len + 1))
488488
noise_shape = (z_len, y_len, x_len)
489489
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
490490
region = img[
491-
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
491+
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
492492
]
493493
noisy_region = np.clip(region + noise, 0, 1)
494494
img[
495-
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
495+
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
496496
] = noisy_region
497497
else:
498498
# 2D case: (H, W)
499499
noise_shape = (y_len, x_len)
500500
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
501-
region = img[y_start: y_start + y_len, x_start: x_start + x_len]
501+
region = img[y_start : y_start + y_len, x_start : x_start + x_len]
502502
noisy_region = np.clip(region + noise, 0, 1)
503-
img[y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
503+
img[y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
504504

505505
if is_tensor:
506506
img = torch.from_numpy(img).to(device)
@@ -886,7 +886,7 @@ def _find_best_paste(
886886
neuron_tensor.flip(0) if neuron_tensor.ndim == 3 else neuron_tensor.flip(1)
887887
)
888888

889-
label_paste = labels[best_idx: best_idx + 1]
889+
label_paste = labels[best_idx : best_idx + 1]
890890

891891
if best_angle != 0:
892892
label_paste = self._rotate_3d(label_paste, best_angle)

connectomics/data/io/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171171
Raises:
172172
ValueError: If file format is not supported
173173
"""
174-
image_suffix = filename[filename.rfind(".") + 1:].lower()
174+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
175175
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176176
raise ValueError(
177177
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +281,7 @@ def read_volume(
281281
if filename.endswith(".nii.gz"):
282282
image_suffix = "nii.gz"
283283
else:
284-
image_suffix = filename[filename.rfind(".") + 1:].lower()
284+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
285285

286286
if image_suffix in ["h5", "hdf5"]:
287287
data = read_hdf5(filename, dataset)
@@ -420,7 +420,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420420
if filename.endswith(".nii.gz"):
421421
image_suffix = "nii.gz"
422422
else:
423-
image_suffix = filename[filename.rfind(".") + 1:].lower()
423+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
424424

425425
if image_suffix in ["h5", "hdf5"]:
426426
# HDF5: Read shape from metadata (no data loading)

connectomics/data/io/tiles.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def reconstruct_volume_from_tiles(
160160
if is_image: # Image data
161161
result[
162162
z - z0,
163-
y_actual_start - y0: y_actual_end - y0,
164-
x_actual_start - x0: x_actual_end - x0,
163+
y_actual_start - y0 : y_actual_end - y0,
164+
x_actual_start - x0 : x_actual_end - x0,
165165
] = patch[
166-
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
167-
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
166+
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
167+
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
168168
0,
169169
]
170170
else: # Label data
171171
result[
172172
z - z0,
173-
y_actual_start - y0: y_actual_end - y0,
174-
x_actual_start - x0: x_actual_end - x0,
173+
y_actual_start - y0 : y_actual_end - y0,
174+
x_actual_start - x0 : x_actual_end - x0,
175175
] = rgb_to_seg(
176176
patch[
177-
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
178-
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
177+
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
178+
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
179179
]
180180
)
181181

connectomics/data/process/crop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1616
st = np.array(st).astype(np.int32)
1717

1818
if data.ndim == 3:
19-
return data[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
19+
return data[st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
2020
else: # crop spatial dimensions
21-
return data[:, st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
21+
return data[:, st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
2222

2323

2424
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
@@ -64,9 +64,9 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
6464
if len(vol_sz) == 3:
6565
mask_sum = (
6666
mask_sum[
67-
pad_sz_pre[0]:pad_sz_post[0],
68-
pad_sz_pre[1]:pad_sz_post[1],
69-
pad_sz_pre[2]:pad_sz_post[2],
67+
pad_sz_pre[0] : pad_sz_post[0],
68+
pad_sz_pre[1] : pad_sz_post[1],
69+
pad_sz_pre[2] : pad_sz_post[2],
7070
]
7171
>= valid_thres
7272
)
@@ -86,7 +86,7 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
8686
)
8787
else:
8888
mask_sum = (
89-
mask_sum[pad_sz_pre[0]:pad_sz_post[0], pad_sz_pre[1]:pad_sz_post[1]] >= valid_thres
89+
mask_sum[pad_sz_pre[0] : pad_sz_post[0], pad_sz_pre[1] : pad_sz_post[1]] >= valid_thres
9090
)
9191
if mask_sum.max() > 0:
9292
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), np.arange(mask_sum.shape[1]))

connectomics/decoding/optuna_tuner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,7 @@ def _objective(self, trial: optuna.Trial) -> float:
355355

356356
print(
357357
"\n❌ Trial {count} failed during post-processing "
358-
"(volume {vol_idx}):".format(
359-
count=self.trial_count, vol_idx=vol_idx
360-
)
358+
"(volume {vol_idx}):".format(count=self.trial_count, vol_idx=vol_idx)
361359
)
362360
print(f" Parameters: {postproc_params}")
363361
print(f" Error: {e}")
@@ -381,9 +379,7 @@ def _objective(self, trial: optuna.Trial) -> float:
381379

382380
print(
383381
"\n❌ Trial {count} failed during metric computation "
384-
"(volume {vol_idx}):".format(
385-
count=self.trial_count, vol_idx=vol_idx
386-
)
382+
"(volume {vol_idx}):".format(count=self.trial_count, vol_idx=vol_idx)
387383
)
388384
print(f" Metric: {metric_name}")
389385
print(f" Segmentation shape: {segmentation.shape}, dtype: {segmentation.dtype}")

connectomics/decoding/postprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def apply_binary_postprocessing(
360360

361361
if len(sizes) > cc_config.top_k:
362362
# Get indices of top-k largest components
363-
top_k_indices = np.argsort(sizes)[-cc_config.top_k:]
363+
top_k_indices = np.argsort(sizes)[-cc_config.top_k :]
364364
top_k_labels = label_ids[top_k_indices]
365365

366366
# Create mask keeping only top-k

connectomics/inference/tta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def apply_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
117117
pass # Keep all channels
118118
elif isinstance(tta_channel, int):
119119
if tta_channel != -1:
120-
tensor = tensor[:, tta_channel: tta_channel + 1, ...]
120+
tensor = tensor[:, tta_channel : tta_channel + 1, ...]
121121
elif isinstance(tta_channel, (list, tuple, Sequence)):
122122
# Convert to list of integers (handle both int and string numbers
123123
# from OmegaConf)

connectomics/models/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_external_weights(model, cfg):
6565
stripped_count = 0
6666
for key, value in state_dict.items():
6767
if key.startswith(key_prefix):
68-
new_key = key[len(key_prefix):]
68+
new_key = key[len(key_prefix) :]
6969
new_state_dict[new_key] = value
7070
stripped_count += 1
7171
else:

connectomics/training/deep_supervision.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def compute_multitask_loss(
123123
num_label_channels = 1
124124

125125
# Extract label channels
126-
task_label = labels[:, label_ch_offset: label_ch_offset + num_label_channels, ...]
126+
task_label = labels[:, label_ch_offset : label_ch_offset + num_label_channels, ...]
127127
label_ch_offset += num_label_channels
128128

129129
# Apply specified losses for this task
@@ -144,9 +144,7 @@ def compute_multitask_loss(
144144
print(f"Loss value: {loss.item()}")
145145
output_range = f"[{task_output.min():.4f}, {task_output.max():.4f}]"
146146
label_range = f"[{task_label.min():.4f}, {task_label.max():.4f}]"
147-
print(
148-
f"Output shape: {task_output.shape}, range: {output_range}"
149-
)
147+
print(f"Output shape: {task_output.shape}, range: {output_range}")
150148
print(f"Label shape: {task_label.shape}, range: {label_range}")
151149
print(f"Output contains NaN: {torch.isnan(task_output).any()}")
152150
print(f"Label contains NaN: {torch.isnan(task_label).any()}")
@@ -242,9 +240,7 @@ def compute_loss_for_scale(
242240
print(f"Loss value: {loss.item()}")
243241
output_range = f"[{task_output.min():.4f}, {task_output.max():.4f}]"
244242
target_range = f"[{task_target.min():.4f}, {task_target.max():.4f}]"
245-
print(
246-
f"Output shape: {task_output.shape}, range: {output_range}"
247-
)
243+
print(f"Output shape: {task_output.shape}, range: {output_range}")
248244
print(f"Target shape: {task_target.shape}, range: {target_range}")
249245
if self.debug_on_nan:
250246
print("\nEntering debugger...")
@@ -282,9 +278,7 @@ def compute_loss_for_scale(
282278
print(f"Scale: {scale_idx}, Weight: {weight}")
283279
out_range = f"[{output.min():.4f}, {output.max():.4f}]"
284280
tgt_range = f"[{target.min():.4f}, {target.max():.4f}]"
285-
print(
286-
f"Output shape: {output.shape}, range: {out_range}"
287-
)
281+
print(f"Output shape: {output.shape}, range: {out_range}")
288282
print(f"Target shape: {target.shape}, range: {tgt_range}")
289283
print(f"Output contains NaN: {torch.isnan(output).any()}")
290284
print(f"Target contains NaN: {torch.isnan(target).any()}")
@@ -394,9 +388,7 @@ def compute_standard_loss(
394388
print(f"Loss index: {i}, Weight: {weight}")
395389
out_range = f"[{outputs.min():.4f}, {outputs.max():.4f}]"
396390
label_range = f"[{labels.min():.4f}, {labels.max():.4f}]"
397-
print(
398-
f"Output shape: {outputs.shape}, range: {out_range}"
399-
)
391+
print(f"Output shape: {outputs.shape}, range: {out_range}")
400392
print(f"Label shape: {labels.shape}, range: {label_range}")
401393
print(f"Output contains NaN: {torch.isnan(outputs).any()}")
402394
print(f"Label contains NaN: {torch.isnan(labels).any()}")

connectomics/training/lit/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def create_datamodule(
137137
try:
138138
size_mb = DATASETS[dataset_name]["size_mb"]
139139
prompt = f" Download {dataset_name} dataset (~{size_mb} MB)? [Y/n]: "
140-
response = (
141-
input(prompt).strip().lower()
142-
)
140+
response = input(prompt).strip().lower()
143141
if response in ["", "y", "yes"]:
144142
if download_dataset(dataset_name, base_dir=PathLib.cwd()):
145143
print("✅ Data downloaded successfully!")

0 commit comments

Comments
 (0)