Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion set_paths
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ echo "[INFO] Environment configured for LLaMA3 with TornadoVM at: $TORNADO_SDK"
# 3. You can run LLaMA3 with GPU acceleration using TornadoVM
#
# To use this script: source ./setup_environment.sh
# or: . ./setup_environment.sh
# or: . ./setup_environment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public static void fusedRmsNormFFNGateUpQ8_0(
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
Expand Down Expand Up @@ -331,20 +332,170 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray
}

/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
* then it applies the computed normalization factor to input and weight elements.
*
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
*
* @param context
* Kernel execution context
* @param output
* Array for normalized output
* Array to store partial sums and final normalization factor
* @param x
* Input array to normalize
* @param weights
* Weight values for each element
* @param temp
* Temporary array containing normalization factor at index 0
* @param size
* Number of elements to process
* @param ermsNorm
* Epsilon value squared for numerical stability
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
int groupSize = context.localGroupSizeX;

// Allocate local memory with the provided size
float[] localX = context.allocateFloatLocalArray(localMemSize);

// Load input value and compute square
if (gid < size) {
float v = x.get(gid);
localX[lid] = v * v;
} else {
localX[lid] = 0.0f;
}

// Perform parallel reduction within the work group
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (lid < stride) {
localX[lid] += localX[lid + stride];
}
}

// Each workgroup stores its partial sum in a different location
if (lid == 0) {
// Store the partial sum from each workgroup
temp.set(groupId, localX[0]);
}

context.globalBarrier();

float localss = 0.0f;
int numGroups = (size + groupSize - 1) / groupSize;
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
localss += temp.get(i);
}
localss /= size;
localss += ermsNorm;
localss = 1.0f / TornadoMath.sqrt(localss);

if (gid < size) {
float in = x.get(gid);
float w = weights.get(gid);
output.set(gid, w * (localss * in));
}
}

/**
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
* then it applies the computed normalization factor to input and weight elements.
*
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
*
* @param context
* Kernel execution context
* @param outputFP16
* Half float array to store partial sums and final normalization factor
* @param x
* Input values to normalize
* Input array to normalize
* @param weights
* Weight values for each element
* @param temp
* Temporary array containing normalization factor at index 0
* @param size
* Number of elements to process
* @param ermsNorm
* Epsilon value squared for numerical stability
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayerFuseFP16(KernelContext context, HalfFloatArray outputFP16, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
int groupSize = context.localGroupSizeX;

// Allocate local memory with the provided size
float[] localX = context.allocateFloatLocalArray(localMemSize);

// Load input value and compute square
if (gid < size) {
float v = x.get(gid);
localX[lid] = v * v;
} else {
localX[lid] = 0.0f;
}

// Perform parallel reduction within the work group
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (lid < stride) {
localX[lid] += localX[lid + stride];
}
}

// Each workgroup stores its partial sum in a different location
if (lid == 0) {
// Store the partial sum from each workgroup
temp.set(groupId, localX[0]);
}

context.globalBarrier();

float localss = 0.0f;
int numGroups = (size + groupSize - 1) / groupSize;
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
localss += temp.get(i);
}
localss /= size;
localss += ermsNorm;
localss = 1.0f / TornadoMath.sqrt(localss);

if (gid < size) {
float in = x.get(gid);
float w = weights.get(gid);
outputFP16.set(gid, new HalfFloat(w * (localss * in)));
}
}


/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* @param context Kernel execution context
* @param output Array for normalized output
* @param x Input values to normalize
* @param weights Weight values for each element
* @param temp Temporary array containing normalization factor at index 0
*/
public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
int gid = context.globalIdx;
Expand All @@ -355,25 +506,17 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray

/**
* Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation.
*
* <p>
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector
*
* @param destKeyCache
* Destination array for key cache
* @param srcKey
* Source keys to copy
* @param destValueCache
* Destination array for value cache
* @param srcValue
* Source values to copy
* @param positioNlayer
* Array containing current position
* @param kvDim
* Dimension of key/value vectors
* @param layer
* Current transformer layer index
* @param contextLength
* Maximum sequence length
* @param destKeyCache Destination array for key cache
* @param srcKey Source keys to copy
* @param destValueCache Destination array for value cache
* @param srcValue Source values to copy
* @param positioNlayer Array containing current position
* @param kvDim Dimension of key/value vectors
* @param layer Current transformer layer index
* @param contextLength Maximum sequence length
*/
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {

Expand Down Expand Up @@ -463,21 +606,15 @@ public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArr
/**
* Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional
* information.
*
* <p>
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair
*
* @param context
* Kernel execution context
* @param positionHolder
* Array containing current position
* @param sq
* Query vectors to rotate
* @param sk
* Key vectors to rotate
* @param kv_dim
* Dimension of key/value vectors
* @param head_size
* Dimension of each attention head
* @param context Kernel execution context
* @param positionHolder Array containing current position
* @param sq Query vectors to rotate
* @param sk Key vectors to rotate
* @param kv_dim Dimension of key/value vectors
* @param head_size Dimension of each attention head
*/
public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
int i = context.globalIdx * 2;
Expand Down Expand Up @@ -552,31 +689,20 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold

/**
* Computes attention for a single head. Implements scaled dot-product attention with softmax normalization.
*
* <p>
* Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values
*
* @param allQ
* All query vectors
* @param key_cache
* Cached keys
* @param value_cache
* Cached values
* @param allXb
* Output buffer
* @param h
* Head index to process
* @param headSize
* Dimension per head
* @param kvDim
* Key/value dimension
* @param kvMul
* Key multiplier for grouped attention
* @param loff
* Layer offset in cache
* @param pos
* Current position
* @param wrapAtt
* Attention weights buffer
* @param allQ All query vectors
* @param key_cache Cached keys
* @param value_cache Cached values
* @param allXb Output buffer
* @param h Head index to process
* @param headSize Dimension per head
* @param kvDim Key/value dimension
* @param kvMul Key multiplier for grouped attention
* @param loff Layer offset in cache
* @param pos Current position
* @param wrapAtt Attention weights buffer
*/
private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
FloatArray wrapAtt) {
Expand Down Expand Up @@ -1117,23 +1243,16 @@ public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArr

/**
* Performs optimized matrix-vector multiplication where each work group processes one row of the matrix.
*
* <p>
* Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result
*
* @param context
* Kernel execution context
* @param x
* Input vector
* @param hb
* Output vector
* @param w
* Weight matrix (row-major)
* @param n
* Input dimension
* @param d
* Output dimension
* @param localWorkGroupSize
* Number of threads per work group
* @param context Kernel execution context
* @param x Input vector
* @param hb Output vector
* @param w Weight matrix (row-major)
* @param n Input dimension
* @param d Output dimension
* @param localWorkGroupSize Number of threads per work group
*/
public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
// One row per workgroup (not per thread)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
for (int i = 0; i < config.numberOfLayers(); i++) {
// === Attention Block ===
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
Expand Down Expand Up @@ -199,21 +198,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
// === Attention Block ===
// RMS Normalization
unifiedLayer.task("attn_rms_reduce",
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
context, state.temp, state.wrapX,
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuseFP16,
context, state.wrapXbFP16, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
config.dim(), config.rmsNormEps(), state.localSize);

if (shouldUseFinalNormalization()) {
unifiedLayer.task("attn_rms_finalize",
TransformerComputeKernelsLayered::reductionFinalNormalization,
context, state.temp, config.dim(), config.rmsNormEps());
}

unifiedLayer.task("attn_rms_apply_fp16",
TransformerComputeKernels::mapContextWithQuantize,
context, state.wrapXbFP16, state.wrapX,
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);

// QKV Projection (fused)
unifiedLayer.task("qkv_projection",
TransformerComputeKernelsLayered::fusedQKVMatmulX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
// === Attention Block ===
// RMS Normalization
unifiedLayer.task("attn_rms_reduce",
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
context, state.temp, state.wrapX,
TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse,
context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp,
config.dim(), config.rmsNormEps(), state.localSize);

if (shouldUseFinalNormalization()) {
unifiedLayer.task("attn_rms_finalize",
TransformerComputeKernelsLayered::reductionFinalNormalization,
context, state.temp, config.dim(), config.rmsNormEps());
}

unifiedLayer.task("attn_rms_apply",
TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
context, state.wrapXb, state.wrapX,
weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);

// QKV Projection (fused with Q8 dequantization)
unifiedLayer.task("qkv_projection",
TransformerComputeKernelsLayered::fusedQKVMatmulQ8,
Expand Down Expand Up @@ -306,7 +295,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
// --- Attention Block ---
// RMS Normalization
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
Expand Down