Skip to content

Commit 57b1d0b

Browse files
committed
ml: only use flash attention on gpu by default
* env VISP_FLASH_ATTENTION=0 always disables it * env VISP_FLASH_ATTENTION=1 always enabled it * all other values use default
1 parent 254f7d5 commit 57b1d0b

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/visp/ml.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,23 @@ void backend_set_n_threads(backend_device& b, int n_threads) {
138138
//
139139
// model_build_flags
140140

141-
model_build_flags flash_attn_flag() {
142-
static model_build_flags const flag = []() {
143-
char const* env = getenv("VISP_NO_FLASH_ATTENTION");
144-
return !env || env[0] == '0' ? model_build_flag::flash_attention : model_build_flags{};
145-
}();
146-
return flag;
141+
model_build_flags flash_attn_flag(bool default_enabled) {
142+
static char const* const env = getenv("VISP_FLASH_ATTENTION");
143+
if (env && env[0] == '1') {
144+
return model_build_flag::flash_attention;
145+
} else if (env && env[0] == '0') {
146+
return model_build_flags{};
147+
}
148+
return default_enabled ? model_build_flag::flash_attention : model_build_flags{};
147149
}
148150

149151
model_build_flags backend_default_flags(backend_type type) {
150152
using enum model_build_flag;
151153
switch (type) {
152154
case backend_type::cpu:
153155
return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition |
154-
flash_attn_flag();
155-
case backend_type::gpu: return flash_attn_flag();
156+
flash_attn_flag(false);
157+
case backend_type::gpu: return flash_attn_flag(true);
156158
}
157159
return {};
158160
}

0 commit comments

Comments
 (0)