Skip to content

Commit dd2c8a6

Browse files
committed
Fix flake8 style warnings for all files except torch_runner.py
1 parent e554500 commit dd2c8a6

File tree

7 files changed

+33
-32
lines changed

7 files changed

+33
-32
lines changed

plasma/conf_parser.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ def parameters(input_file):
9696
sig.jet, params['paths']['shot_list_dir'],
9797
['ILW_unint_late.txt', 'ILW_clear_late.txt'],
9898
'Late jet iter like wall data')
99-
jet_iterlike_wall_full = ShotListFiles(
100-
sig.jet, params['paths']['shot_list_dir'],
101-
['ILW_unint_full.txt', 'ILW_clear_full.txt'],
102-
'Full jet iter like wall data')
99+
# jet_iterlike_wall_full = ShotListFiles(
100+
# sig.jet, params['paths']['shot_list_dir'],
101+
# ['ILW_unint_full.txt', 'ILW_clear_full.txt'],
102+
# 'Full jet iter like wall data')
103103

104104
jenkins_jet_carbon_wall = ShotListFiles(
105105
sig.jet, params['paths']['shot_list_dir'],
@@ -169,12 +169,12 @@ def parameters(input_file):
169169
params['paths']['shot_files'] = [jet_carbon_wall]
170170
params['paths']['shot_files_test'] = [jet_iterlike_wall]
171171
params['paths']['use_signals_dict'] = {
172-
'etemp_profile': etemp_profile}
172+
'etemp_profile': sig.etemp_profile}
173173
elif params['paths']['data'] == 'jet_data_dens_profile':
174174
params['paths']['shot_files'] = [jet_carbon_wall]
175175
params['paths']['shot_files_test'] = [jet_iterlike_wall]
176176
params['paths']['use_signals_dict'] = {
177-
'edens_profile': edens_profile}
177+
'edens_profile': sig.edens_profile}
178178
elif params['paths']['data'] == 'jet_carbon_data':
179179
params['paths']['shot_files'] = [jet_carbon_wall]
180180
params['paths']['shot_files_test'] = []
@@ -302,13 +302,13 @@ def parameters(input_file):
302302
params['paths']['shot_files'] = [d3d_full]
303303
params['paths']['shot_files_test'] = []
304304
params['paths']['use_signals_dict'] = {
305-
'etemp_profile': etemp_profile} # fully_defined_signals_0D
305+
'etemp_profile': sig.etemp_profile} # fully_defined_signals_0D
306306
elif params['paths']['data'] == 'd3d_data_dens_profile':
307307
# jet data but with fully defined signals
308308
params['paths']['shot_files'] = [d3d_full]
309309
params['paths']['shot_files_test'] = []
310310
params['paths']['use_signals_dict'] = {
311-
'edens_profile': edens_profile} # fully_defined_signals_0D
311+
'edens_profile': sig.edens_profile} # fully_defined_signals_0D
312312

313313
# cross-machine
314314
elif params['paths']['data'] == 'jet_to_d3d_data':

plasma/models/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def build_model(self, predict, custom_batch_size=None):
8282
conf = self.conf
8383
model_conf = conf['model']
8484
rnn_size = model_conf['rnn_size']
85+
use_bidirectional = model_conf['use_bidirectional']
8586
rnn_type = model_conf['rnn_type']
8687
regularization = model_conf['regularization']
8788
dense_regularization = model_conf['dense_regularization']
@@ -353,7 +354,8 @@ def get_all_saved_files(self):
353354
self.ensure_save_directory()
354355
unique_id = self.get_unique_id()
355356
path = self.conf['paths']['model_save_path']
356-
filenames = [name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]
357+
filenames = [name for name in os.listdir(path)
358+
if os.path.isfile(os.path.join(path, name))]
357359
epochs = []
358360
for file in filenames:
359361
curr_id, epoch = self.extract_id_and_epoch_from_filename(file)

plasma/models/mpi_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def mpi_make_predictions_and_evaluate_multiple_times(conf, shot_list, loader,
786786

787787

788788
def mpi_train(conf, shot_list_train, shot_list_validate, loader,
789-
callbacks_list=None shot_list_test=None):
789+
callbacks_list=None, shot_list_test=None):
790790
loader.set_inference_mode(False)
791791
conf['num_workers'] = comm.Get_size()
792792

@@ -949,7 +949,7 @@ def mpi_train(conf, shot_list_train, shot_list_validate, loader,
949949
round(e)), epoch_logs)
950950

951951
print_unique("end epoch {} 0".format(e))
952-
stop_training = comm.bcast(stop_training,root=0)
952+
stop_training = comm.bcast(stop_training, root=0)
953953
print_unique("end epoch {} 1".format(e))
954954
if stop_training:
955955
print("Stopping training due to early stopping")

plasma/models/shallow_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import time
2222
import numpy as np
2323
from copy import deepcopy
24-
from keras.utils.generic_utils import Progbar
2524

2625
# import matplotlib.pyplot as plt
2726
import matplotlib

plasma/preprocessor/normalize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,13 @@ def cut_end_of_shot(self, shot):
153153
# only cut shots during training
154154
if not self.inference_mode and cut_shot_ends:
155155
T_min_warn = self.conf['data']['T_min_warn']
156-
if shot.ttd.shape[0] - T_min_warn <= max(self.conf['model']['length'],0):
157-
print("not cutting shot since length of shot after cutting by T_min_warn would be shorter than RNN length")
156+
if shot.ttd.shape[0] - T_min_warn <= max(
157+
self.conf['model']['length'], 0):
158+
print("not cutting shot; length of shot after cutting by ",
159+
"T_min_warn would be shorter than RNN length")
158160
return
159161
for key in shot.signals_dict:
160-
shot.signals_dict[key] = shot.signals_dict[key][:-T_min_warn,
161-
:]
162+
shot.signals_dict[key] = shot.signals_dict[key][:-T_min_warn,:] # noqa
162163
shot.ttd = shot.ttd[:-T_min_warn]
163164

164165
# def apply_mask(self,shot):

plasma/utils/downloading.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333

3434

3535
def general_object_hash(o):
36-
"""Makes a hash from a dictionary, list, tuple or set to any level, that contains
37-
only other hashable types (including any lists, tuples, sets, and
38-
dictionaries). Relies on dill for serialization"""
36+
"""
37+
Makes a hash from a dictionary, list, tuple or set to any level, that
38+
contains only other hashable types (including any lists, tuples, sets, and
39+
dictionaries). Relies on dill for serialization
40+
"""
3941

4042
if isinstance(o, (set, tuple, list)):
4143
return tuple([general_object_hash(e) for e in o])
@@ -51,9 +53,12 @@ def general_object_hash(o):
5153

5254

5355
def myhash(x):
54-
return int(hashlib.md5((dill.dumps(x).decode('unicode_escape')).encode('utf-8')).hexdigest(),16)
56+
return int(hashlib.md5((dill.dumps(x).decode('unicode_escape')).encode(
57+
'utf-8')).hexdigest(), 16)
5558
# return int(hashlib.md5((dill.dumps(x))).hexdigest(),16)
56-
# return int(hashlib.md5((dill.dumps(x))))#.decode('unicode_escape')).encode('utf-8')).hexdigest(),16)
59+
# return
60+
# int(hashlib.md5((dill.dumps(x))))#.decode('unicode_escape')).encode(
61+
# 'utf-8')).hexdigest(),16)
5762

5863

5964
def get_missing_value_array():

plasma/utils/performance.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,8 @@
1414

1515

1616
class PerformanceAnalyzer():
17-
def __init__(
18-
self,
19-
results_dir=None,
20-
shots_dir=None,
21-
i=0,
22-
T_min_warn=None,
23-
T_max_warn=None,
24-
verbose=False,
25-
pred_ttd=False,
26-
conf=None):
17+
def __init__(self, results_dir=None, shots_dir=None, i=0, T_min_warn=None,
18+
T_max_warn=None, verbose=False, pred_ttd=False, conf=None):
2719
self.T_min_warn = T_min_warn
2820
self.T_max_warn = T_max_warn
2921
dt = conf['data']['dt']
@@ -35,7 +27,9 @@ def __init__(
3527
if T_max_warn is None:
3628
self.T_max_warn = T_max_warn_def
3729
if self.T_max_warn < self.T_min_warn:
38-
print("T max warn is too small: need to increase artificially.") #computation of statistics is only correct if T_max_warn is larger than T_min_warn
30+
# computation of statistics is only correct if T_max_warn is larger
31+
# than T_min_warn
32+
print("T max warn is too small: need to increase artificially.")
3933
self.T_max_warn = self.T_min_warn + 1
4034
self.verbose = verbose
4135
self.results_dir = results_dir

0 commit comments

Comments
 (0)