Skip to content

Commit be953be

Browse files
committed
apply upsampling before the network
1 parent faed7f1 commit be953be

File tree

5 files changed

+16
-11
lines changed

5 files changed

+16
-11
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ preprocessed/
99
picked_fb_test_data/
1010
example/checkpoints/
1111

12-
*.egg-info
12+
*.egg-info
13+
0_fbp_env/
14+
build/

example/0_ex_train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#%% ========== Loading required packages
22
import pandas as pd
3-
import sys
43
import os.path as osp
54
import matplotlib.pyplot as plt
65
from time import time
76
from pathlib import Path
87
import shutil
98
from matplotlib.ticker import MaxNLocator
10-
11-
sys.path.append(osp.abspath(osp.join(__file__, "../../")))
9+
# import sys
10+
# sys.path.append(osp.abspath(osp.join(__file__, "../../")))
1211
from first_break_picking import train
1312
from first_break_picking.tools import seed_everything
1413
from first_break_picking.data import save_shots_fb

example/1_ex_predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#%% ========== Loading required packages
2-
import sys
32
import os.path as osp
43
from pathlib import Path
54
import shutil
65

7-
sys.path.append(osp.abspath(osp.join(__file__, "../../")))
6+
# import sys
7+
# sys.path.append(osp.abspath(osp.join(__file__, "../../")))
88
from first_break_picking.data import save_shots_fb
99
from first_break_picking import predict
1010

example/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#%% ============ Initiate ===============
2-
num_epcohs = 5
2+
num_epcohs = 20
33

44
batch_size = 15
55
split_nt = 22

first_break_picking/train_eval/predict.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def predict(base_dir: str,
347347
with torch.no_grad():
348348
if validation:
349349
true_masks = []
350-
for shot_number, (batch, true_mask, fbt_file_name ) in enumerate(loop):
350+
for shot_number, (batch, true_mask, fbt_file_name) in enumerate(loop):
351351
fbt_file_name = fbt_file_name[0]
352352

353353
shot1, predicted_pick, predicted_segment, true_mask1 = predict_validation(
@@ -358,7 +358,7 @@ def predict(base_dir: str,
358358
overlap=overlap,
359359
shot_id=fbt_file_name,
360360
smoothing_threshold=smoothing_threshold,
361-
upsampler=upsampler,
361+
# upsampler=upsampler,
362362
data_info=data_info,
363363
case_specific_parameters=case_specific_parameters
364364
)
@@ -373,14 +373,18 @@ def predict(base_dir: str,
373373
for shot_number, (batch, fbt_file_name) in enumerate(loop):
374374
fbt_file_name = fbt_file_name[0]
375375

376+
# nsp: data.shape=[1, 3, 1, 512, 22]
377+
batch, _ = upsampler(batch.squeeze(0), batch.squeeze(0))
378+
#nsp: data.shape= [3, 1, 512, 512])
379+
376380
shot, predicted_pick, predicted_segment = predict_test(
377-
batch=batch,
381+
batch=batch.unsqueeze(0),
378382
model=model,
379383
split_nt=split_nt,
380384
overlap=overlap,
381385
shot_id=fbt_file_name,
382386
smoothing_threshold=smoothing_threshold,
383-
upsampler=upsampler,
387+
# upsampler=upsampler,
384388
data_info=data_info,
385389
case_specific_parameters=case_specific_parameters
386390
)

0 commit comments

Comments
 (0)