Skip to content

Commit 9c21490

Browse files
authored
fix when params_filename is None (PaddlePaddle#1106)
1 parent 410d90e commit 9c21490

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

paddleslim/auto_compression/auto_strategy.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,26 @@ def create_strategy_config(strategy_str, model_type):
7171

7272
dis_config = Distillation()
7373
if len(tmp_s) == 3:
74+
### TODO(ceci3): choose prune algo automatically
75+
if 'prune' in tmp_s[0]:
76+
### default prune config
77+
default_prune_config = {
78+
'pruned_ratio': float(tmp_s[1]),
79+
'prune_algo': 'prune',
80+
'criterion': 'l1_norm'
81+
}
82+
else:
83+
### default unstruture prune config
84+
default_prune_config = {
85+
'prune_strategy':
86+
'gmp', ### default unstruture prune strategy is gmp
87+
'prune_mode': 'ratio',
88+
'pruned_ratio': float(tmp_s[1]),
89+
'local_sparsity': True,
90+
'prune_params_type': 'conv1x1_only'
91+
}
7492
tmp_s[0] = tmp_s[0].replace('prune', 'Prune')
7593
tmp_s[0] = tmp_s[0].replace('sparse', 'UnstructurePrune')
76-
### TODO(ceci3): auto choose prune algo
77-
default_prune_config = {
78-
'pruned_ratio': float(tmp_s[1]),
79-
'prune_algo': 'prune',
80-
'criterion': 'l1_norm'
81-
}
8294
if model_type == 'transformer' and tmp_s[0] == 'Prune':
8395
default_prune_config['prune_algo'] = 'transformer_pruner'
8496
prune_config = eval(tmp_s[0])(**default_prune_config)

paddleslim/auto_compression/compressor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ def __init__(self,
9797
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
9898
"""
9999
self.model_dir = model_dir
100+
if model_filename == 'None':
101+
model_filename = None
100102
self.model_filename = model_filename
103+
if params_filename == 'None':
104+
params_filename = None
101105
self.params_filename = params_filename
102106
base_path = os.path.basename(os.path.normpath(save_dir))
103107
parent_path = os.path.abspath(os.path.join(save_dir, os.pardir))

paddleslim/auto_compression/create_compressed_program.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def _load_program_and_merge(executor,
100100
feed_target_names=None):
101101
scope = paddle.static.global_scope()
102102
new_scope = paddle.static.Scope()
103-
print(model_dir, model_filename, params_filename)
103+
if params_filename == 'None':
104+
params_filename = None
104105
try:
105106
with paddle.static.scope_guard(new_scope):
106107
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \

paddleslim/auto_compression/strategy_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
"weight_quantize_type"
3535
])
3636

37-
Quantization.__new__.__defaults__ = (None, ) * (len(Quantization._fields) - 1
38-
) + (False, )
37+
Quantization.__new__.__defaults__ = (None, ) * (
38+
len(Quantization._fields) - 3) + (False, 'moving_average_abs_max',
39+
'channel_wise_abs_max')
3940

4041
### Distillation:
4142
Distillation = namedtuple(

0 commit comments

Comments
 (0)