Skip to content

Commit 67b1189

Browse files
author
Julian Kates-Harbeck
committed
slight model architecture modification. dense layer after convolutions has dense_size
1 parent cd20621 commit 67b1189

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

plasma/models/builder.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os,sys
2020
import numpy as np
2121
from copy import deepcopy
22+
from plasma.utils.downloading import makedirs_process_safe
2223

2324

2425
class LossHistory(Callback):
@@ -137,16 +138,18 @@ def slicer_output_shape(input_shape,indices):
137138
pre_rnn_1D = Permute((2,1)) (pre_rnn_1D)
138139

139140
for i in range(model_conf['num_conv_layers']):
140-
pre_rnn_1D = Convolution1D(num_conv_filters,size_conv_filters,padding='valid',activation='relu') (pre_rnn_1D)
141+
div_fac = 2**i
142+
pre_rnn_1D = Convolution1D(num_conv_filters/div_fac,size_conv_filters,padding='valid',activation='relu') (pre_rnn_1D)
143+
pre_rnn_1D = Convolution1D(num_conv_filters/div_fac,1,padding='valid',activation='relu') (pre_rnn_1D)
141144
pre_rnn_1D = MaxPooling1D(pool_size) (pre_rnn_1D)
142145
pre_rnn_1D = Flatten() (pre_rnn_1D)
143-
pre_rnn_1D = Dense(num_conv_filters*4,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn_1D)
144-
pre_rnn_1D = Dense(num_conv_filters,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn_1D)
146+
pre_rnn_1D = Dense(dense_size,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn_1D)
147+
pre_rnn_1D = Dense(dense_size/4,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn_1D)
145148
pre_rnn = Concatenate() ([pre_rnn_0D,pre_rnn_1D])
146149
else:
147150
pre_rnn = pre_rnn_input
148151

149-
if model_conf['rnn_layers'] == 0:
152+
if model_conf['rnn_layers'] == 0 or model_conf['extra_dense_input']:
150153
pre_rnn = Dense(dense_size,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn)
151154
pre_rnn = Dense(dense_size/2,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn)
152155
pre_rnn = Dense(dense_size/4,activation='relu',kernel_regularizer=l2(dense_regularization),bias_regularizer=l2(dense_regularization),activity_regularizer=l2(dense_regularization)) (pre_rnn)
@@ -214,16 +217,7 @@ def get_save_path(self,epoch):
214217

215218
def ensure_save_directory(self):
216219
prepath = self.conf['paths']['model_save_path']
217-
if not os.path.exists(prepath):
218-
try: #can lead to race condition
219-
os.makedirs(prepath)
220-
except OSError as e:
221-
if e.errno == errno.EEXIST:
222-
# File exists, and it's a directory, another process beat us to creating this dir, that's OK.
223-
pass
224-
else:
225-
# Our target dir exists as a file, or different error, reraise the error!
226-
raise
220+
makedirs_process_safe(prepath)
227221

228222
def load_model_weights(self,model,custom_path=None):
229223
if custom_path == None:

0 commit comments

Comments
 (0)