@@ -86,10 +86,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
8686 for (int i = 0 ; i < config .numberOfLayers (); i ++) {
8787 // === Attention Block ===
8888 gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_reduce" , rmsNormWorker );
89- // gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_matmul", fusedQkvWorker);
9089 gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_projection" , fusedQkvWorker );
91-
92- // gridScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
9390 gridScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
9491 gridScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
9592 gridScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
@@ -261,30 +258,6 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) {
261258 phi3Config .rmsNormEps (), // epsilon
262259 phi3State .localSize ); // local memory size
263260
264- // // Fused RMS Apply + QKV Projection (combined matrix)
265- // unifiedLayer.task("attn_rms_qkv_matmul",
266- // Phi3Kernels::fusedRmsNormMatmul,
267- // context,
268- // phi3State.wrapX, // input: raw hidden state (FP32)
269- // phi3State.wrapQkv, // output: combined Q+K+V
270- // weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights
271- // phi3State.temp, // RMS scale factor from reduction
272- // weights.wqkvLayered[layerIndex].asHalfFloatArray(), // Wqkv [opSize × dim]
273- // phi3Config.dim(), // input dimension
274- // opSize, // output dimension (Q + K + V)
275- // LOCAL_WORK_GROUP_SIZE_ALLOC);
276- //
277- // // Split combined QKV into separate Q, K, V buffers
278- // unifiedLayer.task("splitQKV",
279- // TransformerComputeKernelsLayered::splitQKV,
280- // phi3State.wrapQkv,
281- // phi3State.wrapQ,
282- // phi3State.wrapK,
283- // phi3State.wrapV,
284- // phi3Config.dim(),
285- // phi3Config.headSize() * phi3Config.numberOfKeyValueHeads());
286-
287- // AFTER: 1 task
288261 unifiedLayer .task ("attn_rms_qkv_projection" ,
289262 Phi3Kernels ::fusedRmsNormQKVMatmulDirect ,
290263 context ,
0 commit comments