@@ -709,18 +709,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
709709
710710 float scale = (1 .0f / sqrt ((float )d_head));
711711
712- // if (flash_attn) {
713- // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
714- // }
715- // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
712+ // if (flash_attn) {
713+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
714+ // }
715+ // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
716716 GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
717717
718718 bool can_use_flash_attn = true ;
719- can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0 ;
720- can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0 ; // double check
719+ can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0 ;
720+ can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0 ; // double check
721721
722722 // cuda max d_head seems to be 256, cpu does seem to work with 512
723- can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
723+ can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
724724
725725 if (mask != nullptr ) {
726726 // TODO(Green-Sky): figure out if we can bend t5 to work too
@@ -731,9 +731,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
731731 // TODO(Green-Sky): more pad or disable for funny tensor shapes
732732
733733 ggml_tensor* kqv = nullptr ;
734- // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
734+ // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
735735 if (can_use_flash_attn && flash_attn) {
736- // LOG_DEBUG("using flash attention");
736+ // LOG_DEBUG("using flash attention");
737737 k = ggml_cast (ctx, k, GGML_TYPE_F16);
738738
739739 v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
@@ -743,7 +743,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
743743 kqv = ggml_flash_attn_ext (ctx, q, k, v, mask, scale, 0 , 0 );
744744 ggml_flash_attn_ext_set_prec (kqv, GGML_PREC_F32);
745745
746- // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
746+ // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
747747 kqv = ggml_view_3d (ctx, kqv, d_head, n_head, L_q, kqv->nb [1 ], kqv->nb [2 ], 0 );
748748 } else {
749749 v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, n_head, d_head, L_k]
@@ -761,8 +761,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
761761
762762 kqv = ggml_mul_mat (ctx, v, kq); // [N * n_head, L_q, d_head]
763763
764- kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
765- kqv = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 ); // [N, L_q, n_head, d_head]
764+ kqv = ggml_reshape_4d (ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
765+ kqv = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 ); // [N, L_q, n_head, d_head]
766766 }
767767
768768 kqv = ggml_cont (ctx, kqv);
@@ -1057,7 +1057,7 @@ struct GGMLRunner {
10571057 // get_desc().c_str(),
10581058 // params_buffer_size / (1024.0 * 1024.0),
10591059 // ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
1060- // num_tensors);
1060+ // num_tensors);
10611061 return true ;
10621062 }
10631063
@@ -1227,8 +1227,7 @@ class Linear : public UnaryBlock {
12271227 params[" weight" ] = ggml_new_tensor_2d (ctx, wtype, in_features, out_features);
12281228 if (bias) {
12291229 params[" bias" ] = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, out_features);
1230- }
1231-
1230+ }
12321231 }
12331232
12341233public:
0 commit comments