Skip to content

Commit 18ff577

Browse files
author
Donglai Wei
committed
fix formatting
1 parent 7e1a353 commit 18ff577

File tree

14 files changed

+67
-63
lines changed

14 files changed

+67
-63
lines changed

connectomics/data/augment/monai_transforms.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def _apply_misalignment_translation(
9999

100100
output = np.zeros(out_shape, img.dtype)
101101
if mode == "slip":
102-
output = img[:, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
103-
output[idx] = img[idx, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
102+
output = img[:, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
103+
output[idx] = img[idx, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
104104
else:
105-
output[:idx] = img[:idx, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
106-
output[idx:] = img[idx:, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
105+
output[:idx] = img[:idx, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
106+
output[idx:] = img[idx:, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
107107

108108
if is_tensor:
109109
output = torch.from_numpy(output).to(device)
@@ -299,7 +299,7 @@ def _apply_missing_parts(
299299
x_start = self.R.randint(0, img.shape[2] - hole_w + 1)
300300

301301
# Create hole (set to 0 or mean value)
302-
img[section_idx, y_start : y_start + hole_h, x_start : x_start + hole_w] = 0
302+
img[section_idx, y_start:y_start + hole_h, x_start:x_start + hole_w] = 0
303303

304304
return img
305305

@@ -452,24 +452,24 @@ def _apply_cut_noise(
452452
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
453453
region = img[
454454
:,
455-
z_start : z_start + z_len,
456-
y_start : y_start + y_len,
457-
x_start : x_start + x_len,
455+
z_start:z_start + z_len,
456+
y_start:y_start + y_len,
457+
x_start:x_start + x_len,
458458
]
459459
noisy_region = np.clip(region + noise, 0, 1)
460460
img[
461461
:,
462-
z_start : z_start + z_len,
463-
y_start : y_start + y_len,
464-
x_start : x_start + x_len,
462+
z_start:z_start + z_len,
463+
y_start:y_start + y_len,
464+
x_start:x_start + x_len,
465465
] = noisy_region
466466
else:
467467
# (C, H, W) - 2D with channels
468468
noise_shape = (img.shape[0], y_len, x_len)
469469
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
470-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
470+
region = img[:, y_start:y_start + y_len, x_start:x_start + x_len]
471471
noisy_region = np.clip(region + noise, 0, 1)
472-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
472+
img[:, y_start:y_start + y_len, x_start:x_start + x_len] = noisy_region
473473
elif img.ndim == 3:
474474
# 3D case: (Z, Y, X) or (C, H, W)
475475
# Heuristic: if first dim is small (<=4), assume it's channel (2D with channels)
@@ -478,29 +478,31 @@ def _apply_cut_noise(
478478
# (C, H, W) - 2D with channels
479479
noise_shape = (img.shape[0], y_len, x_len)
480480
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
481-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
481+
region = img[:, y_start:y_start + y_len, x_start:x_start + x_len]
482482
noisy_region = np.clip(region + noise, 0, 1)
483-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
483+
img[:, y_start:y_start + y_len, x_start:x_start + x_len] = noisy_region
484484
else:
485485
# (Z, Y, X) - 3D
486486
z_len = max(1, int(self.length_ratio * img.shape[0])) # Ensure at least 1
487487
z_start = self.R.randint(0, max(1, img.shape[0] - z_len + 1))
488488
noise_shape = (z_len, y_len, x_len)
489489
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
490490
region = img[
491-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
491+
z_start:z_start + z_len,
492+
y_start:y_start + y_len,
493+
x_start:x_start + x_len,
492494
]
493495
noisy_region = np.clip(region + noise, 0, 1)
494496
img[
495-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
497+
z_start:z_start + z_len, y_start:y_start + y_len, x_start:x_start + x_len
496498
] = noisy_region
497499
else:
498500
# 2D case: (H, W)
499501
noise_shape = (y_len, x_len)
500502
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
501-
region = img[y_start : y_start + y_len, x_start : x_start + x_len]
503+
region = img[y_start:y_start + y_len, x_start:x_start + x_len]
502504
noisy_region = np.clip(region + noise, 0, 1)
503-
img[y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
505+
img[y_start:y_start + y_len, x_start:x_start + x_len] = noisy_region
504506

505507
if is_tensor:
506508
img = torch.from_numpy(img).to(device)
@@ -886,7 +888,7 @@ def _find_best_paste(
886888
neuron_tensor.flip(0) if neuron_tensor.ndim == 3 else neuron_tensor.flip(1)
887889
)
888890

889-
label_paste = labels[best_idx : best_idx + 1]
891+
label_paste = labels[best_idx:best_idx + 1]
890892

891893
if best_angle != 0:
892894
label_paste = self._rotate_3d(label_paste, best_angle)

connectomics/data/io/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171171
Raises:
172172
ValueError: If file format is not supported
173173
"""
174-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
174+
image_suffix = filename[filename.rfind(".") + 1:].lower()
175175
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176176
raise ValueError(
177177
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +281,7 @@ def read_volume(
281281
if filename.endswith(".nii.gz"):
282282
image_suffix = "nii.gz"
283283
else:
284-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
284+
image_suffix = filename[filename.rfind(".") + 1:].lower()
285285

286286
if image_suffix in ["h5", "hdf5"]:
287287
data = read_hdf5(filename, dataset)
@@ -420,7 +420,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420420
if filename.endswith(".nii.gz"):
421421
image_suffix = "nii.gz"
422422
else:
423-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
423+
image_suffix = filename[filename.rfind(".") + 1:].lower()
424424

425425
if image_suffix in ["h5", "hdf5"]:
426426
# HDF5: Read shape from metadata (no data loading)

connectomics/data/io/tiles.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def reconstruct_volume_from_tiles(
160160
if is_image: # Image data
161161
result[
162162
z - z0,
163-
y_actual_start - y0 : y_actual_end - y0,
164-
x_actual_start - x0 : x_actual_end - x0,
163+
y_actual_start - y0:y_actual_end - y0,
164+
x_actual_start - x0:x_actual_end - x0,
165165
] = patch[
166-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
167-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
166+
y_actual_start - y_patch_start:y_actual_end - y_patch_start,
167+
x_actual_start - x_patch_start:x_actual_end - x_patch_start,
168168
0,
169169
]
170170
else: # Label data
171171
result[
172172
z - z0,
173-
y_actual_start - y0 : y_actual_end - y0,
174-
x_actual_start - x0 : x_actual_end - x0,
173+
y_actual_start - y0:y_actual_end - y0,
174+
x_actual_start - x0:x_actual_end - x0,
175175
] = rgb_to_seg(
176176
patch[
177-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
178-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
177+
y_actual_start - y_patch_start:y_actual_end - y_patch_start,
178+
x_actual_start - x_patch_start:x_actual_end - x_patch_start,
179179
]
180180
)
181181

connectomics/data/process/crop.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1616
st = np.array(st).astype(np.int32)
1717

1818
if data.ndim == 3:
19-
return data[st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
19+
return data[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
2020
else: # crop spatial dimensions
21-
return data[:, st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
21+
return data[:, st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
2222

2323

2424
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
@@ -62,14 +62,11 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
6262
pad_sz_post = data_sz - (vol_sz - pad_sz_pre - 1)
6363
valid_pos = np.zeros([0, 3])
6464
if len(vol_sz) == 3:
65-
mask_sum = (
66-
mask_sum[
67-
pad_sz_pre[0] : pad_sz_post[0],
68-
pad_sz_pre[1] : pad_sz_post[1],
69-
pad_sz_pre[2] : pad_sz_post[2],
70-
]
71-
>= valid_thres
72-
)
65+
mask_sum = mask_sum[
66+
pad_sz_pre[0]:pad_sz_post[0],
67+
pad_sz_pre[1]:pad_sz_post[1],
68+
pad_sz_pre[2]:pad_sz_post[2],
69+
] >= valid_thres
7370
if mask_sum.max() > 0:
7471
zz, yy, xx = np.meshgrid(
7572
np.arange(mask_sum.shape[0]),
@@ -86,7 +83,7 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
8683
)
8784
else:
8885
mask_sum = (
89-
mask_sum[pad_sz_pre[0] : pad_sz_post[0], pad_sz_pre[1] : pad_sz_post[1]] >= valid_thres
86+
mask_sum[pad_sz_pre[0]:pad_sz_post[0], pad_sz_pre[1]:pad_sz_post[1]] >= valid_thres
9087
)
9188
if mask_sum.max() > 0:
9289
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), np.arange(mask_sum.shape[1]))

connectomics/decoding/postprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def apply_binary_postprocessing(
360360

361361
if len(sizes) > cc_config.top_k:
362362
# Get indices of top-k largest components
363-
top_k_indices = np.argsort(sizes)[-cc_config.top_k :]
363+
top_k_indices = np.argsort(sizes)[-cc_config.top_k:]
364364
top_k_labels = label_ids[top_k_indices]
365365

366366
# Create mask keeping only top-k

connectomics/inference/tta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def apply_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
117117
pass # Keep all channels
118118
elif isinstance(tta_channel, int):
119119
if tta_channel != -1:
120-
tensor = tensor[:, tta_channel : tta_channel + 1, ...]
120+
tensor = tensor[:, tta_channel:tta_channel + 1, ...]
121121
elif isinstance(tta_channel, (list, tuple, Sequence)):
122122
# Convert to list of integers (handle both int and string numbers
123123
# from OmegaConf)

connectomics/models/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_external_weights(model, cfg):
6565
stripped_count = 0
6666
for key, value in state_dict.items():
6767
if key.startswith(key_prefix):
68-
new_key = key[len(key_prefix) :]
68+
new_key = key[len(key_prefix):]
6969
new_state_dict[new_key] = value
7070
stripped_count += 1
7171
else:

connectomics/training/deep_supervision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def compute_multitask_loss(
123123
num_label_channels = 1
124124

125125
# Extract label channels
126-
task_label = labels[:, label_ch_offset : label_ch_offset + num_label_channels, ...]
126+
task_label = labels[:, label_ch_offset:label_ch_offset + num_label_channels, ...]
127127
label_ch_offset += num_label_channels
128128

129129
# Apply specified losses for this task

connectomics/training/lit/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,6 @@ def _compute_test_metrics(
426426
):
427427
# Adapted Rand requires instance segmentation labels (integer labels), not binary
428428
# decoded_predictions should already be instance segmentation from decode_instance_*
429-
# Check original shape before processing to handle batch dimension correctly
430-
original_shape = decoded_predictions.shape
431429
pred_instance = torch.from_numpy(decoded_predictions).long()
432430

433431
# Labels should also be instance segmentation (integer labels)
@@ -438,8 +436,10 @@ def _compute_test_metrics(
438436
while labels_instance.ndim > 3 and labels_instance.shape[0] == 1:
439437
labels_instance = labels_instance.squeeze(0)
440438

441-
# Squeeze all leading dimensions of size 1 from predictions (remove batch & channel dims)
442-
# Predictions can be: (B, C, Z, H, W), (B, Z, H, W), (C, Z, H, W), or (Z, H, W)
439+
# Squeeze all leading dimensions of size 1 from predictions
440+
# (remove batch & channel dims)
441+
# Predictions can be: (B, C, Z, H, W), (B, Z, H, W),
442+
# (C, Z, H, W), or (Z, H, W)
443443
while pred_instance.ndim > 3 and pred_instance.shape[0] == 1:
444444
pred_instance = pred_instance.squeeze(0)
445445

connectomics/training/lit/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,8 @@ def setup_config(args) -> Config:
195195
print("🔧 Fast-dev-run mode: Overriding config for debugging")
196196
print(f" - num_gpus: {cfg.system.training.num_gpus} → 1")
197197
print(f" - num_cpus: {cfg.system.training.num_cpus} → 1")
198-
print(
199-
f" - num_workers: {cfg.system.training.num_workers} → 0 (avoid multiprocessing in debug mode)"
200-
)
198+
print(f" - num_workers: {cfg.system.training.num_workers} → 0 "
199+
"(avoid multiprocessing in debug mode)")
201200
print(
202201
f" - batch_size: Controlled by PyTorch Lightning (--fast-dev-run={args.fast_dev_run})"
203202
)

0 commit comments

Comments
 (0)