@@ -248,6 +248,21 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
248248 feed_target_names , fetch_targets )
249249
250250 config_dict = dict (config ._asdict ())
251+ if config_dict ["prune_strategy" ] == "gmp" and config_dict [
252+ 'gmp_config' ] is None :
253+ _logger .info (
254+ "Calculating the iterations per epoch……(It will take some time)" )
255+ # NOTE:XXX: This way of calculating the iters needs to be improved.
256+ iters_per_epoch = len (list (self .train_dataloader ()))
257+ total_iters = self .train_config .epochs * iters_per_epoch
258+ config_dict ['gmp_config' ] = {
259+ 'stable_iterations' : 0 ,
260+ 'pruning_iterations' : 0.45 * total_iters ,
261+ 'tunning_iterations' : 0.45 * total_iters ,
262+ 'resume_iteration' : - 1 ,
263+ 'pruning_steps' : 100 ,
264+ 'initial_ratio' : 0.15 ,
265+ }
251266 ### add prune program
252267 self ._pruner = None
253268 if 'prune' in strategy :
@@ -280,13 +295,14 @@ def _prepare_program(self, program, feed_target_names, fetch_targets,
280295 test_program_info )
281296 if self .train_config .sparse_model :
282297 from ..prune .unstructured_pruner import UnstructuredPruner
298+ # NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
283299 self ._pruner = UnstructuredPruner (
284300 train_program_info .program ,
285301 mode = 'ratio' ,
286302 ratio = 0.75 ,
287303 prune_params_type = 'conv1x1_only' ,
288304 place = self ._places )
289- self ._pruner .set_static_masks ()
305+ self ._pruner .set_static_masks () # Fixed model sparsity
290306
291307 self ._exe .run (train_program_info .startup_program )
292308
0 commit comments