File tree Expand file tree Collapse file tree 1 file changed +10
-8
lines changed
Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -138,21 +138,23 @@ void backend_set_n_threads(backend_device& b, int n_threads) {
138138//
139139// model_build_flags
140140
141- model_build_flags flash_attn_flag () {
142- static model_build_flags const flag = []() {
143- char const * env = getenv (" VISP_NO_FLASH_ATTENTION" );
144- return !env || env[0 ] == ' 0' ? model_build_flag::flash_attention : model_build_flags{};
145- }();
146- return flag;
141+ model_build_flags flash_attn_flag (bool default_enabled) {
142+ static char const * const env = getenv (" VISP_FLASH_ATTENTION" );
143+ if (env && env[0 ] == ' 1' ) {
144+ return model_build_flag::flash_attention;
145+ } else if (env && env[0 ] == ' 0' ) {
146+ return model_build_flags{};
147+ }
148+ return default_enabled ? model_build_flag::flash_attention : model_build_flags{};
147149}
148150
149151model_build_flags backend_default_flags (backend_type type) {
150152 using enum model_build_flag;
151153 switch (type) {
152154 case backend_type::cpu:
153155 return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition |
154- flash_attn_flag ();
155- case backend_type::gpu: return flash_attn_flag ();
156+ flash_attn_flag (false );
157+ case backend_type::gpu: return flash_attn_flag (true );
156158 }
157159 return {};
158160}
You can’t perform that action at this time.
0 commit comments