diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index ac0ace4d79..3585693767 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1494,7 +1494,6 @@ namespace dlib } p_scale[n] = 1.0f / std::sqrt(p_scale[n] / (ks * num) + static_cast(eps)); } - scale.host(); // Apply RMS normalization p_src = src.host(); @@ -1648,14 +1647,22 @@ namespace dlib for (long k = 0; k < num_channels; ++k) max_val = std::max(max_val, ss[k * num_locations]); - float sum = 0.0f; - for (long k = 0; k < num_channels; ++k) + if (max_val == -std::numeric_limits::infinity()) { - dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val); - sum += dd[k * num_locations]; + for (long k = 0; k < num_channels; ++k) + dd[k * num_locations] = 0.0f; + } + else + { + float sum = 0.0f; + for (long k = 0; k < num_channels; ++k) + { + dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val); + sum += dd[k * num_locations]; + } + for (long k = 0; k < num_channels; ++k) + dd[k * num_locations] /= sum; } - for (long k = 0; k < num_channels; ++k) - dd[k * num_locations] /= sum; ++ss; ++dd; @@ -3366,6 +3373,69 @@ namespace dlib } } + // ------------------------------------------------------------------------------------ + + void apply_rotary_positional_embedding( + bool is_backward, + resizable_tensor& data, + const resizable_tensor& cos_cache, + const resizable_tensor& sin_cache) + { + const long batch_size = data.num_samples(); + const long num_heads = data.k(); + const long seq_len = data.nr(); + const long d_head = data.nc(); + const long half_d = d_head / 2; + + DLIB_CASSERT(cos_cache.nr() == seq_len, "cos_cache rows must match seq_len"); + DLIB_CASSERT(cos_cache.nc() == half_d, "cos_cache cols must be d_head/2"); + DLIB_CASSERT(sin_cache.nr() == seq_len, "sin_cache rows must match seq_len"); + DLIB_CASSERT(sin_cache.nc() == half_d, "sin_cache cols must be d_head/2"); + + const bool is_odd = (d_head % 2 != 0); + const long rot_dim = is_odd ? d_head - 1 : d_head; + + float* data_ptr = data.host(); + const float* cos_ptr = cos_cache.host(); + const float* sin_ptr = sin_cache.host(); + + const size_t total_elements = batch_size * num_heads * seq_len * half_d; + + parallel_for(0, total_elements, [&](long idx) + { + const long pair_idx = idx % half_d; + const long pos = (idx / half_d) % seq_len; + const long head = (idx / (half_d * seq_len)) % num_heads; + const long batch = idx / (half_d * seq_len * num_heads); + + const long dim_i = pair_idx * 2; + if (dim_i >= rot_dim) return; + + const long data_offset = ((batch * num_heads + head) * seq_len + pos) * d_head + dim_i; + const long trig_offset = pos * half_d + pair_idx; + + const float c = cos_ptr[trig_offset]; + const float s = sin_ptr[trig_offset]; + const float x0 = data_ptr[data_offset]; + const float x1 = data_ptr[data_offset + 1]; + + if (!is_backward) + { + // Forward: [cos -sin] [x0] + // [sin cos] [x1] + data_ptr[data_offset] = x0 * c - x1 * s; + data_ptr[data_offset + 1] = x0 * s + x1 * c; + } + else + { + // Backward (inverse rotation): [cos sin] [x0] + // [-sin cos] [x1] + data_ptr[data_offset] = x0 * c + x1 * s; + data_ptr[data_offset + 1] = -x0 * s + x1 * c; + } + }); + } + // ------------------------------------------------------------------------------------ } diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index 4e29c8a8d9..1689ba0cf7 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -584,6 +584,15 @@ namespace dlib float scale_factor ); + // ----------------------------------------------------------------------------------- + + void apply_rotary_positional_embedding( + bool is_backward, + resizable_tensor& data, + const resizable_tensor& cos_cache, + const resizable_tensor& sin_cache + ); + // ----------------------------------------------------------------------------------- class pooling @@ -761,6 +770,138 @@ namespace dlib // ----------------------------------------------------------------------------------- + class compute_loss_cross_entropy_per_logit + { + /*! + Computes cross-entropy loss for causal language modeling + Uses all sequence positions (except last) for training + Each position t predicts the token at position t+1 + !*/ + public: + compute_loss_cross_entropy_per_logit() {} + + template + void operator()( + const_label_iterator truth, + const tensor& input_tensor, + const tensor& output_tensor, + tensor& grad, + double& loss, + long ignore_index + ) const + { + DLIB_CASSERT(output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.k() == 1); + DLIB_CASSERT(input_tensor.nc() == 1); + + const long batch_size = output_tensor.num_samples(); + const long seq_len = output_tensor.nr(); + const long vocab_size = output_tensor.nc(); + + const float* out_data = output_tensor.host(); + const float* in_data = input_tensor.host(); + float* g = grad.host(); + + std::fill(g, g + grad.size(), 0.0f); + + long valid_tokens = 0; + + if (ignore_index < 0) + { + valid_tokens = batch_size * seq_len; + } + else { + for (long i = 0; i < batch_size; ++i) + { + for (long t = 0; t < seq_len; ++t) + { + unsigned long target_class; + if (t < seq_len - 1) { + target_class = static_cast( + in_data[tensor_index(input_tensor, i, 0, t + 1, 0)] + ); + } + else + target_class = *(truth + i); + + if (static_cast(target_class) != ignore_index) + valid_tokens++; + } + } + } + if (valid_tokens == 0) + { + loss = 0.0; + return; + } + + const double scale = 1.0 / valid_tokens; + loss = 0.0; + + for (long i = 0; i < batch_size; ++i) + { + // Loop over all positions (0 to seq_len-1) + for (long t = 0; t < seq_len; ++t) + { + unsigned long target_class; + + // Extract target token + if (t < seq_len - 1) { + // For positions 0 to seq_len-2: target from input_tensor[t+1] + target_class = static_cast( + in_data[tensor_index(input_tensor, i, 0, t + 1, 0)] + ); + } else { + // For last position (seq_len-1): target from truth + target_class = *(truth + i); + } + + if (ignore_index >= 0 && static_cast(target_class) == ignore_index) + continue; + + DLIB_CASSERT(target_class < static_cast(vocab_size)); + + // Find max logit for numerical stability + float max_val = out_data[tensor_index(output_tensor, i, 0, t, 0)]; + for (long c = 1; c < vocab_size; ++c) + { + const float val = out_data[tensor_index(output_tensor, i, 0, t, c)]; + max_val = std::max(max_val, val); + } + + // Compute softmax denominator + float sum_exp = 0.0f; + for (long c = 0; c < vocab_size; ++c) + { + const unsigned long idx = tensor_index(output_tensor, i, 0, t, c); + const float exp_val = std::exp(out_data[idx] - max_val); + g[idx] = exp_val; + sum_exp += exp_val; + } + + // Compute loss and gradients + for (long c = 0; c < vocab_size; ++c) + { + const unsigned long idx = tensor_index(output_tensor, i, 0, t, c); + const float softmax_val = g[idx] / sum_exp; + + if (static_cast(c) == target_class) + { + loss += scale * (-std::log(std::max(softmax_val, 1e-10f))); + g[idx] = scale * (softmax_val - 1.0f); + } + else + { + g[idx] = scale * softmax_val; + } + } + } + } + } + }; + + // ----------------------------------------------------------------------------------- + class compute_loss_binary_log_per_pixel { diff --git a/dlib/cuda/cublas_dlibapi.cpp b/dlib/cuda/cublas_dlibapi.cpp index 064e92c3df..3e4c38d8e8 100644 --- a/dlib/cuda/cublas_dlibapi.cpp +++ b/dlib/cuda/cublas_dlibapi.cpp @@ -159,16 +159,21 @@ namespace dlib const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N; const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N; - long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); - long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); - - auto is_matrix = [](const auto& tensor) { - return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) || - (tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1)); - }; - const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest); - - if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) num_samples = num_channels = 1; + const bool lhs_is_matrix = is_2d_matrix(lhs); + const bool rhs_is_matrix = is_2d_matrix(rhs); + const bool dest_is_matrix = is_2d_matrix(dest); + + const size_t lhs_plane_size = lhs.nr() * lhs.nc(); + const size_t rhs_plane_size = rhs.nr() * rhs.nc(); + const size_t dest_plane_size = dest.nr() * dest.nc(); + + long num_samples, num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); + if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) + num_samples = 1; + else if (!lhs_is_matrix && rhs_is_matrix) + num_samples = lhs.num_samples(); + else + num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); size_t lhs_rows = lhs.nr(); size_t lhs_cols = lhs.nc(); @@ -176,12 +181,14 @@ namespace dlib lhs_rows = lhs.num_samples(); lhs_cols = lhs.k(); } + size_t rhs_rows = rhs.nr(); size_t rhs_cols = rhs.nc(); if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) { rhs_rows = rhs.num_samples(); rhs_cols = rhs.k(); } + size_t dest_rows = dest.nr(); size_t dest_cols = dest.nc(); if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) { @@ -189,10 +196,6 @@ namespace dlib dest_cols = dest.k(); } - const size_t lhs_plane_size = lhs_rows * lhs_cols; - const size_t rhs_plane_size = rhs_rows * rhs_cols; - const size_t dest_plane_size = dest_rows * dest_cols; - for (long b = 0; b < num_samples; ++b) { for (long c = 0; c < num_channels; ++c) @@ -203,12 +206,18 @@ namespace dlib rhs.device() + (b * num_channels + c) * rhs_plane_size; auto dest_slice = dest_is_matrix ? dest.device() : dest.device() + (b * num_channels + c) * dest_plane_size; + const int k = trans_rhs ? rhs_cols : rhs_rows; CHECK_CUBLAS(cublasSgemm( - context(), transb, transa, dest_cols, dest_rows, k, - &alpha, rhs_slice, rhs_cols, lhs_slice, lhs_cols, - &beta, dest_slice, dest_cols + context(), + transb, transa, + dest_cols, dest_rows, k, + &alpha, + rhs_slice, rhs_cols, + lhs_slice, lhs_cols, + &beta, + dest_slice, dest_cols )); } } diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 672efe9c22..56b3680896 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2407,12 +2407,9 @@ namespace dlib // ---------------------------------------------------------------------------------------- - __global__ void _cuda_rms_normalize( - float* dest, + __global__ void _cuda_rms_normalize_accumulate( float* scale, const float* src, - const float* gamma, - float eps, size_t ns, size_t ks, size_t num @@ -2422,28 +2419,42 @@ namespace dlib { const auto ps = src + n * ks * num; float sum_squares = 0.0f; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { sum_squares += ps[i] * ps[i]; } warp_reduce_atomic_add(scale[n], sum_squares / (ks * num)); } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_invert( + float* scale, + float eps, + size_t ns + ) + { for (auto n : grid_stride_range_y(0, ns)) { - for (auto i : grid_stride_range(0, 1)) - { + if (threadIdx.x == 0) scale[n] = 1.0f / std::sqrt(scale[n] + eps); - } } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_apply( + float* dest, + const float* scale, + const float* src, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pd = dest + n * ks * num; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { pd[i] = ps[i] * scale[n] * gamma[i / num]; } @@ -2457,7 +2468,7 @@ namespace dlib const tensor& src, const tensor& gamma ) - { + { DLIB_CASSERT( gamma.k() == src.k() && gamma.nr() == 1 && @@ -2478,26 +2489,31 @@ namespace dlib scale.set_size(ns); scale = 0; - launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns), - dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num); + launch_kernel(_cuda_rms_normalize_accumulate, max_jobs(ks * num, ns), + scale.device(), src.device(), ns, ks, num); + + launch_kernel(_cuda_rms_normalize_invert, max_jobs(1, ns), + scale.device(), eps, ns); + + launch_kernel(_cuda_rms_normalize_apply, max_jobs(ks * num, ns), + dest.device(), scale.device(), src.device(), gamma.device(), ns, ks, num); } // ---------------------------------------------------------------------------------------- - __global__ void _cuda_rms_normalize_gradient( - float* src_grad, + __global__ void _cuda_rms_normalize_gradient_accumulate( float* gamma_grad, float* dscale, const float* src, const float* gradient_input, const float* scale, const float* gamma, - size_t ns, - size_t ks, - size_t num + size_t ns, + size_t ks, + size_t num ) { - for (auto nk : grid_stride_range_y(0, ns * ks)) + for (auto nk : grid_stride_range_y(0, ns* ks)) { const auto n = nk / ks; const auto k = nk % ks; @@ -2509,22 +2525,34 @@ namespace dlib for (auto i : grid_stride_range(0, num)) { const float x_hat = ps[i] * scale[n]; - const float dx = pgi[i] * gamma[i / num]; + const float dx = pgi[i] * gamma[k]; temp_gg += pgi[i] * x_hat; temp_ds += dx * ps[i] * scale_pow; } warp_reduce_atomic_add(gamma_grad[k], temp_gg); warp_reduce_atomic_add(dscale[n], temp_ds); } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_gradient_apply( + float* src_grad, + const float* dscale, + const float* src, + const float* gradient_input, + const float* scale, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pgi = gradient_input + n * ks * num; const auto psg = src_grad + n * ks * num; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { const float dx = pgi[i] * gamma[i / num]; psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum; @@ -2541,7 +2569,7 @@ namespace dlib tensor& gamma_grad, resizable_tensor& dscale ) - { + { DLIB_CASSERT(src.num_samples() == scale.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(gamma.k() == src.k()); @@ -2558,9 +2586,13 @@ namespace dlib dscale.copy_size(scale); dscale = 0; - // Lancement du kernel CUDA - launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns), - src_grad.device(), gamma_grad.device(), dscale.device(), + launch_kernel(_cuda_rms_normalize_gradient_accumulate, max_jobs(ks * num, ns * ks), + gamma_grad.device(), dscale.device(), + src.device(), gradient_input.device(), scale.device(), gamma.device(), + ns, ks, num); + + launch_kernel(_cuda_rms_normalize_gradient_apply, max_jobs(ks * num, ns), + src_grad.device(), dscale.device(), src.device(), gradient_input.device(), scale.device(), gamma.device(), ns, ks, num); } @@ -2736,12 +2768,23 @@ namespace dlib // ---------------------------------------------------------------------------------------- // CUDA Kernels for ACT operations - __global__ void _cuda_compute_act_halt_probabilities( - float* halt_probs, + + // Kernel 1: initialize logits with bias + __global__ void _cuda_act_init_logits( + float* logits, + float b_halt, + size_t total_positions + ) + { + for (auto pos : grid_stride_range(0, total_positions)) + logits[pos] = b_halt; + } + + // Kernel 2: compute dot product and accumulate into logits + __global__ void _cuda_act_accumulate_logits( float* logits, const float* input_data, const float* W_halt, - float b_halt, size_t batch_size, size_t seq_len, size_t d_model, @@ -2751,11 +2794,6 @@ namespace dlib { const long total_positions = batch_size * seq_len; - for (auto pos : grid_stride_range_y(0, total_positions)) - for (auto i : grid_stride_range(0, 1)) - logits[pos] = b_halt; - __syncthreads(); - for (auto pos : grid_stride_range_y(0, total_positions)) { const long n = pos / seq_len; @@ -2773,12 +2811,17 @@ namespace dlib warp_reduce_atomic_add(logits[pos], temp); } - __syncthreads(); + } + // Kernel 3: apply sigmoid to compute halt probabilities + __global__ void _cuda_act_apply_sigmoid( + float* halt_probs, + const float* logits, + size_t total_positions + ) + { for (auto pos : grid_stride_range(0, total_positions)) - { halt_probs[pos] = 1.0f / (1.0f + expf(-logits[pos])); - } } void compute_act_halt_probabilities( @@ -2798,18 +2841,36 @@ namespace dlib halt_probs.set_size(total_positions, 1, 1, 1); logits.set_size(total_positions, 1, 1, 1); - launch_kernel(_cuda_compute_act_halt_probabilities, + // Extract bias from halt_params (last element) + const float b_halt = halt_params.host()[feature_dim]; + + // Phase 1: initialize logits with bias + launch_kernel(_cuda_act_init_logits, + max_jobs(total_positions), + logits.device(), + b_halt, + total_positions); + + // Phase 2: accumulate dot product into logits + // Note: sequential kernel launch provides implicit synchronization + launch_kernel(_cuda_act_accumulate_logits, max_jobs(feature_dim, total_positions), - halt_probs.device(), logits.device(), input_data.device(), halt_params.device(), - halt_params.host()[feature_dim], batch_size, seq_len, d_model, num_channels, feature_dim); + + // Phase 3: apply sigmoid + // Note: sequential kernel launch provides implicit synchronization + launch_kernel(_cuda_act_apply_sigmoid, + max_jobs(total_positions), + halt_probs.device(), + logits.device(), + total_positions); } __global__ void _cuda_update_act_state( @@ -2993,6 +3054,263 @@ namespace dlib // ---------------------------------------------------------------------------------------- + __global__ void apply_rope_kernel( + float* __restrict__ data, + const float* __restrict__ cos_cache, + const float* __restrict__ sin_cache, + const size_t total_pairs, + const long num_heads, + const long seq_len, + const long d_head, + const long half_d, + const long rot_dim, + const bool is_backward) + { + for (auto pair_id : grid_stride_range(0, total_pairs)) + { + const long pair_idx = pair_id % half_d; + const long pos = (pair_id / half_d) % seq_len; + const long head = (pair_id / (half_d * seq_len)) % num_heads; + const long batch = pair_id / (half_d * seq_len * num_heads); + + const long dim_i = pair_idx * 2; + if (dim_i >= rot_dim) continue; + + const long base_offset = ((batch * num_heads + head) * seq_len + pos) * d_head; + const long data_offset = base_offset + dim_i; + const long trig_offset = pos * half_d + pair_idx; + + const float c = cos_cache[trig_offset]; + const float s = sin_cache[trig_offset]; + const float x0 = data[data_offset]; + const float x1 = data[data_offset + 1]; + + if (!is_backward) + { + // Forward: rotation standard + data[data_offset] = x0 * c - x1 * s; + data[data_offset + 1] = x0 * s + x1 * c; + } + else + { + // Backward: rotation inverse + data[data_offset] = x0 * c + x1 * s; + data[data_offset + 1] = -x0 * s + x1 * c; + } + } + } + + void apply_rotary_positional_embedding( + bool is_backward, + tensor& data, + const tensor& cos_cache, + const tensor& sin_cache) + { + const long batch_size = data.num_samples(); + const long num_heads = data.k(); + const long seq_len = data.nr(); + const long d_head = data.nc(); + const long half_d = d_head / 2; + + DLIB_CASSERT(cos_cache.nr() == seq_len, "cos_cache.nr() must match seq_len"); + DLIB_CASSERT(cos_cache.nc() == half_d, "cos_cache.nc() must be d_head/2"); + DLIB_CASSERT(sin_cache.nr() == seq_len, "sin_cache.nr() must match seq_len"); + DLIB_CASSERT(sin_cache.nc() == half_d, "sin_cache.nc() must be d_head/2"); + + const bool is_odd = (d_head % 2 != 0); + const long rot_dim = is_odd ? d_head - 1 : d_head; + + const size_t total_elements = batch_size * num_heads * seq_len * half_d; + if (total_elements == 0) return; + + launch_kernel(apply_rope_kernel, max_jobs(total_elements), + data.device(), + cos_cache.device(), + sin_cache.device(), + total_elements, + num_heads, + seq_len, + d_head, + half_d, + rot_dim, + is_backward + ); + } + + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_count_valid_tokens( + float* valid_count, + const unsigned long* truth, + const float* input_data, + size_t batch_size, + size_t seq_len, + long ignore_index + ) + { + float count = 0.0f; + + for (auto sample_idx : grid_stride_range(0, batch_size)) + { + for (size_t t = 0; t < seq_len; ++t) + { + unsigned long target_class; + if (t < seq_len - 1) { + const size_t input_idx = sample_idx * seq_len + (t + 1); + target_class = static_cast(input_data[input_idx]); + } + else { + target_class = truth[sample_idx]; + } + + if (ignore_index < 0 || static_cast(target_class) != ignore_index) { + count += 1.0f; + } + } + } + + warp_reduce_atomic_add(*valid_count, count); + } + + __global__ void _cuda_compute_loss_cross_entropy_per_logit( + float* loss_out, + float* g, + const unsigned long* truth, + const float* input_data, + const float* out_data, + size_t batch_size, + size_t seq_len, + size_t vocab_size, + float scale, + long ignore_index + ) + { + float total_loss = 0; + + for (auto sample_idx : grid_stride_range(0, batch_size)) + { + for (size_t t = 0; t < seq_len; ++t) + { + unsigned long target_class; + if (t < seq_len - 1) { + const size_t input_idx = sample_idx * seq_len + (t + 1); + target_class = static_cast(input_data[input_idx]); + } + else { + target_class = truth[sample_idx]; + } + + const size_t base_idx = sample_idx * seq_len * vocab_size + t * vocab_size; + + if (ignore_index >= 0 && static_cast(target_class) == ignore_index) { + for (size_t c = 0; c < vocab_size; ++c) { + g[base_idx + c] = 0.0f; + } + continue; + } + + float max_val = out_data[base_idx]; + for (size_t c = 1; c < vocab_size; ++c) + { + max_val = ::max(max_val, out_data[base_idx + c]); + } + + float sum_exp = 0.0f; + for (size_t c = 0; c < vocab_size; ++c) + { + const size_t idx = base_idx + c; + const float exp_val = ::exp(out_data[idx] - max_val); + g[idx] = exp_val; + sum_exp += exp_val; + } + + for (size_t c = 0; c < vocab_size; ++c) + { + const size_t idx = base_idx + c; + const float softmax_val = g[idx] / sum_exp; + + if (c == target_class) + { + total_loss += -::log(::max(softmax_val, 1e-10f)); + g[idx] = scale * (softmax_val - 1.0f); + } + else + { + g[idx] = scale * softmax_val; + } + } + } + } + + warp_reduce_atomic_add(*loss_out, total_loss); + } + + void compute_loss_cross_entropy_per_logit::do_work( + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, + const tensor& input_tensor, + const tensor& subnetwork_output, + tensor& gradient, + double& loss, + long ignore_index + ) + { + CHECK_CUDA(cudaMemset(gradient.device(), 0, gradient.size() * sizeof(float))); + CHECK_CUDA(cudaMemset(loss_work_buffer, 0, sizeof(float))); + + const long batch_size = subnetwork_output.num_samples(); + const long seq_len = subnetwork_output.nr(); + const long vocab_size = subnetwork_output.nc(); + + double scale; + if (ignore_index < 0) + { + scale = 1.0 / (batch_size * seq_len); + } + else { + cuda_data_void_ptr count_buf = device_global_buffer(sizeof(float)); + auto valid_count_ptr = static_pointer_cast(count_buf, 1); + CHECK_CUDA(cudaMemset(valid_count_ptr, 0, sizeof(float))); + + launch_kernel(_cuda_count_valid_tokens, max_jobs(batch_size), + valid_count_ptr.data(), + truth_buffer.data(), + input_tensor.device(), + batch_size, + seq_len, + ignore_index + ); + + float valid_count; + dlib::cuda::memcpy(&valid_count, valid_count_ptr); + + if (valid_count == 0) { + loss = 0.0; + return; + } + + scale = 1.0 / valid_count; + } + + launch_kernel(_cuda_compute_loss_cross_entropy_per_logit, max_jobs(batch_size), + loss_work_buffer.data(), + gradient.device(), + truth_buffer.data(), + input_tensor.device(), + subnetwork_output.device(), + batch_size, + seq_len, + vocab_size, + static_cast(scale), + ignore_index + ); + + float floss; + dlib::cuda::memcpy(&floss, loss_work_buffer); + loss = scale * floss; + } + + // ---------------------------------------------------------------------------------------- __device__ float cuda_log1pexp(float x) { diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index 26e1d29e4f..e1a345cf9e 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -656,6 +656,65 @@ namespace dlib float scale_factor ); + // ---------------------------------------------------------------------------------------- + + void apply_rotary_positional_embedding( + bool is_backward, + tensor& data, + const tensor& cos_cache, + const tensor& sin_cache + ); + + // ---------------------------------------------------------------------------------------- + + class compute_loss_cross_entropy_per_logit + { + /*! + The point of this class is to compute the loss computed by + loss_cross_entropy_per_logit_, but to do so with CUDA + !*/ + public: + compute_loss_cross_entropy_per_logit() {} + + template + void operator() ( + const_label_iterator truth, + const tensor& input_tensor, // Source tokens + const tensor& subnetwork_output, // Logits + tensor& gradient, + double& loss, + long ignore_index + ) const + { + const size_t bytes_per_sample = sizeof(unsigned long); + buf = device_global_buffer(subnetwork_output.num_samples() * bytes_per_sample + sizeof(float)); + cuda_data_ptr loss_buf = static_pointer_cast(buf, 1); + buf = buf + sizeof(float); + + for (long i = 0; i < subnetwork_output.num_samples(); ++i, ++truth) + { + const unsigned long t = *truth; + memcpy(buf + i * bytes_per_sample, &t, bytes_per_sample); + } + + auto truth_buf = static_pointer_cast(buf, subnetwork_output.num_samples()); + do_work(loss_buf, truth_buf, input_tensor, subnetwork_output, gradient, loss, ignore_index); + } + + private: + static void do_work( + cuda_data_ptr loss_work_buffer, + cuda_data_ptr truth_buffer, + const tensor& input_tensor, + const tensor& subnetwork_output, + tensor& gradient, + double& loss, + long ignore_index + ); + + mutable cuda_data_void_ptr buf; + }; + // ---------------------------------------------------------------------------------------- class compute_loss_binary_log_per_pixel diff --git a/dlib/cuda/tensor.h b/dlib/cuda/tensor.h index 6a893df311..138413b642 100644 --- a/dlib/cuda/tensor.h +++ b/dlib/cuda/tensor.h @@ -220,6 +220,17 @@ namespace dlib t.size() == (size_t)t.nc(); } +// ---------------------------------------------------------------------------------------- + + inline bool is_2d_matrix( + const tensor& t + ) + { + return !is_vector(t) && + (t.size() == (size_t)(t.num_samples() * t.k()) || + t.size() == (size_t)(t.nr() * t.nc())); + } + // ---------------------------------------------------------------------------------------- inline const matrix_op > mat ( diff --git a/dlib/cuda/tensor_abstract.h b/dlib/cuda/tensor_abstract.h index 62f649391e..3a3d83eda7 100644 --- a/dlib/cuda/tensor_abstract.h +++ b/dlib/cuda/tensor_abstract.h @@ -359,6 +359,18 @@ namespace dlib - t.size() == t.nc() !*/ +// ---------------------------------------------------------------------------------------- + + inline bool is_2d_matrix( + const tensor& t + ); + /*! + ensures + - returns true if and only if one of the following is true: + - t.size() == t.num_samples() * t.k() + - t.size() == t.nr() * t.nc() + !*/ + // ---------------------------------------------------------------------------------------- const matrix_exp mat ( diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index d9429df2f4..64f437480f 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -242,39 +242,54 @@ namespace dlib { namespace tt } else if (mode == operation_mode::PLANE_WISE) { - auto is_matrix = [](const auto& tensor) { - return ((tensor.num_samples() * tensor.k() == 1 && tensor.nr() * tensor.nc() > 1) || - (tensor.num_samples() * tensor.k() > 1 && tensor.nr() * tensor.nc() == 1)); - }; - - long num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); - long num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); - const bool lhs_is_matrix = is_matrix(lhs), rhs_is_matrix = is_matrix(rhs), dest_is_matrix = is_matrix(dest); + const bool lhs_is_matrix = is_2d_matrix(lhs); + const bool rhs_is_matrix = is_2d_matrix(rhs); + const bool dest_is_matrix = is_2d_matrix(dest); + + const size_t lhs_plane_size = lhs.nr() * lhs.nc(); + const size_t rhs_plane_size = rhs.nr() * rhs.nc(); + const size_t dest_plane_size = dest.nr() * dest.nc(); + + long num_samples, num_channels = std::min({ lhs.k(), rhs.k(), dest.k() }); + if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) + num_samples = 1; + else if (!lhs_is_matrix && rhs_is_matrix) + num_samples = lhs.num_samples(); + else + num_samples = std::min({ lhs.num_samples(), rhs.num_samples(), dest.num_samples() }); - if (lhs_is_matrix && rhs_is_matrix && dest_is_matrix) { - num_samples = num_channels = 1; + size_t lhs_rows = lhs.nr(); + size_t lhs_cols = lhs.nc(); + if (lhs_is_matrix && (lhs.num_samples() > 1 || lhs.k() > 1)) { + lhs_rows = lhs.num_samples(); + lhs_cols = lhs.k(); + } + size_t rhs_rows = rhs.nr(); + size_t rhs_cols = rhs.nc(); + if (rhs_is_matrix && (rhs.num_samples() > 1 || rhs.k() > 1)) { + rhs_rows = rhs.num_samples(); + rhs_cols = rhs.k(); + } + size_t dest_rows = dest.nr(); + size_t dest_cols = dest.nc(); + if (dest_is_matrix && (dest.num_samples() > 1 || dest.k() > 1)) { + dest_rows = dest.num_samples(); + dest_cols = dest.k(); } - long lhs_rows = (lhs_is_matrix && lhs.num_samples() > 1) ? lhs.num_samples() : lhs.nr(); - long lhs_cols = (lhs_is_matrix && lhs.k() > 1) ? lhs.k() : lhs.nc(); - long rhs_rows = (rhs_is_matrix && rhs.num_samples() > 1) ? rhs.num_samples() : rhs.nr(); - long rhs_cols = (rhs_is_matrix && rhs.k() > 1) ? rhs.k() : rhs.nc(); - long dest_rows = (dest_is_matrix && dest.num_samples() > 1) ? dest.num_samples() : dest.nr(); - long dest_cols = (dest_is_matrix && dest.k() > 1) ? dest.k() : dest.nc(); - - const size_t lhs_plane_size = lhs_rows * lhs_cols; - const size_t rhs_plane_size = rhs_rows * rhs_cols; - const size_t dest_plane_size = dest_rows * dest_cols; - + // Process each plane for (long b = 0; b < num_samples; ++b) { for (long c = 0; c < num_channels; ++c) { - auto lhs_slice = lhs_is_matrix ? alias_tensor(lhs_rows, lhs_cols)(lhs, 0) : + auto lhs_slice = lhs_is_matrix ? + alias_tensor(lhs_rows, lhs_cols)(lhs, 0) : alias_tensor(lhs_rows, lhs_cols)(lhs, (b * num_channels + c) * lhs_plane_size); - auto rhs_slice = rhs_is_matrix ? alias_tensor(rhs_rows, rhs_cols)(rhs, 0) : + auto rhs_slice = rhs_is_matrix ? + alias_tensor(rhs_rows, rhs_cols)(rhs, 0) : alias_tensor(rhs_rows, rhs_cols)(rhs, (b * num_channels + c) * rhs_plane_size); - auto dest_slice = dest_is_matrix ? alias_tensor(dest_rows, dest_cols)(dest, 0) : + auto dest_slice = dest_is_matrix ? + alias_tensor(dest_rows, dest_cols)(dest, 0) : alias_tensor(dest_rows, dest_cols)(dest, (b * num_channels + c) * dest_plane_size); if (beta != 0) @@ -1496,6 +1511,22 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void apply_rotary_positional_embedding( + bool is_backward, + resizable_tensor& data, + const resizable_tensor& cos_cache, + const resizable_tensor& sin_cache + ) + { +#ifdef DLIB_USE_CUDA + cuda::apply_rotary_positional_embedding(is_backward, data, cos_cache, sin_cache); +#else + cpu::apply_rotary_positional_embedding(is_backward, data, cos_cache, sin_cache); +#endif + } + // ---------------------------------------------------------------------------------------- }} diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index fe0260ea88..89d3d6c627 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -2516,6 +2516,39 @@ namespace dlib { namespace tt - scale_factor: scaling strength (0 = no scaling) !*/ +// ---------------------------------------------------------------------------------------- + + void apply_rotary_positional_embedding( + bool is_backward, + resizable_tensor& data, + const resizable_tensor& cos_cache, + const resizable_tensor& sin_cache + ); + /*! + requires + - data.nr() == cos_cache.nr() + - data.nr() == sin_cache.nr() + - cos_cache.nc() == data.nc() / 2 + - sin_cache.nc() == data.nc() / 2 + - cos_cache.num_samples() == 1 + - cos_cache.k() == 1 + - sin_cache.num_samples() == 1 + - sin_cache.k() == 1 + - data.nc() >= 2 + ensures + - Applies rotary positional embeddings (RoPE) to the input tensor + - data is modified in-place with the rotation applied pairwise to dimensions + - For each position pos and dimension pair (i, i+1): + if (!is_backward): + // Forward rotation (encoding) + data[pos,i] = data[pos,i] * cos_cache[pos,i/2] - data[pos,i+1] * sin_cache[pos,i/2] + data[pos,i+1] = data[pos,i] * sin_cache[pos,i/2] + data[pos,i+1] * cos_cache[pos,i/2] + else: + // Backward rotation (decoding, inverse transformation for gradients) + data[pos,i] = data[pos,i] * cos_cache[pos,i/2] + data[pos,i+1] * sin_cache[pos,i/2] + data[pos,i+1] = -data[pos,i] * sin_cache[pos,i/2] + data[pos,i+1] * cos_cache[pos,i/2] + !*/ + // ---------------------------------------------------------------------------------------- }} diff --git a/dlib/data_io.h b/dlib/data_io.h index 15c630e9e9..505f75108c 100644 --- a/dlib/data_io.h +++ b/dlib/data_io.h @@ -8,6 +8,7 @@ #include "data_io/mnist.h" #include "data_io/cifar.h" #include "data_io/arc_agi.h" +#include "data_io/language_model_data.h" #ifndef DLIB_ISO_CPP_ONLY #include "data_io/load_image_dataset.h" diff --git a/dlib/data_io/arc_agi.h b/dlib/data_io/arc_agi.h index 9153e8d4fd..64356dda8c 100644 --- a/dlib/data_io/arc_agi.h +++ b/dlib/data_io/arc_agi.h @@ -715,8 +715,8 @@ namespace dlib sequence.push_back(TOKEN_GEN_START); // Convert to dlib column vector - arc_token_sequence_t result(static_cast(sequence.size())); - for (long i = 0; i < static_cast(sequence.size()); ++i) + arc_token_sequence_t result(sequence.size()); + for (size_t i = 0; i < sequence.size(); ++i) result(i) = sequence[i]; return result; } @@ -736,8 +736,8 @@ namespace dlib append_flat_grid(sequence, test_pair.output); sequence.push_back(TOKEN_END_OF_OUTPUT); - arc_token_sequence_t result(static_cast(sequence.size())); - for (long i = 0; i < static_cast(sequence.size()); ++i) + arc_token_sequence_t result(sequence.size()); + for (size_t i = 0; i < sequence.size(); ++i) result(i) = sequence[i]; return result; } diff --git a/dlib/data_io/language_model_data.h b/dlib/data_io/language_model_data.h new file mode 100644 index 0000000000..d1f4aa6ae2 --- /dev/null +++ b/dlib/data_io/language_model_data.h @@ -0,0 +1,976 @@ +#ifndef DLIB_LANGUAGE_MODEL_DATA_H_ +#define DLIB_LANGUAGE_MODEL_DATA_H_ + +#include "language_model_data_abstract.h" + +#include +#include +#include +#include "../matrix.h" +#include "../serialize.h" + +namespace dlib +{ + + // --------------------------------------------------------------------------------- + + enum class file_content_type + { + TEXT_PLAIN, // Plain text file (including CSV, code, etc.) + TEXT_XML, // XML or HTML markup + IMAGE, // Image formats (PNG, JPEG, GIF, TIFF, BMP, etc.) + VIDEO, // Video formats (MP4, AVI, MKV, etc.) + AUDIO, // Audio formats (MP3, WAV, FLAC, etc.) + EXECUTABLE, // Executable files (EXE, DLL, ELF, Mach-O) + COMPRESSED, // Compressed archives (ZIP, GZIP, 7Z, RAR, etc.) + PDF, // PDF documents + OFFICE, // Office documents (DOCX, XLSX, PPTX, etc.) + UNKNOWN // Unknown or undetermined file type + }; + + // --------------------------------------------------------------------------------- + + namespace impl + { + // Magic number signature structure + struct magic_signature + { + const unsigned char* bytes; + size_t length; + file_content_type type; + size_t offset; // Byte offset where signature should appear + }; + + // Common magic number signatures (ordered by frequency/priority) + static const unsigned char sig_png[] = { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A }; + static const unsigned char sig_jpg1[] = { 0xFF, 0xD8, 0xFF, 0xE0 }; + static const unsigned char sig_jpg2[] = { 0xFF, 0xD8, 0xFF, 0xE1 }; + static const unsigned char sig_jpg3[] = { 0xFF, 0xD8, 0xFF, 0xDB }; + static const unsigned char sig_jpg4[] = { 0xFF, 0xD8, 0xFF, 0xEE }; + static const unsigned char sig_gif87[] = { 0x47, 0x49, 0x46, 0x38, 0x37, 0x61 }; // GIF87a + static const unsigned char sig_gif89[] = { 0x47, 0x49, 0x46, 0x38, 0x39, 0x61 }; // GIF89a + static const unsigned char sig_tiff_le[] = { 0x49, 0x49, 0x2A, 0x00 }; // Little endian + static const unsigned char sig_tiff_be[] = { 0x4D, 0x4D, 0x00, 0x2A }; // Big endian + static const unsigned char sig_bmp[] = { 0x42, 0x4D }; + static const unsigned char sig_webp[] = { 0x52, 0x49, 0x46, 0x46 }; // RIFF (check for WEBP at offset 8) + + static const unsigned char sig_pdf[] = { 0x25, 0x50, 0x44, 0x46 }; // %PDF + + static const unsigned char sig_zip[] = { 0x50, 0x4B, 0x03, 0x04 }; + static const unsigned char sig_gzip[] = { 0x1F, 0x8B }; + static const unsigned char sig_7z[] = { 0x37, 0x7A, 0xBC, 0xAF, 0x27, 0x1C }; + static const unsigned char sig_rar[] = { 0x52, 0x61, 0x72, 0x21, 0x1A, 0x07 }; + + static const unsigned char sig_exe[] = { 0x4D, 0x5A }; // MZ (DOS/Windows executable) + static const unsigned char sig_elf[] = { 0x7F, 0x45, 0x4C, 0x46 }; // ELF (Unix/Linux executable) + static const unsigned char sig_macho_32[] = { 0xFE, 0xED, 0xFA, 0xCE }; // Mach-O 32-bit + static const unsigned char sig_macho_64[] = { 0xFE, 0xED, 0xFA, 0xCF }; // Mach-O 64-bit + + static const unsigned char sig_mp3_id3[] = { 0x49, 0x44, 0x33 }; // ID3 + static const unsigned char sig_mp3_ff[] = { 0xFF, 0xFB }; + static const unsigned char sig_wav[] = { 0x52, 0x49, 0x46, 0x46 }; // RIFF (check for WAVE at offset 8) + static const unsigned char sig_flac[] = { 0x66, 0x4C, 0x61, 0x43 }; // fLaC + static const unsigned char sig_ogg[] = { 0x4F, 0x67, 0x67, 0x53 }; // OggS + + static const unsigned char sig_mp4[] = { 0x66, 0x74, 0x79, 0x70 }; // ftyp (at offset 4) + static const unsigned char sig_avi[] = { 0x52, 0x49, 0x46, 0x46 }; // RIFF (check for AVI at offset 8) + static const unsigned char sig_mkv[] = { 0x1A, 0x45, 0xDF, 0xA3 }; + + static const magic_signature signatures[] = { + // Images + {sig_png, sizeof(sig_png), file_content_type::IMAGE, 0}, + {sig_jpg1, sizeof(sig_jpg1), file_content_type::IMAGE, 0}, + {sig_jpg2, sizeof(sig_jpg2), file_content_type::IMAGE, 0}, + {sig_jpg3, sizeof(sig_jpg3), file_content_type::IMAGE, 0}, + {sig_jpg4, sizeof(sig_jpg4), file_content_type::IMAGE, 0}, + {sig_gif87, sizeof(sig_gif87), file_content_type::IMAGE, 0}, + {sig_gif89, sizeof(sig_gif89), file_content_type::IMAGE, 0}, + {sig_tiff_le, sizeof(sig_tiff_le), file_content_type::IMAGE, 0}, + {sig_tiff_be, sizeof(sig_tiff_be), file_content_type::IMAGE, 0}, + {sig_bmp, sizeof(sig_bmp), file_content_type::IMAGE, 0}, + + // PDF + {sig_pdf, sizeof(sig_pdf), file_content_type::PDF, 0}, + + // Compressed + {sig_zip, sizeof(sig_zip), file_content_type::COMPRESSED, 0}, + {sig_gzip, sizeof(sig_gzip), file_content_type::COMPRESSED, 0}, + {sig_7z, sizeof(sig_7z), file_content_type::COMPRESSED, 0}, + {sig_rar, sizeof(sig_rar), file_content_type::COMPRESSED, 0}, + + // Executables + {sig_exe, sizeof(sig_exe), file_content_type::EXECUTABLE, 0}, + {sig_elf, sizeof(sig_elf), file_content_type::EXECUTABLE, 0}, + {sig_macho_32, sizeof(sig_macho_32), file_content_type::EXECUTABLE, 0}, + {sig_macho_64, sizeof(sig_macho_64), file_content_type::EXECUTABLE, 0}, + + // Audio + {sig_mp3_id3, sizeof(sig_mp3_id3), file_content_type::AUDIO, 0}, + {sig_mp3_ff, sizeof(sig_mp3_ff), file_content_type::AUDIO, 0}, + {sig_flac, sizeof(sig_flac), file_content_type::AUDIO, 0}, + {sig_ogg, sizeof(sig_ogg), file_content_type::AUDIO, 0}, + + // Video + {sig_mp4, sizeof(sig_mp4), file_content_type::VIDEO, 4}, + {sig_mkv, sizeof(sig_mkv), file_content_type::VIDEO, 0} + }; + + // Portable case-insensitive string comparison (C++14 compatible) + inline bool iequals_n(const char* s1, const char* s2, size_t n) + { + for (size_t i = 0; i < n; ++i) + { + const char c1 = (s1[i] >= 'A' && s1[i] <= 'Z') ? s1[i] + 32 : s1[i]; + const char c2 = (s2[i] >= 'A' && s2[i] <= 'Z') ? s2[i] + 32 : s2[i]; + if (c1 != c2) return false; + } + return true; + } + + // Case-insensitive check for file extension + inline bool has_extension(const std::string& filename, const char* ext) + { + const size_t ext_len = std::strlen(ext); + if (filename.length() < ext_len) return false; + + const size_t start = filename.length() - ext_len; + for (size_t i = 0; i < ext_len; ++i) + { + const char fc = filename[start + i]; + const char ec = ext[i]; + const char fc_lower = (fc >= 'A' && fc <= 'Z') ? fc + 32 : fc; + const char ec_lower = (ec >= 'A' && ec <= 'Z') ? ec + 32 : ec; + if (fc_lower != ec_lower) return false; + } + return true; + } + + // Calculate Shannon entropy for a buffer + inline double calculate_entropy(const unsigned char* buffer, size_t length) + { + if (length == 0) return 0.0; + + // Count byte frequency + std::array counts = {}; + for (size_t i = 0; i < length; ++i) + counts[buffer[i]]++; + + // Calculate entropy using Shannon's formula: H = -sum(p * log2(p)) + double entropy = 0.0; + const double length_d = static_cast(length); + + for (size_t i = 0; i < 256; ++i) + { + if (counts[i] > 0) + { + const double probability = static_cast(counts[i]) / length_d; + entropy -= probability * std::log2(probability); + } + } + + return entropy; + } + + // Check if buffer contains mostly printable ASCII/UTF-8 text + inline bool is_text_content(const unsigned char* buffer, size_t length) + { + if (length == 0) return false; + + size_t printable_count = 0; + size_t whitespace_count = 0; + size_t control_count = 0; + + for (size_t i = 0; i < length; ++i) + { + const unsigned char ch = buffer[i]; + + // Common whitespace characters + if (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r') + { + whitespace_count++; + printable_count++; + } + // Printable ASCII range + else if (ch >= 32 && ch <= 126) + { + printable_count++; + } + // UTF-8 continuation bytes (10xxxxxx) + else if ((ch & 0xC0) == 0x80) + { + printable_count++; + } + // UTF-8 multi-byte sequence starts (110xxxxx, 1110xxxx, 11110xxx) + else if ((ch & 0xE0) == 0xC0 || (ch & 0xF0) == 0xE0 || (ch & 0xF8) == 0xF0) + { + printable_count++; + } + // Control characters (excluding common whitespace) + else if (ch < 32) + { + control_count++; + } + } + + // Consider as text if >90% printable and <10% control chars + const double printable_ratio = static_cast(printable_count) / length; + const double control_ratio = static_cast(control_count) / length; + + return printable_ratio > 0.90 && control_ratio < 0.10; + } + + // Check for XML/HTML markers + inline bool is_xml_content(const unsigned char* buffer, size_t length) + { + if (length < 5) return false; + + const char* str = reinterpret_cast(buffer); + + // Check for "= 5 && buffer[0] == '<' && buffer[1] == '?') + { + if (iequals_n(str + 2, "xml", 3)) + return true; + } + + // Check for HTML doctype (case-insensitive) + if (length >= 9 && buffer[0] == '<' && buffer[1] == '!') + { + if (iequals_n(str + 2, "DOCTYPE", 7)) + return true; + } + + // Check for HTML tags (case-insensitive) + if (length >= 6 && buffer[0] == '<') + { + if (iequals_n(str + 1, "html>", 5) || iequals_n(str + 1, "html ", 5)) + return true; + } + + return false; + } + + // Special check for RIFF-based formats (WAV, AVI, WEBP) + inline file_content_type check_riff_type(const unsigned char* buffer, size_t length) + { + if (length < 12) return file_content_type::UNKNOWN; + + // RIFF format: "RIFF" + size (4 bytes) + format type (4 bytes) + if (std::memcmp(buffer + 8, "WAVE", 4) == 0) + return file_content_type::AUDIO; + else if (std::memcmp(buffer + 8, "AVI ", 4) == 0) + return file_content_type::VIDEO; + else if (std::memcmp(buffer + 8, "WEBP", 4) == 0) + return file_content_type::IMAGE; + + return file_content_type::UNKNOWN; + } + + // Check if ZIP is actually an Office document (DOCX, XLSX, PPTX) + inline file_content_type check_office_type(const std::string& filename) + { + if (has_extension(filename, ".docx") || + has_extension(filename, ".xlsx") || + has_extension(filename, ".pptx")) + { + return file_content_type::OFFICE; + } + + return file_content_type::COMPRESSED; + } + } + + // --------------------------------------------------------------------------------- + + inline bool detect_file_type( + const std::string& filename, + file_content_type& detected_type + ) + { + detected_type = file_content_type::UNKNOWN; + + // Open file in binary mode + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) + return false; + + // Read initial bytes for analysis (8KB should be sufficient) + constexpr size_t BUFFER_SIZE = 8192; + std::array buffer; + + file.read(reinterpret_cast(buffer.data()), BUFFER_SIZE); + const size_t bytes_read = static_cast(file.gcount()); + file.close(); + + if (bytes_read == 0) + return false; + + // Step 1: Check for known magic number signatures + for (const auto& sig : impl::signatures) + { + if (bytes_read >= sig.offset + sig.length) + { + if (std::memcmp(buffer.data() + sig.offset, sig.bytes, sig.length) == 0) + { + detected_type = sig.type; + + // Special handling for RIFF-based formats + if (sig.bytes == impl::sig_webp || sig.bytes == impl::sig_wav || + sig.bytes == impl::sig_avi) + { + const auto riff_type = impl::check_riff_type(buffer.data(), bytes_read); + if (riff_type != file_content_type::UNKNOWN) + detected_type = riff_type; + } + + // Special handling for ZIP (could be Office document) + if (detected_type == file_content_type::COMPRESSED && + sig.bytes == impl::sig_zip) + { + detected_type = impl::check_office_type(filename); + } + + // Binary types + return false; + } + } + } + + // Step 2: Check for XML/HTML content + if (impl::is_xml_content(buffer.data(), bytes_read)) + { + detected_type = file_content_type::TEXT_XML; + return true; + } + + // Step 3: Calculate entropy to distinguish text from binary + const double entropy = impl::calculate_entropy(buffer.data(), bytes_read); + + // Step 4: Use heuristics to classify content + // Entropy thresholds: + // < 5.0 : Likely plain text + // 5.0-6.8: Could be text or structured binary + // > 6.8 : Likely compressed/encrypted/random binary + + const bool is_text = impl::is_text_content(buffer.data(), bytes_read); + + if (is_text && entropy < 6.5) + { + // High probability of plain text (< 5.5) + // Or could be text with some binary content (e.g., source code with special chars) + detected_type = file_content_type::TEXT_PLAIN; + return true; + } + + // Likely binary content (no recognized format) + detected_type = file_content_type::UNKNOWN; + return false; + } + + // --------------------------------------------------------------------------------- + + // Compute Levenshtein (edit) distance between two token sequences + inline size_t edit_distance(const std::vector& tokens1, const std::vector& tokens2) + { + const size_t len1 = tokens1.size(); + const size_t len2 = tokens2.size(); + + if (len1 == 0) return len2; + if (len2 == 0) return len1; + + // DP table: dp[i][j] = edit distance between tokens1[0..i-1] and tokens2[0..j-1] + std::vector> dp(len1 + 1, std::vector(len2 + 1)); + + // Initialize base cases + for (size_t i = 0; i <= len1; ++i) + dp[i][0] = i; + for (size_t j = 0; j <= len2; ++j) + dp[0][j] = j; + + // Fill DP table + for (size_t i = 1; i <= len1; ++i) { + for (size_t j = 1; j <= len2; ++j) { + if (tokens1[i - 1] == tokens2[j - 1]) { + dp[i][j] = dp[i - 1][j - 1]; // No edit needed + } + else { + dp[i][j] = 1 + std::min({ dp[i - 1][j], // Deletion + dp[i][j - 1], // Insertion + dp[i - 1][j - 1] // Substitution + }); + } + } + } + + return dp[len1][len2]; + } + + // Compute normalized edit distance as a similarity score between 0 and 1 + inline double normalized_edit_similarity(const std::vector& tokens1, const std::vector& tokens2) + { + if (tokens1.empty() && tokens2.empty()) + return 1.0; + + const size_t max_len = std::max(tokens1.size(), tokens2.size()); + if (max_len == 0) + return 1.0; + + const size_t dist = edit_distance(tokens1, tokens2); + return 1.0 - (static_cast(dist) / max_len); + } + + // Compute token-level precision, recall, and F1-score + struct token_overlap_metrics + { + double precision; // What fraction of generated tokens appear in reference + double recall; // What fraction of reference tokens appear in generated + double f1_score; // Harmonic mean of precision and recall + + void print() const + { + std::cout << "Token overlap metrics:\n" + << " Precision: " << std::fixed << std::setprecision(4) << (precision * 100.0) << "%\n" + << " Recall: " << std::fixed << std::setprecision(4) << (recall * 100.0) << "%\n" + << " F1-score: " << std::fixed << std::setprecision(4) << (f1_score * 100.0) << "%\n"; + } + }; + + inline token_overlap_metrics compute_token_overlap( + const std::vector& reference, + const std::vector& generated) + { + token_overlap_metrics metrics{ 0.0, 0.0, 0.0 }; + + if (reference.empty() || generated.empty()) + return metrics; + + // Count matching tokens + std::multiset ref_tokens(reference.begin(), reference.end()); + std::multiset gen_tokens(generated.begin(), generated.end()); + + size_t matches = 0; + for (int token : gen_tokens) { + auto it = ref_tokens.find(token); + if (it != ref_tokens.end()) { + ++matches; + ref_tokens.erase(it); // Remove to handle duplicates correctly + } + } + + // Calculate precision and recall + metrics.precision = static_cast(matches) / generated.size(); + metrics.recall = static_cast(matches) / reference.size(); + + // Calculate F1-score + if (metrics.precision + metrics.recall > 0.0) { + metrics.f1_score = 2.0 * (metrics.precision * metrics.recall) / + (metrics.precision + metrics.recall); + } + + return metrics; + } + + // Compute BLEU-like n-gram overlap score + inline double compute_ngram_overlap( + const std::vector& reference, + const std::vector& generated, + int max_n = 4) + { + if (reference.empty() || generated.empty()) + return 0.0; + + double total_score = 0.0; + int valid_n_count = 0; + + // Compute overlap for n-grams of size 1 to max_n + for (int n = 1; n <= max_n; ++n) { + if (static_cast(n) > reference.size() || + static_cast(n) > generated.size()) + break; + + // Extract n-grams from reference + std::map, size_t> ref_ngrams; + for (size_t i = 0; i <= reference.size() - n; ++i) { + std::vector ngram(reference.begin() + i, reference.begin() + i + n); + ref_ngrams[ngram]++; + } + + // Count matching n-grams in generated + size_t matches = 0; + size_t total_gen_ngrams = 0; + for (size_t i = 0; i <= generated.size() - n; ++i) { + std::vector ngram(generated.begin() + i, generated.begin() + i + n); + total_gen_ngrams++; + + auto it = ref_ngrams.find(ngram); + if (it != ref_ngrams.end() && it->second > 0) { + matches++; + it->second--; // Decrement to handle multiple occurrences + } + } + + if (total_gen_ngrams > 0) { + total_score += static_cast(matches) / total_gen_ngrams; + valid_n_count++; + } + } + + // Return average n-gram precision + return valid_n_count > 0 ? total_score / valid_n_count : 0.0; + } + + // Text similarity report + struct text_similarity_report + { + double edit_similarity; // Normalized Levenshtein distance + token_overlap_metrics overlap; // Token-level precision/recall/F1 + double ngram_score; // N-gram overlap (BLEU-like) + + void print() const + { + std::cout << "\n=== Text similarity report ===\n"; + std::cout << "Edit similarity (order-sensitive): " + << std::fixed << std::setprecision(4) << (edit_similarity * 100.0) << "%\n\n"; + + overlap.print(); + + std::cout << "\nN-gram overlap (BLEU-like): " + << std::fixed << std::setprecision(4) << (ngram_score * 100.0) << "%\n"; + std::cout << "==============================\n\n"; + } + }; + + inline text_similarity_report compute_text_similarity( + const std::vector& reference, + const std::vector& generated) + { + text_similarity_report report; + + report.edit_similarity = normalized_edit_similarity(reference, generated); + report.overlap = compute_token_overlap(reference, generated); + report.ngram_score = compute_ngram_overlap(reference, generated, 4); + + return report; + } + + class inference_context + { + public: + inference_context( + long window_size = 256, + long context_multiplier = 10, + long padding_token = 0 + ) : window_size_(window_size), + context_capacity_(window_size * context_multiplier), + padding_token_(padding_token), + current_size_(0) + { + DLIB_CASSERT(window_size > 0, "Window size must be positive"); + DLIB_CASSERT(context_multiplier > 0, "Context multiplier must be positive"); + context_.reserve(context_capacity_); + } + + void add_token(unsigned long token) + { + if (current_size_ == context_capacity_) + { + // FIFO: remove oldest, add newest + context_.erase(context_.begin()); + context_.push_back(static_cast(token)); + } + else + { + // Still room in context + context_.push_back(static_cast(token)); + current_size_++; + } + } + + void add_tokens(const std::vector& tokens) + { + for (unsigned long token : tokens) add_token(token); + } + + void add_tokens(const std::vector& tokens) + { + for (int token : tokens) add_token(static_cast(token)); + } + + matrix get_input_window(long custom_window_size = -1) const + { + long win_size = (custom_window_size > 0) ? custom_window_size : window_size_; + matrix window(win_size, 1); + + if (current_size_ >= win_size) + { + // Context has enough tokens - take last win_size tokens + for (long i = 0; i < win_size; ++i) + window(i) = context_[current_size_ - win_size + i]; + } + else + { + // Context has fewer tokens - left pad + long padding_needed = win_size - current_size_; + + for (long i = 0; i < padding_needed; ++i) + window(i) = padding_token_; + for (long i = 0; i < current_size_; ++i) + window(padding_needed + i) = context_[i]; + } + + return window; + } + + void reset() + { + context_.clear(); + current_size_ = 0; + } + + void resize_context(long new_capacity) + { + DLIB_CASSERT(new_capacity > 0, "New capacity must be positive"); + + if (new_capacity < current_size_) + { + // Keep only the last new_capacity tokens + context_.erase(context_.begin(), context_.begin() + (current_size_ - new_capacity)); + current_size_ = new_capacity; + } + + context_capacity_ = new_capacity; + context_.reserve(context_capacity_); + } + + long size() const { return current_size_; } + long capacity() const { return context_capacity_; } + long window_size() const { return window_size_; } + bool is_full() const { return current_size_ >= context_capacity_; } + const std::vector& get_full_context() const { return context_; } + + std::string to_string(bool show_all = false) const + { + std::ostringstream ss; + ss << "InferenceContext[size=" << current_size_ + << "/" << context_capacity_ + << ", window=" << window_size_ << "]\n"; + + if (show_all && current_size_ > 0) + { + ss << "Tokens: ["; + long display_count = show_all ? current_size_ : std::min(20L, current_size_); + for (long i = 0; i < display_count; ++i) + { + ss << context_[i]; + if (i < display_count - 1) ss << ", "; + } + if (current_size_ > display_count) + { + ss << " ... +" << (current_size_ - display_count) << " more"; + } + ss << "]"; + } + + return ss.str(); + } + + friend void serialize(const inference_context& item, std::ostream& out) + { + serialize("inference_context", out); + serialize(item.window_size_, out); + serialize(item.context_capacity_, out); + serialize(item.padding_token_, out); + serialize(item.current_size_, out); + serialize(item.context_, out); + } + + friend void deserialize(inference_context& item, std::istream& in) + { + std::string name; + deserialize(name, in); + if (name != "inference_context") + { + throw serialization_error("Error deserializing object of type 'inference_context': " + "expected 'inference_context' but got '" + name + "'"); + } + + deserialize(item.window_size_, in); + deserialize(item.context_capacity_, in); + deserialize(item.padding_token_, in); + deserialize(item.current_size_, in); + deserialize(item.context_, in); + } + + private: + std::vector context_; // Full context history + long window_size_; // Window size for model input + long context_capacity_; // Maximum context size + long padding_token_; // Token used for left padding + long current_size_; // Current number of tokens + }; + + inline void build_single_token_prediction_dataset( + const std::vector>& token_sequences, + long window_len, + long padding_token, + bool use_left_padding, + std::vector>& X, + std::vector& Y) + { + X.clear(); + Y.clear(); + + for (const auto& seq : token_sequences) + { + const long len = static_cast(seq.size()); + if (len <= 1) continue; + + long start = 0; + if (len < window_len) + { + if (!use_left_padding) continue; + start = (len - window_len); + } + + // Generate initial padded samples for sequences >= window_len + if (use_left_padding && len >= window_len) + { + for (long pos = 1; pos < window_len; ++pos) + { + matrix window(window_len, 1); + long pad = window_len - pos; + + for (long i = 0; i < pad; ++i) window(i) = padding_token; + for (long i = 0; i < pos; ++i) window(pad + i) = seq[i]; + + X.push_back(window); + Y.push_back(seq[pos]); + } + } + + // Slide window through sequence + for (long pos = start; pos < len - 1; ++pos) + { + matrix window(window_len, 1); + + for (long i = 0; i < window_len; ++i) + { + long idx = pos + i; + window(i) = (idx >= 0 && idx < len) ? seq[idx] : padding_token; + } + + long target_idx = pos + window_len; + if (target_idx >= 0 && target_idx < len) + { + X.push_back(window); + Y.push_back(seq[target_idx]); + } + } + } + } + + inline void build_multi_token_prediction_dataset( + const std::vector>& source_sequences, + const std::vector>& target_sequences, + long src_window_len, + long tgt_window_len, + long padding_token, + std::vector>& X, + std::vector>& Y) + { + DLIB_CASSERT(source_sequences.size() == target_sequences.size(), + "Source and target must have same size"); + + X.clear(); + Y.clear(); + + for (size_t i = 0; i < source_sequences.size(); ++i) + { + const auto& src = source_sequences[i]; + const auto& tgt = target_sequences[i]; + + const long src_len = static_cast(src.size()); + const long tgt_len = static_cast(tgt.size()); + + if (src_len == 0 || tgt_len == 0) continue; + + long src_pos = (src_len < src_window_len) ? (src_len - src_window_len) : 0; + long tgt_pos = 0; + + while (true) + { + // Build source window + matrix src_window(src_window_len, 1); + long src_real = 0; + + for (long j = 0; j < src_window_len; ++j) + { + long idx = src_pos + j; + if (idx >= 0 && idx < src_len) + { + src_window(j) = src[idx]; + src_real++; + } + else + { + src_window(j) = padding_token; + } + } + + // Build target window + matrix tgt_window(tgt_window_len, 1); + long tgt_real = 0; + + for (long j = 0; j < tgt_window_len; ++j) + { + long idx = tgt_pos + j; + if (idx < tgt_len) + { + tgt_window(j) = tgt[idx]; + tgt_real++; + } + else + { + tgt_window(j) = padding_token; + } + } + + // Stop if no real tokens in either window + if (src_real == 0 || tgt_real == 0) break; + + X.push_back(src_window); + Y.push_back(tgt_window); + + // Stop if both sequences fully consumed + if (src_pos + src_window_len >= src_len && + tgt_pos + tgt_window_len >= tgt_len) break; + + src_pos++; + tgt_pos++; + } + } + } + + template + void shuffle_training_dataset( + std::vector& samples, + std::vector& labels, + unsigned long seed = 0) + { + DLIB_CASSERT(samples.size() == labels.size(), + "samples and labels must have the same size"); + + const size_t dataset_size = samples.size(); + if (dataset_size <= 1) return; + + dlib::rand rng; + if (seed != 0) rng = dlib::rand(seed); + + // Fisher-Yates shuffle algorithm + for (size_t i = dataset_size - 1; i > 0; --i) + { + size_t j = rng.get_random_32bit_number() % (i + 1); + + // Swap samples[i] with samples[j] + std::swap(samples[i], samples[j]); + + // Swap labels[i] with labels[j] + std::swap(labels[i], labels[j]); + } + } + + template + void augment_training_dataset( + std::vector& samples, + std::vector& labels, + int unk_token, + int padding_token, + double augmentation_ratio = 0.2, + long min_noise_tokens = 1, + long max_noise_tokens = 3, + unsigned long seed = 0) + { + DLIB_CASSERT(samples.size() == labels.size(), + "samples and labels must have the same size"); + DLIB_CASSERT(augmentation_ratio >= 0.0 && augmentation_ratio <= 2.0, + "augmentation_ratio must be between 0.0 and 2.0"); + DLIB_CASSERT(min_noise_tokens >= 0 && max_noise_tokens >= min_noise_tokens, + "Invalid noise token range: min=" << min_noise_tokens << ", max=" << max_noise_tokens); + + const size_t original_size = samples.size(); + if (original_size == 0 || augmentation_ratio == 0.0) return; + + // Calculate number of augmented samples to create + const size_t num_augmented = static_cast(original_size * augmentation_ratio); + if (num_augmented == 0) return; + + // Reserve space to avoid multiple reallocations + samples.reserve(original_size + num_augmented); + labels.reserve(original_size + num_augmented); + + dlib::rand rng; + if (seed != 0) rng = dlib::rand(seed); + + for (size_t aug_idx = 0; aug_idx < num_augmented; ++aug_idx) + { + // Select a random sample to augment + const size_t source_idx = rng.get_random_32bit_number() % original_size; + + // Create a copy of the sample and its label + auto augmented_sample = samples[source_idx]; + auto augmented_label = labels[source_idx]; + + // Identify non-padding positions in the sample + std::vector valid_positions; + const long sample_length = augmented_sample.nr(); + + for (long i = 0; i < sample_length; ++i) + { + if (augmented_sample(i) != padding_token) + valid_positions.push_back(i); + } + + // Skip if no valid positions to add noise + if (valid_positions.empty()) continue; + + // Determine number of tokens to replace with noise + const long num_valid = static_cast(valid_positions.size()); + const long effective_max = std::min(max_noise_tokens, num_valid); + const long effective_min = std::min(min_noise_tokens, effective_max); + + long num_noise = effective_min; + if (effective_max > effective_min) + { + num_noise = effective_min + + (rng.get_random_32bit_number() % (effective_max - effective_min + 1)); + } + + // Ensure noise ratio is reasonable (max 30% of non-padding tokens) + const long max_reasonable = std::max(1L, static_cast(num_valid * 0.3)); + num_noise = std::min(num_noise, max_reasonable); + + // Randomly select positions to replace with UNK + std::vector noise_positions = valid_positions; + + // Fisher-Yates shuffle to select random positions + for (long i = static_cast(noise_positions.size()) - 1; i > 0; --i) + { + long j = rng.get_random_32bit_number() % (i + 1); + std::swap(noise_positions[i], noise_positions[j]); + } + + // Apply noise to the first num_noise positions + for (long i = 0; i < num_noise; ++i) + { + augmented_sample(noise_positions[i]) = unk_token; + } + + // Add augmented sample and label to the dataset + samples.push_back(std::move(augmented_sample)); + labels.push_back(std::move(augmented_label)); + } + } + +} // namespace dlib + +#endif // DLIB_LANGUAGE_MODEL_DATA_H_ \ No newline at end of file diff --git a/dlib/data_io/language_model_data_abstract.h b/dlib/data_io/language_model_data_abstract.h new file mode 100644 index 0000000000..2b797223e2 --- /dev/null +++ b/dlib/data_io/language_model_data_abstract.h @@ -0,0 +1,556 @@ +// Copyright (C) 2025 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_LANGUAGE_MODEL_DATA_ABSTRACT_H_ +#ifdef DLIB_LANGUAGE_MODEL_DATA_ABSTRACT_H_ + +#include +#include +#include +#include "../matrix.h" +#include "../serialize.h" + +namespace dlib +{ + // --------------------------------------------------------------------------------- + + enum class file_content_type + { + /*! + WHAT THIS ENUM REPRESENTS + Enumeration of recognized file content types for classification purposes. + Used by detect_file_type() to identify the nature of file contents. + + VALUES + TEXT_PLAIN - Plain text files (including CSV, source code, logs, etc.) + TEXT_XML - XML or HTML markup documents + IMAGE - Image formats (PNG, JPEG, GIF, TIFF, BMP, WEBP, etc.) + VIDEO - Video formats (MP4, AVI, MKV, etc.) + AUDIO - Audio formats (MP3, WAV, FLAC, OGG, etc.) + EXECUTABLE - Executable binary files (EXE, DLL, ELF, Mach-O) + COMPRESSED - Compressed archives (ZIP, GZIP, 7Z, RAR, etc.) + PDF - PDF documents + OFFICE - Office documents (DOCX, XLSX, PPTX) + UNKNOWN - File type could not be determined or is not recognized + + NOTES + - Detection is based on file content analysis, not file extensions + - Magic number signatures are checked first for binary formats + - Entropy analysis and heuristics are used for text vs binary classification + !*/ + }; + + // --------------------------------------------------------------------------------- + + inline bool detect_file_type( + const std::string& filename, + file_content_type& detected_type + ); + /*! + ensures + - Efficiently detects the content type of a file by analyzing its internal + structure using magic number signatures and entropy-based heuristics + - Opens and reads the first 8KB of the file for analysis + - Returns true if file contains text-based content (TEXT_PLAIN or TEXT_XML) + - Returns false if file contains binary content or cannot be opened + - Sets detected_type to the most specific content type that could be identified + - If file cannot be opened, returns false and sets detected_type to UNKNOWN + + FILE DETECTION METHODOLOGY + The function uses a multi-stage detection process: + + Stage 1: magic number detection (Binary Formats) + - Checks for ~30 common file format signatures (magic numbers) + - Supported formats include: + * Images: PNG, JPEG (4 variants), GIF (87a/89a), TIFF (LE/BE), BMP, WEBP + * Documents: PDF + * Compressed: ZIP, GZIP, 7Z, RAR + * Executables: Windows PE (EXE/DLL), Unix ELF, macOS Mach-O (32/64-bit) + * Audio: MP3 (ID3/FF), WAV, FLAC, OGG + * Video: MP4, AVI, MKV + - Special handling for container formats: + * RIFF containers (WAV/AVI/WEBP) are distinguished by format identifier + * ZIP files are checked against filename to detect Office documents (DOCX/XLSX/PPTX) + - If magic number is found, returns false (binary) with appropriate type + + Stage 2: XML/HTML detection + - Checks for XML declarations (90% printable characters + * <10% control characters + * Entropy < 5.5 (high confidence text) + * Entropy < 6.5 (text with special characters) + * Entropy >= 6.8 (likely binary/compressed/encrypted) + + TYPICAL USAGE + file_content_type type; + + // Detect file type + bool is_text = detect_file_type("document.pdf", type); + + if (type == file_content_type::PDF) + std::cout << "PDF document detected\n"; + else if (type == file_content_type::IMAGE) + std::cout << "Image file detected\n"; + else if (is_text) + std::cout << "Text file detected\n"; + else + std::cout << "Binary file or unknown format\n"; + + // Filter text files for processing + std::vector filenames = get_file_list(); + for (const auto& fname : filenames) + { + file_content_type ftype; + if (detect_file_type(fname, ftype)) + { + // Process text file + process_text_file(fname); + } + } + !*/ + + // --------------------------------------------------------------------------------- + + inline size_t edit_distance( + const std::vector& tokens1, + const std::vector& tokens2 + ); + /*! + ensures + - Computes the Levenshtein (edit) distance between two token sequences + - Returns the minimum number of single-token edits (insertions, deletions, + or substitutions) required to transform tokens1 into tokens2 + - Uses dynamic programming with O(n*m) time complexity and O(n*m) space + - Returns tokens2.size() if tokens1 is empty + - Returns tokens1.size() if tokens2 is empty + - Returns 0 if both sequences are identical + !*/ + + inline double normalized_edit_similarity( + const std::vector& tokens1, + const std::vector& tokens2 + ); + /*! + ensures + - Computes a normalized similarity score based on edit distance + - Returns a value in the range [0.0, 1.0] where: + * 1.0 indicates identical sequences + * 0.0 indicates completely different sequences + - Formula: 1.0 - (edit_distance / max_length) + - If both sequences are empty, returns 1.0 (considered identical) + - This metric is order-sensitive: [1,2,3] vs [3,2,1] will have low similarity + !*/ + + // --------------------------------------------------------------------------------- + + struct token_overlap_metrics + { + /*! + WHAT THIS OBJECT REPRESENTS + Stores token-level evaluation metrics that treat sequences as + bags of tokens (order-independent). Useful for assessing vocabulary + overlap between reference and generated text. + + FIELDS + precision - Fraction of generated tokens that appear in the reference + Range: [0.0, 1.0] + Formula: matching_tokens / total_generated_tokens + + recall - Fraction of reference tokens that appear in the generated text + Range: [0.0, 1.0] + Formula: matching_tokens / total_reference_tokens + + f1_score - Harmonic mean of precision and recall + Range: [0.0, 1.0] + Formula: 2 * (precision * recall) / (precision + recall) + + INTERPRETATION + - High precision: generated text uses vocabulary from reference + - High recall: generated text covers reference vocabulary + - High F1: good balance between precision and recall + - Unlike edit distance, this metric ignores token order + !*/ + + double precision; + double recall; + double f1_score; + + void print() const; + /*! + ensures + - Prints formatted metrics to standard output + - Format: "Precision: XX.XX%\n Recall: XX.XX%\n F1-score: XX.XX%" + !*/ + }; + + inline token_overlap_metrics compute_token_overlap( + const std::vector& reference, + const std::vector& generated + ); + /*! + ensures + - Computes token-level precision, recall, and F1-score between reference + and generated token sequences + - Treats sequences as multisets (bags) of tokens, ignoring order + - Handles duplicate tokens correctly by matching each token at most once + - Returns metrics with all values set to 0.0 if either sequence is empty + - Precision = fraction of generated tokens found in reference + - Recall = fraction of reference tokens found in generated + - F1 = harmonic mean of precision and recall + !*/ + + // --------------------------------------------------------------------------------- + + inline double compute_ngram_overlap( + const std::vector& reference, + const std::vector& generated, + int max_n = 4 + ); + /*! + requires + - max_n >= 1 + ensures + - Computes n-gram overlap score similar to BLEU metric + - Evaluates matching n-grams for n = 1, 2, 3, ..., max_n + - Returns average n-gram precision across all n values + - Score range: [0.0, 1.0] where 1.0 is perfect overlap + - Returns 0.0 if either sequence is empty + - Stops computing for n-values where n > sequence length + + COMPARISON TO BLEU + - Similar to BLEU but simplified (no brevity penalty, no geometric mean) + - Uses arithmetic mean instead of geometric mean for simplicity + - Suitable for quick similarity assessment in language model evaluation + !*/ + + // --------------------------------------------------------------------------------- + + struct text_similarity_report + { + /*! + WHAT THIS OBJECT REPRESENTS + Comprehensive similarity report combining multiple metrics to evaluate + how closely generated text matches reference text. Provides both + order-sensitive and order-insensitive measures. + + FIELDS + edit_similarity - Normalized Levenshtein distance (order-sensitive) + Range: [0.0, 1.0] + Measures token-by-token match considering order + + overlap - Token-level precision/recall/F1 metrics + Order-insensitive bag-of-tokens comparison + Useful for vocabulary coverage assessment + + ngram_score - BLEU-like n-gram overlap score (order-aware locally) + Range: [0.0, 1.0] + Captures phrase-level similarity + + INTERPRETATION GUIDE + Use edit_similarity when: + - Exact token order matters + - Evaluating sequence prediction tasks + - Need strict alignment measure + + Use overlap metrics when: + - Vocabulary coverage is important + - Order is less critical + - Want to know what fraction of tokens are correct + + Use ngram_score when: + - Local phrase structure matters + - Evaluating fluency and coherence + - Need metric between strict order and pure bag-of-words + !*/ + + double edit_similarity; + token_overlap_metrics overlap; + double ngram_score; + + void print() const; + /*! + ensures + - Prints comprehensive formatted report to standard output + - Displays all three metric categories with clear labels + - Format optimized for readability with percentages and section headers + !*/ + }; + + inline text_similarity_report compute_text_similarity( + const std::vector& reference, + const std::vector& generated + ); + /*! + ensures + - Computes comprehensive similarity metrics between reference and generated + token sequences + - Returns text_similarity_report containing: + * edit_similarity: normalized Levenshtein distance + * overlap: token-level precision/recall/F1 scores + * ngram_score: BLEU-like n-gram overlap (up to 4-grams) + - This is the primary function for evaluating text generation quality + - Provides multiple complementary views of similarity + !*/ + + // --------------------------------------------------------------------------------- + + class inference_context + { + /*! + WHAT THIS OBJECT REPRESENTS + This class manages a token context for inference with language models. + It maintains a full history context and provides a sliding window view + for model input. + + Features: + - Full context history with configurable capacity + - Sliding window extraction for model input + - Left padding when context not full + - FIFO policy when context reaches capacity + - Dynamic resizing without data loss + + TYPICAL USAGE + inference_context ctx(256, 10, 0); // window=256, capacity=2560, pad=0 + + ctx.add_tokens({1, 2, 3, 4, 5}); // Add tokens + auto input = ctx.get_input_window(); // Get last 256 tokens (padded if needed) + + // Feed to model, get prediction, add to context + unsigned long next_token = model(input); + ctx.add_token(next_token); + !*/ + public: + inference_context( + long window_size = 256, + long context_multiplier = 10, + long padding_token = 0 + ); + /*! + requires + - window_size > 0 + - context_multiplier > 0 + ensures + - Constructs an inference context manager + - context_capacity = window_size * context_multiplier + - Context is initially empty (will be left-padded) + !*/ + + void add_token(unsigned long token); + /*! + ensures + - Adds a single token to the context + - If context is full, removes oldest token (FIFO) + - New token is always added at the end + !*/ + + void add_tokens(const std::vector& tokens); + void add_tokens(const std::vector& tokens); + /*! + ensures + - Adds multiple tokens to the context + - Tokens are added in order + - FIFO policy applies if capacity exceeded + !*/ + + matrix get_input_window(long custom_window_size = -1) const; + /*! + ensures + - Returns a window of tokens suitable for model input + - Window size is custom_window_size if specified, otherwise window_size_ + - Window contains the last N tokens from context + - Left-padded with padding_token if context has fewer than N tokens + - Returns matrix of shape (N, 1) compatible with Dlib + !*/ + + void reset(); + /*! + ensures + - Clears all tokens from context + - Resets current_size to 0 + - Context capacity remains unchanged + !*/ + + void resize_context(long new_capacity); + /*! + requires + - new_capacity > 0 + ensures + - Resizes the context capacity + - Preserves existing tokens (up to new capacity) + - If new_capacity < current_size, keeps only the last new_capacity tokens + !*/ + + long size() const; + /*! + ensures + - Returns the current number of tokens in context + !*/ + + long capacity() const; + /*! + ensures + - Returns the maximum capacity of the context + !*/ + + long window_size() const; + /*! + ensures + - Returns the default window size for model input + !*/ + + bool is_full() const; + /*! + ensures + - Returns true if context is at full capacity + !*/ + + const std::vector& get_full_context() const; + /*! + ensures + - Returns a const reference to the full context vector + !*/ + + std::string to_string(bool show_all = false) const; + /*! + ensures + - Returns a string representation of the context for debugging + !*/ + + friend void serialize(const inference_context& item, std::ostream& out); + /*! + ensures + - Serializes the inference_context to an output stream + - Saves all context data and configuration parameters + !*/ + + friend void deserialize(inference_context& item, std::istream& in); + /*! + ensures + - Deserializes the inference_context from an input stream + - Restores all context data and configuration parameters + !*/ + + private: + std::vector context_; // Full context history + long context_capacity_; // Maximum context size + long window_size_; // Window size for model input + long padding_token_; // Token used for left padding + long current_size_; // Current number of tokens + }; + + inline void build_single_token_prediction_dataset( + const std::vector>& token_sequences, + long window_len, + long padding_token, + bool use_left_padding, + std::vector>& X, + std::vector& Y); + /*! + ensures + - Constructs training samples for single next-token prediction using a sliding window approach + - For each sequence, creates input windows of size window_len paired with the immediately following token + - If use_left_padding is true: + * Sequences shorter than window_len are left-padded with padding_token + * Sequences >= window_len generate initial samples with progressive left padding + - If use_left_padding is false: + * Sequences shorter than window_len are skipped + - Returns samples in X (input windows) and Y (target tokens) + - X contains matrix of shape (window_len, 1) + - Y contains unsigned long values representing the next token + !*/ + + inline void build_multi_token_prediction_dataset( + const std::vector>& source_sequences, + const std::vector>& target_sequences, + long src_window_len, + long tgt_window_len, + long padding_token, + std::vector>& X, + std::vector>& Y); + /*! + requires + - source_sequences.size() == target_sequences.size() + - src_window_len > 0 + - tgt_window_len > 0 + ensures + - Constructs training samples for sequence-to-sequence prediction + - For each (source, target) pair, creates aligned windows that slide synchronously + - Source windows are left-padded with padding_token when source length < src_window_len + - Target windows are right-padded with padding_token when insufficient tokens remain + - Sliding continues while both windows contain at least one real (non-padding) token + - Stops when both sequences are fully consumed (all tokens have appeared in windows) + - Returns samples in X (source windows) and Y (target windows) + - X contains matrix of shape (src_window_len, 1) + - Y contains matrix of shape (tgt_window_len, 1) + !*/ + + template + void shuffle_training_dataset( + std::vector& samples, + std::vector& labels, + unsigned long seed = 0 + ); + /*! + requires + - samples.size() == labels.size() + ensures + - Randomly shuffles the training dataset in-place + - Applies the same permutation to both samples and labels to maintain correspondence + - If seed == 0, uses a random seed based on current time + - If seed != 0, uses the provided seed for reproducible shuffling + - After shuffling, samples[i] still corresponds to labels[i] + - Uses Fisher-Yates shuffle algorithm for uniform random permutation + !*/ + + template + void augment_training_dataset( + std::vector& samples, + std::vector& labels, + int unk_token, + int padding_token, + double augmentation_ratio = 0.2, + long min_noise_tokens = 1, + long max_noise_tokens = 3, + unsigned long seed = 0 + ); + /*! + requires + - samples.size() == labels.size() + - 0.0 <= augmentation_ratio <= 2.0 + - min_noise_tokens >= 0 + - max_noise_tokens >= min_noise_tokens + ensures + - Augments the training dataset by adding noisy copies of existing samples + - Creates floor(samples.size() * augmentation_ratio) new augmented samples + - For each augmented sample: + * Randomly selects a source sample from the original dataset + * Creates a copy of the sample and its corresponding label + * Randomly replaces between min_noise_tokens and max_noise_tokens + non-padding tokens with unk_token + * Only tokens != padding_token are eligible for noise injection + * Number of noise tokens is capped at 30% of non-padding tokens + to maintain sample quality + - Corresponding labels are appended to labels vector (unchanged) + - Original samples and labels are preserved + - If seed == 0, uses random seed based on current time + - If seed != 0, uses provided seed for reproducible augmentation + - Default augmentation_ratio of 0.2 (20%) follows common practices + in language model training literature + !*/ + +} // namespace dlib + +#endif // DLIB_LANGUAGE_MODEL_DATA_ABSTRACT_H_ \ No newline at end of file diff --git a/dlib/dnn.h b/dlib/dnn.h index bc38dc4b73..313c19b6f7 100644 --- a/dlib/dnn.h +++ b/dlib/dnn.h @@ -32,6 +32,7 @@ #include "dnn/utilities.h" #include "dnn/validation.h" #include "dnn/visitors.h" +#include "dnn/transformer.h" #endif // DLIB_DNn_ diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 6f9389fced..cde2f7ed9f 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -1017,19 +1017,10 @@ namespace dlib void setup(const SUBNET& sub) { const auto& input = sub.get_output(); - input_k = input.k(); - input_nr = input.nr(); - input_nc = input.nc(); - - // Calculate output dimensions using input dims where target is -1 - if (k_ == -1) output_k = input_k; - if (nr_ == -1) output_nr = input_nr; - if (nc_ == -1) output_nc = input_nc; + update_dimensions_from_input(input); - // Check if this is well a pure reshape long input_elements = input_k * input_nr * input_nc; long output_elements = output_k * output_nr * output_nc; - if (input_elements != output_elements && input_k == output_k) needs_rescale = true; DLIB_CASSERT(input_elements == output_elements || needs_rescale, "Cannot reshape tensor of " << input_elements << " elements into shape with " << output_elements << " elements. " << @@ -1039,8 +1030,14 @@ namespace dlib template void forward(const SUBNET& sub, resizable_tensor& output) { - // Set the output size (always preserving batch dimension) const tensor& input = sub.get_output(); + + // Check if dimensions changed (after deserialization or fine-tuning) + // This ensures dimensions are always synchronized with current input + if (input_k != input.k() || input_nr != input.nr() || input_nc != input.nc()) + update_dimensions_from_input(input); + + // Set the output size (always preserving batch dimension) output.set_size(input.num_samples(), output_k, output_nr, output_nc); if (!needs_rescale) @@ -1142,7 +1139,25 @@ namespace dlib << "/>\n"; } - private: + private: + void update_dimensions_from_input(const tensor& input) + { + // Update input dimensions + input_k = input.k(); + input_nr = input.nr(); + input_nc = input.nc(); + + // Recalculate output dimensions for dynamic axes (-1) + if (k_ == -1) output_k = input_k; + if (nr_ == -1) output_nr = input_nr; + if (nc_ == -1) output_nc = input_nc; + + // Check if rescaling is needed + long input_elements = input_k * input_nr * input_nc; + long output_elements = output_k * output_nr * output_nc; + needs_rescale = (input_elements != output_elements && input_k == output_k); + } + long input_k, input_nr, input_nc; // Input dimensions long output_k, output_nr, output_nc; // Output dimensions bool needs_rescale; @@ -2407,7 +2422,7 @@ namespace dlib { const auto& prev_output = sub.get_output(); DLIB_CASSERT((long)num_inputs == prev_output.nc(), - "The size of the input tensor to this linear layer doesn't match the size the linear layer was trained with."); + "The size of the input tensor to this linear layer doesn't match the size the linear layer was trained with."); output.set_size(prev_output.num_samples(), prev_output.k(), prev_output.nr(), num_outputs); auto o = alias_tensor(output.num_samples() * output.k() * output.nr(), num_outputs)(output, 0); @@ -2441,8 +2456,6 @@ namespace dlib } } - //prev_gradient is not const, so that sgi isn't const - //since sgi is used as a destination for tt::gemm auto& prev_gradient = sub.get_gradient_input(); alias_tensor_instance sgi = alias_tensor(prev_gradient.num_samples() * prev_gradient.k() * prev_gradient.nr(), num_inputs)(prev_gradient, 0); auto w = weights(params, 0); @@ -5441,7 +5454,8 @@ namespace dlib embeddings_() : num_embeddings(num_embeddings_), embedding_dim(embedding_dim_), learning_rate_multiplier(1.0f), - scale_by_freq(true) + scale_by_freq(true), + output_scale(std::sqrt(static_cast(embedding_dim_))) { } @@ -5473,12 +5487,17 @@ namespace dlib } } + float get_output_scale() const { return output_scale; } + template void setup(const SUBNET& /*sub*/) { embs.set_size(num_embeddings, embedding_dim); tt::tensor_rand rnd(std::rand()); rnd.fill_gaussian(embs); + + const float init_scale = 1.0f / std::sqrt(static_cast(embedding_dim)); + tt::affine_transform(embs, embs, init_scale); } template @@ -5488,6 +5507,7 @@ namespace dlib output.set_size(prev.num_samples(), prev.k(), prev.nr(), embedding_dim); tt::embeddings(output, prev, embs); + tt::affine_transform(output, output, output_scale); } template @@ -5502,7 +5522,8 @@ namespace dlib auto& prev_src = sub.get_output(); calc_token_freqs(prev_src, gradient_input); - tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, learning_rate_multiplier, scale_by_freq); + const float scaled_lr = learning_rate_multiplier * output_scale; + tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, scaled_lr, scale_by_freq); } } @@ -5520,6 +5541,7 @@ namespace dlib serialize(item.embedding_dim, out); serialize(item.learning_rate_multiplier, out); serialize(item.scale_by_freq, out); + serialize(item.output_scale, out); } friend void deserialize(embeddings_& item, std::istream& in) { @@ -5532,12 +5554,14 @@ namespace dlib deserialize(item.embedding_dim, in); deserialize(item.learning_rate_multiplier, in); deserialize(item.scale_by_freq, in); + deserialize(item.output_scale, in); } friend std::ostream& operator<<(std::ostream& out, const embeddings_& item) { out << "embeddings (num_embeddings=" << item.num_embeddings << ", embedding_dim=" << item.embedding_dim + << ", scale=" << item.output_scale << ") learning_rate_mult=" << item.learning_rate_multiplier; return out; } @@ -5545,6 +5569,7 @@ namespace dlib { out << "\n"; out << mat(item.embs); @@ -5576,6 +5601,7 @@ namespace dlib unsigned long num_embeddings, embedding_dim; double learning_rate_multiplier; bool scale_by_freq; + float output_scale; }; template < @@ -5587,6 +5613,113 @@ namespace dlib // ---------------------------------------------------------------------------------------- + class tril_padding_context + { + public: + static void set(const tensor& input_tokens, long padding_token) + { + if (padding_token < 0) { + clear(); + return; + } + std::lock_guard lock(get_mutex_()); + const long batch_size = input_tokens.num_samples(); + const long seq_len = input_tokens.nr(); + const float* data = input_tokens.host(); + get_padding_lengths_().resize(batch_size); + for (long s = 0; s < batch_size; ++s) + { + long count = 0; + for (long t = 0; t < seq_len; ++t) + { + const long idx = s * seq_len + t; + const long token = static_cast(data[idx]); + if (token == padding_token) + count++; + else + break; + } + get_padding_lengths_()[s] = count; + } + get_is_set_() = true; + } + + static void set_from_lengths(const std::vector& lengths) + { + std::lock_guard lock(get_mutex_()); + get_padding_lengths_() = lengths; + get_is_set_() = true; + } + + static void set_uniform(long padding_length, long batch_size) + { + std::lock_guard lock(get_mutex_()); + get_padding_lengths_().assign(batch_size, padding_length); + get_is_set_() = true; + } + + static void clear() + { + std::lock_guard lock(get_mutex_()); + get_padding_lengths_().clear(); + get_is_set_() = false; + } + + static long get_padding_length(long sample_idx) + { + std::lock_guard lock(get_mutex_()); + if (!get_is_set_() || sample_idx < 0 || + sample_idx >= static_cast(get_padding_lengths_().size())) + return 0; + return get_padding_lengths_()[sample_idx]; + } + + static std::vector get_all_lengths() + { + std::lock_guard lock(get_mutex_()); + return get_padding_lengths_(); + } + + static bool is_set() + { + std::lock_guard lock(get_mutex_()); + return get_is_set_(); + } + + private: + static std::mutex& get_mutex_() + { + static std::mutex m; + return m; + } + + static std::vector& get_padding_lengths_() + { + static std::vector lengths; + return lengths; + } + + static bool& get_is_set_() + { + static bool is_set = false; + return is_set; + } + }; + + template + long count_leading_padding(const matrix& seq, T padding_token) + { + long count = 0; + for (long i = 0; i < seq.size(); ++i) + { + if (seq(i) == padding_token) count++; + else break; + } + return count; + } + +// ---------------------------------------------------------------------------------------- + struct neg_infinity_tag {}; struct zero_tag {}; @@ -5601,7 +5734,7 @@ namespace dlib class tril_ { public: - tril_(): diag(diag_), diag_value(compute_diag_value()) {} + tril_(): diag(diag_), prefix_size(0), diag_value(compute_diag_value()) {} template void setup(const SUBNET& /*sub*/) @@ -5614,10 +5747,28 @@ namespace dlib auto& prev = sub.get_output(); output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc()); + // Check padding context and update cached lengths if needed + if (tril_padding_context::is_set()) + { + auto new_lengths = tril_padding_context::get_all_lengths(); + if (new_lengths != cached_padding_lengths_) + { + cached_padding_lengths_ = new_lengths; + invalidate_mask(); + } + } + else if (!cached_padding_lengths_.empty()) + { + // Context was cleared, reset padding + cached_padding_lengths_.clear(); + invalidate_mask(); + } + check_mask(prev); tt::multiply(false, output, prev, binary_mask); if (diag_value != 0.0f) tt::add(1, output, 1, output_mask); } + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) { @@ -5630,6 +5781,15 @@ namespace dlib const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } + + void set_prefix_size(long n_prefix_size) + { + if (prefix_size != n_prefix_size) { + prefix_size = n_prefix_size; + invalidate_mask(); + } + } + long get_prefix_size() const { return prefix_size; } friend void serialize(const tril_& item, std::ostream& out) { @@ -5667,25 +5827,66 @@ namespace dlib return static_cast(num_) / static_cast(den_); } + void invalidate_mask() + { + binary_mask.set_size(0, 0, 0, 0); + output_mask.set_size(0, 0, 0, 0); + } + void check_mask(const tensor& t) { - if (!have_same_dimensions(binary_mask, t)) { + if (!have_same_dimensions(binary_mask, t)) + { binary_mask.copy_size(t); binary_mask = 1; - if (diag_value != 0.0f) { + + const bool use_output_mask = (diag_value != 0.0f); + if (use_output_mask) { output_mask.copy_size(t); output_mask = 0; - } - for (long s = 0; s < output_mask.num_samples(); ++s) + } + + const bool has_padding = !cached_padding_lengths_.empty(); + + for (long s = 0; s < t.num_samples(); ++s) { - for (long k = 0; k < output_mask.k(); ++k) + const long pad_len = has_padding && + s < static_cast(cached_padding_lengths_.size()) + ? cached_padding_lengths_[s] : 0; + + for (long k = 0; k < t.k(); ++k) { - for (long r = 0; r < output_mask.nr(); ++r) + for (long r = 0; r < t.nr(); ++r) { - for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c) + // Mask padding columns + for (long c = 0; c < pad_len; ++c) + { + const long idx = tensor_index(t, s, k, r, c); + binary_mask.host()[idx] = 0; + if (use_output_mask) + output_mask.host()[idx] = diag_value; + } + + // Mask future positions (causal) + const long causal_start = std::max({ r + diag + 1, prefix_size, pad_len }); + for (long c = causal_start; c < t.nc(); ++c) { - if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value; - binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0; + const long idx = tensor_index(t, s, k, r, c); + binary_mask.host()[idx] = 0; + if (use_output_mask) + output_mask.host()[idx] = diag_value; + } + + // Mask padding rows + if (r < pad_len) + { + for (long c = 0; c < t.nc(); ++c) + { + const long idx = tensor_index(t, s, k, r, c); + binary_mask.host()[idx] = 0; + if (use_output_mask) + output_mask.host()[idx] = diag_value; + } } } } @@ -5699,7 +5900,9 @@ namespace dlib resizable_tensor params; // unused resizable_tensor binary_mask, output_mask; long diag; + long prefix_size; float diag_value; + std::vector cached_padding_lengths_; }; template @@ -5742,8 +5945,7 @@ namespace dlib num_channels_(item.num_channels_), feature_dim_(item.feature_dim_), ponder_cost_(item.ponder_cost_), - avg_steps_(item.avg_steps_), - params(item.params), + avg_steps_(item.avg_steps_), halting_probs_(item.halting_probs_), cumulative_halting_(item.cumulative_halting_), remainders_(item.remainders_), @@ -5751,7 +5953,8 @@ namespace dlib logits_(item.logits_), grad_logits_(item.grad_logits_), input_cache_(item.input_cache_), - true_effective_weights_(item.true_effective_weights_) + true_effective_weights_(item.true_effective_weights_), + params(item.params) { } @@ -5770,8 +5973,7 @@ namespace dlib num_channels_ = item.num_channels_; feature_dim_ = item.feature_dim_; ponder_cost_ = item.ponder_cost_; - avg_steps_ = item.avg_steps_; - params = item.params; + avg_steps_ = item.avg_steps_; halting_probs_ = item.halting_probs_; cumulative_halting_ = item.cumulative_halting_; remainders_ = item.remainders_; @@ -5780,6 +5982,7 @@ namespace dlib grad_logits_ = item.grad_logits_; input_cache_ = item.input_cache_; true_effective_weights_ = item.true_effective_weights_; + params = item.params; return *this; } @@ -6077,9 +6280,6 @@ namespace dlib long num_channels_; long feature_dim_; - // Learnable parameters - resizable_tensor params; - // Working memory resizable_tensor halting_probs_; // p_t^n: Halting probabilities resizable_tensor cumulative_halting_; // h_t^n: Cumulative halting probabilities @@ -6093,6 +6293,9 @@ namespace dlib // Statistics for monitoring float ponder_cost_; // R(x): Current ponder cost float avg_steps_; // Average number of computation steps + + // Learnable parameters + resizable_tensor params; }; template @@ -6107,6 +6310,808 @@ namespace dlib template using act16 = add_layer, SUBNET>; // Deep version +// ---------------------------------------------------------------------------------------- + + // YaRN configuration structure + struct yarn_config + { + // Alpha controls overall intensity of scaling (typical ~1.0) + float alpha = 1.0f; + + // Beta controls curvature of scaling across head dimensions (typical 0.25..0.5) + float beta = 0.5f; + + // original_len is the context length used at training time + // If 0, it will be set to the first seq_len observed (common pattern) + long original_len = 0; + + // Enable/disable YaRN; if false, behavior is identical to classical RoPE + bool enabled = true; + }; + + class rotary_positional_embedding_ + { + public: + explicit rotary_positional_embedding_() : + seq_len(0), + d_head(0), + theta_base(10000.0f) + { + } + + rotary_positional_embedding_(const rotary_positional_embedding_& other) : + seq_len(other.seq_len), + d_head(other.d_head), + theta_base(other.theta_base), + cos_cache(other.cos_cache), + sin_cache(other.sin_cache), + yarn(other.yarn) + { + } + + rotary_positional_embedding_& operator=(const rotary_positional_embedding_& other) + { + if (this != &other) { + seq_len = other.seq_len; + d_head = other.d_head; + theta_base = other.theta_base; + cos_cache = other.cos_cache; + sin_cache = other.sin_cache; + yarn = other.yarn; + } + return *this; + } + + // Set base used to compute inverse frequencies (theta base > 0) + void set_theta_base(float base) + { + DLIB_CASSERT(base > 0, "Theta base must be positive"); + theta_base = base; + } + + float get_theta_base() const { return theta_base; } + long get_seq_len() const { return seq_len; } + long get_d_head() const { return d_head; } + + // Configure YaRN hyperparameters + void set_yarn_params(float alpha, float beta, long original_len = 0, bool enabled = true) + { + DLIB_CASSERT(alpha >= 0, "alpha must be non-negative"); + DLIB_CASSERT(beta >= 0, "beta must be non-negative"); + yarn.alpha = alpha; + yarn.beta = beta; + yarn.original_len = original_len; + yarn.enabled = enabled; + } + const yarn_config& get_yarn_config() const { return yarn; } + + template + void setup(const SUBNET& sub) + { + const tensor& input = sub.get_output(); + + // Expected input shape: (batch, num_heads, seq_len, d_head) + seq_len = input.nr(); + d_head = input.nc(); + + DLIB_CASSERT(d_head >= 2, "d_head must be at least 2 for rotation"); + DLIB_CASSERT(seq_len > 0, "seq_len must be positive"); + + // If original_len not set, treat the setup seq_len as the model's training length + if (yarn.original_len == 0) yarn.original_len = seq_len; + + // Precompute rotation angles and trigonometric values + compute_and_cache_trig_values(seq_len); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + const tensor& input = sub.get_output(); + + // Validate shape; we expect shape (batch, num_heads, seq_len, d_head) + const long in_seq_len = input.nr(); + const long in_d_head = input.nc(); + + DLIB_CASSERT(in_d_head >= 2, "d_head must be at least 2 for rotation"); + DLIB_CASSERT(in_seq_len > 0, "seq_len must be positive"); + + // If setup() was not called or the incoming sequence length changed from + // the cached seq_len (e.g. inference with a different context window), + // recompute trig caches for the current seq_len. + if (seq_len != in_seq_len || d_head != in_d_head + || cos_cache.size() == 0 || sin_cache.size() == 0) + { + // If we don't have a recorded original_len yet, set it here (first observed seq_len) + if (yarn.original_len == 0) yarn.original_len = in_seq_len; + + // Update internal dimensions and recompute caches targeted to in_seq_len + seq_len = in_seq_len; + d_head = in_d_head; + compute_and_cache_trig_values(seq_len); + } + + output.copy_size(input); + + // Copy input to output + tt::copy_tensor(false, output, 0, input, 0, input.k()); + + // Apply rotary embedding in-place + tt::apply_rotary_positional_embedding( + false, // forward pass + output, + cos_cache, + sin_cache + ); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + tensor& prev_grad = sub.get_gradient_input(); + + // Apply inverse rotation to gradients + resizable_tensor grad_output; + grad_output.copy_size(gradient_input); + tt::copy_tensor(false, grad_output, 0, gradient_input, 0, gradient_input.k()); + + tt::apply_rotary_positional_embedding( + true, // backward pass (inverse rotation) + grad_output, + cos_cache, + sin_cache + ); + + // Accumulate gradients + tt::copy_tensor(true, prev_grad, 0, grad_output, 0, grad_output.k()); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const rotary_positional_embedding_& item, std::ostream& out) + { + serialize("rope_", out); + serialize(item.theta_base, out); + serialize(item.cos_cache, out); + serialize(item.sin_cache, out); + + // yarn config + serialize(item.yarn.alpha, out); + serialize(item.yarn.beta, out); + serialize(item.yarn.original_len, out); + serialize(item.yarn.enabled, out); + } + + friend void deserialize(rotary_positional_embedding_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "rope_") + throw serialization_error("Unexpected version '" + version + + "' while deserializing rope_"); + + deserialize(item.theta_base, in); + deserialize(item.cos_cache, in); + deserialize(item.sin_cache, in); + + // yarn config + deserialize(item.yarn.alpha, in); + deserialize(item.yarn.beta, in); + deserialize(item.yarn.original_len, in); + deserialize(item.yarn.enabled, in); + } + + friend std::ostream& operator<<(std::ostream& out, const rotary_positional_embedding_& item) + { + out << "rope (theta_base=" << item.theta_base + << ", yarn.alpha=" << item.yarn.alpha + << ", yarn.beta=" << item.yarn.beta + << ", yarn.original_len=" << item.yarn.original_len + << ", yarn.enabled=" << (item.yarn.enabled ? "true" : "false") + << ")"; + return out; + } + + friend void to_xml(const rotary_positional_embedding_& item, std::ostream& out) + { + out << "\n"; + } + + inline dpoint map_input_to_output(const dpoint& p) const { return p; } + inline dpoint map_output_to_input(const dpoint& p) const { return p; } + + private: + // Compute and cache cosine/sine tables for target_seq_len + // This function uses YaRN scaling when yarn.enabled is true + void compute_and_cache_trig_values(long target_seq_len) + { + if (seq_len == 0 || d_head == 0) return; + + // Half the head dimension (we rotate pairs) + const long half_dim = d_head / 2; + + // Allocate cache tensors: shape (1, 1, seq_len, half_dim) + cos_cache.set_size(1, 1, seq_len, half_dim); + sin_cache.set_size(1, 1, seq_len, half_dim); + + // Compute on host side + float* cos_ptr = cos_cache.host(); + float* sin_ptr = sin_cache.host(); + + // Precompute inv_freq constant per dimension (independent of position) + // inv_freq_i = theta_base^(-2i/d_head) + std::vector inv_freq(half_dim); + for (long i = 0; i < half_dim; ++i) + inv_freq[i] = std::pow(theta_base, -2.0f * i / static_cast(d_head)); + + // Determine the training length to use for YaRN scaling + const long train_len = (yarn.original_len > 0) ? yarn.original_len : target_seq_len; + + // Compute cos/sin for each position and frequency index, using YaRN if enabled + for (long pos = 0; pos < target_seq_len; ++pos) + { + for (long i = 0; i < half_dim; ++i) + { + // Base angle: pos * inv_freq[i] + float pos_scaled = static_cast(pos); + + if (yarn.enabled) + { + // Compute dimension-normalized index in [0,1] + const float dim_norm = static_cast(i) / static_cast(half_dim); + + // exponent = alpha * dim_norm^beta + // Note: we use half_dim for normalization so higher-frequency dims get smaller exponent + const float exponent = yarn.alpha * std::pow(dim_norm, yarn.beta); + + // scale = (target_len / train_len)^exponent + // This allows small-dim (low freq) to scale less than high-dim if desired + const float ratio = static_cast(target_seq_len) / static_cast(train_len); + const float scale = std::pow(ratio, exponent); + + // Scaled position used to compute the angle + pos_scaled = static_cast(pos) * scale; + } + + const float angle = pos_scaled * inv_freq[i]; + + const long idx = pos * half_dim + i; + cos_ptr[idx] = std::cos(angle); + sin_ptr[idx] = std::sin(angle); + } + } + } + + // Configuration + long seq_len; + long d_head; + float theta_base; + + // Precomputed trigonometric values + // Shape: (1, 1, seq_len, d_head/2) + resizable_tensor cos_cache; + resizable_tensor sin_cache; + + // YaRN configuration + yarn_config yarn; + + // No trainable parameters + resizable_tensor params; + }; + + template + using rope = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + long patch_size_, + long embedding_dim_, + long use_class_token_, + long use_position_embeddings_ + > + class patch_embeddings_ + { + static_assert(patch_size_ > 0, "Patch size must be positive"); + static_assert(embedding_dim_ > 0, "Embedding dimension must be positive"); + static_assert(use_class_token_ == 0 || use_class_token_ == 1, + "use_class_token must be 0 or 1"); + static_assert(use_position_embeddings_ == 0 || use_position_embeddings_ == 1, + "use_position_embeddings must be 0 or 1"); + + public: + + patch_embeddings_() : + in_channels(0), + num_patches_h(0), + num_patches_w(0), + cached_input_h(0), + cached_input_w(0), + cached_input_k(0), + learning_rate_multiplier(1.0) + { + } + + patch_embeddings_(const patch_embeddings_& other) : + in_channels(other.in_channels), + num_patches_h(other.num_patches_h), + num_patches_w(other.num_patches_w), + cached_input_h(other.cached_input_h), + cached_input_w(other.cached_input_w), + cached_input_k(other.cached_input_k), + learning_rate_multiplier(other.learning_rate_multiplier), + params(other.params), + filters_alias(other.filters_alias), + biases_alias(other.biases_alias), + pos_embed_alias(other.pos_embed_alias), + cls_token_alias(other.cls_token_alias) + { + } + + patch_embeddings_& operator=(const patch_embeddings_& other) + { + if (this != &other) { + in_channels = other.in_channels; + num_patches_h = other.num_patches_h; + num_patches_w = other.num_patches_w; + cached_input_h = other.cached_input_h; + cached_input_w = other.cached_input_w; + cached_input_k = other.cached_input_k; + learning_rate_multiplier = other.learning_rate_multiplier; + params = other.params; + filters_alias = other.filters_alias; + biases_alias = other.biases_alias; + pos_embed_alias = other.pos_embed_alias; + cls_token_alias = other.cls_token_alias; + // Note: conv_op is non-copyable and stateless, will be re-setup on forward() + } + return *this; + } + + long get_patch_size() const { return patch_size_; } + long get_embedding_dim() const { return embedding_dim_; } + long uses_class_token() const { return use_class_token_; } + long uses_position_embeddings() const { return use_position_embeddings_; } + long get_num_patches() const { return num_patches_h * num_patches_w; } + + double get_learning_rate_multiplier() const { return learning_rate_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + + template + void setup(const SUBNET& sub) + { + const tensor& input = sub.get_output(); + in_channels = input.k(); + + DLIB_CASSERT(input.nr() % patch_size_ == 0, + "Image height must be divisible by patch size. Got height=" << input.nr() + << ", patch_size=" << patch_size_); + DLIB_CASSERT(input.nc() % patch_size_ == 0, + "Image width must be divisible by patch size. Got width=" << input.nc() + << ", patch_size=" << patch_size_); + + num_patches_h = input.nr() / patch_size_; + num_patches_w = input.nc() / patch_size_; + const long num_patches = num_patches_h * num_patches_w; + const long sequence_length = num_patches + use_class_token_; + + // Calculate total parameter size: + // - projection_filters: embedding_dim * in_channels * patch_size * patch_size + // - projection_biases: embedding_dim + // - position_embeddings (optional): sequence_length * embedding_dim + // - class_token (optional): embedding_dim + const long filter_size = embedding_dim_ * in_channels * patch_size_ * patch_size_; + const long bias_size = embedding_dim_; + const long pos_embed_size = use_position_embeddings_ ? sequence_length * embedding_dim_ : 0; + const long cls_token_size = use_class_token_ ? embedding_dim_ : 0; + const long total_params = filter_size + bias_size + pos_embed_size + cls_token_size; + + // Allocate all parameters in a single contiguous tensor + params.set_size(total_params); + + // Setup alias tensors for accessing parameter regions + filters_alias = alias_tensor(embedding_dim_, in_channels, patch_size_, patch_size_); + biases_alias = alias_tensor(1, embedding_dim_, 1, 1); + + if (use_position_embeddings_) { + pos_embed_alias = alias_tensor(1, 1, sequence_length, embedding_dim_); + } + if (use_class_token_) { + cls_token_alias = alias_tensor(1, 1, 1, embedding_dim_); + } + + // Initialize parameters with Xavier/Glorot for filters + tt::tensor_rand rnd; + const float fan_in = static_cast(in_channels * patch_size_ * patch_size_); + const float fan_out = static_cast(embedding_dim_); + const float xavier_stddev = std::sqrt(2.0f / (fan_in + fan_out)); + + // Initialize filter weights + auto filt = filters_alias(params, 0); + rnd.fill_gaussian(filt, 0.0f, xavier_stddev); + + // Initialize biases to zero + auto bias = biases_alias(params, filters_alias.size()); + bias = 0; + + // Initialize position embeddings if enabled + if (use_position_embeddings_) { + auto pos = pos_embed_alias(params, filters_alias.size() + biases_alias.size()); + rnd.fill_gaussian(pos, 0.0f, 0.02f); + } + + // Initialize class token if enabled + if (use_class_token_) { + long cls_offset = filters_alias.size() + biases_alias.size(); + if (use_position_embeddings_) cls_offset += pos_embed_alias.size(); + auto cls = cls_token_alias(params, cls_offset); + rnd.fill_gaussian(cls, 0.0f, 0.02f); + } + + // Cache input dimensions and setup convolution + cached_input_h = input.nr(); + cached_input_w = input.nc(); + cached_input_k = input.k(); + conv_op.setup(input, filt, patch_size_, patch_size_, 0, 0); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + const tensor& input = sub.get_output(); + const long batch_size = input.num_samples(); + + // Re-setup convolution if input spatial dimensions changed + if (input.nr() != cached_input_h || + input.nc() != cached_input_w || + input.k() != cached_input_k || + params.size() == 0) + { + DLIB_CASSERT(input.nr() % patch_size_ == 0, + "Image height must be divisible by patch size. Got height=" << input.nr() + << ", patch_size=" << patch_size_); + DLIB_CASSERT(input.nc() % patch_size_ == 0, + "Image width must be divisible by patch size. Got width=" << input.nc() + << ", patch_size=" << patch_size_); + + cached_input_h = input.nr(); + cached_input_w = input.nc(); + cached_input_k = input.k(); + num_patches_h = input.nr() / patch_size_; + num_patches_w = input.nc() / patch_size_; + } + + const long num_patches = num_patches_h * num_patches_w; + const long sequence_length = num_patches + use_class_token_; + + // Get parameter aliases + auto filt = filters_alias(params, 0); + auto bias = biases_alias(params, filters_alias.size()); + conv_op.setup(input, filt, patch_size_, patch_size_, 0, 0); + + // Step 1: apply convolution (patch extraction + projection) + conv_output.set_size(batch_size, embedding_dim_, num_patches_h, num_patches_w); + conv_op(false, conv_output, input, filt); + + // Add bias using broadcasting + tt::add(1.0f, conv_output, 1.0f, bias); + + // Step 2: reshape from (batch, embed, H/P, W/P) to (batch, 1, num_patches, embed) + patch_sequence.set_size(batch_size, 1, num_patches, embedding_dim_); + reshape_conv_to_sequence(conv_output, patch_sequence); + + // Step 3: prepend class token if enabled + if (use_class_token_) { + long cls_offset = filters_alias.size() + biases_alias.size(); + if (use_position_embeddings_) cls_offset += pos_embed_alias.size(); + auto cls = cls_token_alias(params, cls_offset); + + output.set_size(batch_size, 1, sequence_length, embedding_dim_); + prepend_class_token(patch_sequence, cls, output); + } + else { + output.copy_size(patch_sequence); + tt::copy_tensor(false, output, 0, patch_sequence, 0, patch_sequence.k()); + } + + // Step 4: add position embeddings if enabled + if (use_position_embeddings_) { + auto pos = pos_embed_alias(params, filters_alias.size() + biases_alias.size()); + tt::add(1.0f, output, 1.0f, pos); + } + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) + { + const long batch_size = gradient_input.num_samples(); + const long num_patches = num_patches_h * num_patches_w; + + // Get parameter aliases from params + auto filt = filters_alias(params, 0); + + // Get gradient aliases from params_grad + auto filt_grad = filters_alias(params_grad, 0); + auto bias_grad = biases_alias(params_grad, filters_alias.size()); + + // Step 1: gradient for position embeddings (if enabled) + if (use_position_embeddings_) { + auto pos_grad = pos_embed_alias(params_grad, filters_alias.size() + biases_alias.size()); + // Zero out and accumulate across batch + pos_grad = 0; + sum_across_batch_to_alias(gradient_input, pos_grad); + tt::affine_transform(pos_grad, pos_grad, static_cast(learning_rate_multiplier)); + } + + // Step 2: split gradient between class token and patches + grad_patch_sequence.set_size(batch_size, 1, num_patches, embedding_dim_); + + if (use_class_token_) { + long cls_offset = filters_alias.size() + biases_alias.size(); + if (use_position_embeddings_) cls_offset += pos_embed_alias.size(); + auto cls_grad = cls_token_alias(params_grad, cls_offset); + + cls_grad = 0; + split_class_token_gradient_to_alias(gradient_input, cls_grad, grad_patch_sequence); + tt::affine_transform(cls_grad, cls_grad, static_cast(learning_rate_multiplier)); + } + else { + tt::copy_tensor(false, grad_patch_sequence, 0, gradient_input, 0, gradient_input.k()); + } + + // Step 3: reshape gradient from sequence back to spatial format + grad_conv_output.set_size(batch_size, embedding_dim_, num_patches_h, num_patches_w); + reshape_sequence_to_conv(grad_patch_sequence, grad_conv_output); + + // Step 4: gradient for projection bias + bias_grad = 0; + tt::assign_conv_bias_gradient(bias_grad, grad_conv_output); + tt::affine_transform(bias_grad, bias_grad, static_cast(learning_rate_multiplier)); + + // Step 5: gradient for convolution filters + const tensor& input = sub.get_output(); + filt_grad = 0; + conv_op.get_gradient_for_filters(false, grad_conv_output, input, filt_grad); + tt::affine_transform(filt_grad, filt_grad, static_cast(learning_rate_multiplier)); + + // Step 6: gradient for input (accumulate) + tensor& grad_input = sub.get_gradient_input(); + conv_op.get_gradient_for_data(true, grad_conv_output, filt, grad_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const patch_embeddings_& item, std::ostream& out) + { + serialize("patch_embeddings_", out); + serialize(item.in_channels, out); + serialize(item.num_patches_h, out); + serialize(item.num_patches_w, out); + serialize(item.cached_input_h, out); + serialize(item.cached_input_w, out); + serialize(item.cached_input_k, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.params, out); + serialize(item.filters_alias, out); + serialize(item.biases_alias, out); + if (use_position_embeddings_) + serialize(item.pos_embed_alias, out); + if (use_class_token_) + serialize(item.cls_token_alias, out); + } + + friend void deserialize(patch_embeddings_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "patch_embeddings_") + throw serialization_error("Unexpected version '" + version + + "' found while deserializing patch_embeddings_."); + + deserialize(item.in_channels, in); + deserialize(item.num_patches_h, in); + deserialize(item.num_patches_w, in); + deserialize(item.cached_input_h, in); + deserialize(item.cached_input_w, in); + deserialize(item.cached_input_k, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.params, in); + deserialize(item.filters_alias, in); + deserialize(item.biases_alias, in); + if (use_position_embeddings_) + deserialize(item.pos_embed_alias, in); + if (use_class_token_) + deserialize(item.cls_token_alias, in); + } + + friend std::ostream& operator<<(std::ostream& out, const patch_embeddings_& item) + { + out << "patch_embeddings (patch_size=" << patch_size_ + << ", embedding_dim=" << embedding_dim_ + << ", num_patches=" << item.get_num_patches() + << ", use_class_token=" << use_class_token_ + << ", use_position_embeddings=" << use_position_embeddings_ + << ") learning_rate_mult=" << item.learning_rate_multiplier; + return out; + } + + friend void to_xml(const patch_embeddings_& item, std::ostream& out) + { + out << "\n"; + } + + private: + + // Reshape conv output (batch, embed, H/P, W/P) to sequence (batch, 1, num_patches, embed) + void reshape_conv_to_sequence(const tensor& src, tensor& dest) + { + const long batch_size = src.num_samples(); + const long embed_dim = src.k(); + const long h = src.nr(); + const long w = src.nc(); + const long num_patches = h * w; + + const float* src_ptr = src.host(); + float* dest_ptr = dest.host_write_only(); + + // src[n, d, i, j] -> dest[n, 0, i*w + j, d] + for (long n = 0; n < batch_size; ++n) { + for (long i = 0; i < h; ++i) { + for (long j = 0; j < w; ++j) { + const long patch_idx = i * w + j; + for (long d = 0; d < embed_dim; ++d) { + const long src_idx = ((n * embed_dim + d) * h + i) * w + j; + const long dest_idx = (n * num_patches + patch_idx) * embed_dim + d; + dest_ptr[dest_idx] = src_ptr[src_idx]; + } + } + } + } + } + + // Reshape sequence (batch, 1, num_patches, embed) to conv format (batch, embed, H/P, W/P) + void reshape_sequence_to_conv(const tensor& src, tensor& dest) + { + const long batch_size = src.num_samples(); + const long num_patches = src.nr(); + const long embed_dim = src.nc(); + const long h = dest.nr(); + const long w = dest.nc(); + + const float* src_ptr = src.host(); + float* dest_ptr = dest.host_write_only(); + + // src[n, 0, i*w + j, d] -> dest[n, d, i, j] + for (long n = 0; n < batch_size; ++n) { + for (long i = 0; i < h; ++i) { + for (long j = 0; j < w; ++j) { + const long patch_idx = i * w + j; + for (long d = 0; d < embed_dim; ++d) { + const long src_idx = (n * num_patches + patch_idx) * embed_dim + d; + const long dest_idx = ((n * embed_dim + d) * h + i) * w + j; + dest_ptr[dest_idx] = src_ptr[src_idx]; + } + } + } + } + } + + // Prepend class token to patch sequence + void prepend_class_token(const tensor& patches, const tensor& cls_token, tensor& output) + { + const long batch_size = patches.num_samples(); + const long num_patches = patches.nr(); + const long embed_dim = patches.nc(); + const long seq_len = num_patches + 1; + + const float* patches_ptr = patches.host(); + const float* cls_ptr = cls_token.host(); + float* out_ptr = output.host_write_only(); + + for (long n = 0; n < batch_size; ++n) { + // Copy class token to position 0 + for (long d = 0; d < embed_dim; ++d) { + out_ptr[n * seq_len * embed_dim + d] = cls_ptr[d]; + } + // Copy patch embeddings to positions 1..seq_len-1 + for (long s = 0; s < num_patches; ++s) { + for (long d = 0; d < embed_dim; ++d) { + out_ptr[(n * seq_len + s + 1) * embed_dim + d] = + patches_ptr[(n * num_patches + s) * embed_dim + d]; + } + } + } + } + + // Split gradient between class token and patches (writes to alias) + void split_class_token_gradient_to_alias(const tensor& grad_in, tensor& grad_cls, tensor& grad_patches) + { + const long batch_size = grad_in.num_samples(); + const long seq_len = grad_in.nr(); + const long embed_dim = grad_in.nc(); + const long num_patches = seq_len - 1; + + const float* grad_in_ptr = grad_in.host(); + float* grad_cls_ptr = grad_cls.host(); + float* grad_patches_ptr = grad_patches.host_write_only(); + + for (long n = 0; n < batch_size; ++n) { + // Accumulate gradient for class token across batch + for (long d = 0; d < embed_dim; ++d) { + grad_cls_ptr[d] += grad_in_ptr[n * seq_len * embed_dim + d]; + } + // Copy gradient for patches + for (long s = 0; s < num_patches; ++s) { + for (long d = 0; d < embed_dim; ++d) { + grad_patches_ptr[(n * num_patches + s) * embed_dim + d] = + grad_in_ptr[(n * seq_len + s + 1) * embed_dim + d]; + } + } + } + } + + // Sum tensor across batch dimension (writes to alias) + void sum_across_batch_to_alias(const tensor& src, tensor& dest) + { + const long batch_size = src.num_samples(); + const long seq_len = src.nr(); + const long embed_dim = src.nc(); + + const float* src_ptr = src.host(); + float* dest_ptr = dest.host(); + + for (long n = 0; n < batch_size; ++n) { + for (long s = 0; s < seq_len; ++s) { + for (long d = 0; d < embed_dim; ++d) { + dest_ptr[s * embed_dim + d] += src_ptr[(n * seq_len + s) * embed_dim + d]; + } + } + } + } + + // Configuration + long in_channels; + long num_patches_h, num_patches_w; + long cached_input_h, cached_input_w, cached_input_k; + double learning_rate_multiplier; + + // All learnable parameters stored in a single tensor + resizable_tensor params; + + // Alias tensors for accessing parameter regions + alias_tensor filters_alias; // (embedding_dim, in_channels, patch_size, patch_size) + alias_tensor biases_alias; // (1, embedding_dim, 1, 1) + alias_tensor pos_embed_alias; // (1, 1, sequence_length, embedding_dim) if enabled + alias_tensor cls_token_alias; // (1, 1, 1, embedding_dim) if enabled + + // Intermediate tensors for forward/backward + resizable_tensor conv_output; + resizable_tensor patch_sequence; + resizable_tensor grad_conv_output; + resizable_tensor grad_patch_sequence; + + // Convolution operation + tt::tensor_conv conv_op; + }; + + template + using patch_embeddings = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index cbfe81ad66..3222052ae3 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -4543,6 +4543,81 @@ namespace dlib > using embeddings = add_layer, SUBNET>; +// ---------------------------------------------------------------------------------------- + + class tril_padding_context + { + /*! + WHAT THIS OBJECT REPRESENTS + This class provides a shared context for communicating padding information + to tril_ layers during forward passes. It solves the problem of nested + architectures where tril_ layers cannot directly access the input sequence. + The context stores per-sample padding lengths that are computed once + before each forward pass and consulted by all tril_ layers. + + THREAD SAFETY + All methods are thread-safe through internal mutex protection. + + TYPICAL USAGE + // Before forward pass: + tril_padding_context::set(input_tensor, padding_token); + // Or from pre-computed lengths: + tril_padding_context::set_from_lengths(padding_lengths); + !*/ + public: + static void set(const tensor& input_tokens, long padding_token); + /*! + ensures + - Computes and stores padding lengths by scanning input_tokens + - For each sample, counts leading tokens equal to padding_token + - #is_set() == true (if padding_token >= 0) + - If padding_token < 0, clears the context instead + !*/ + + static void set_from_lengths(const std::vector& lengths); + /*! + ensures + - Stores the provided padding lengths directly + - #is_set() == true + - #get_padding_length(i) == lengths[i] for all valid i + !*/ + + static void set_uniform(long padding_length, long batch_size); + /*! + ensures + - Sets uniform padding length for all samples + - #is_set() == true + - #get_padding_length(i) == padding_length for i in [0, batch_size) + !*/ + + static void clear(); + /*! + ensures + - #is_set() == false + - Releases stored padding lengths + !*/ + + static long get_padding_length(long sample_idx); + /*! + ensures + - If is_set() and sample_idx is valid: returns padding length for that sample + - Otherwise: returns 0 + !*/ + + static std::vector get_all_lengths(); + /*! + ensures + - Returns a copy of all stored padding lengths + - Returns empty vector if !is_set() + !*/ + + static bool is_set(); + /*! + ensures + - Returns true if padding context has been initialized + !*/ + }; + // ---------------------------------------------------------------------------------------- struct neg_infinity_tag {}; @@ -4665,6 +4740,25 @@ namespace dlib - Returns the parameters of this layer. !*/ + void set_prefix_size(long n_prefix_size); + /*! + ensures + - #get_prefix_size() == n_prefix_size + - Invalidates cached mask if value changed + !*/ + long get_prefix_size() const; + + void set_padding_token(long token_id); + /*! + ensures + - #get_padding_token() == token_id + - If token_id >= 0: enables automatic padding context usage + - If token_id < 0: disables padding masking + !*/ + long get_padding_token() const; + + bool uses_padding_context() const; + friend void serialize(const tril_& item, std::ostream& out); /*! ensures @@ -4818,6 +4912,343 @@ namespace dlib template using act16 = add_layer, SUBNET>; +// ---------------------------------------------------------------------------------------- + + class rotary_positional_embedding_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements a rotary positional embedding (RoPE) layer for neural + networks, as described in "RoFormer: Enhanced Transformer with Rotary Position + Embedding" by Su et al. + + Rotary positional embeddings encode positional information by rotating pairs + of feature dimensions according to their position in the sequence. This method + provides better relative position encoding compared to traditional learned + positional embeddings, particularly for sequence-to-sequence tasks. + + The transformation is applied as a rotation matrix in 2D subspaces: + For each pair of dimensions (i, i+1) at position pos: + [x'_i ] [cos(θ) -sin(θ)] [x_i ] + [x'_i+1] = [sin(θ) cos(θ)] [x_i+1] + + where θ(pos, i) = pos * base^(-2i/d_head) and base is typically 10000. + + DYNAMIC SEQUENCE LENGTH SUPPORT: + This layer automatically adapts to different sequence lengths during + inference. When a sequence of different length is processed, the rotation + angles are recomputed on-the-fly. This allows models trained on shorter + sequences to handle longer contexts at inference time. + + YARN EXTENSION (OPTIONAL): + Optionally supports YaRN (Yet another RoPE extensioN) scaling for + improved extrapolation to longer sequences than seen during training. + YaRN applies frequency-dependent scaling that preserves low-frequency + information while adapting high-frequency components. Enable via + set_yarn_params(). + + This layer has no trainable parameters. All rotation angles are precomputed + during setup based on the sequence length and head dimension. + !*/ + + public: + + rotary_positional_embedding_( + ); + /*! + ensures + - #get_theta_base() == 10000.0 + - #get_seq_len() == 0 + - #get_d_head() == 0 + !*/ + + rotary_positional_embedding_( + const rotary_positional_embedding_& item + ); + /*! + ensures + - Creates a copy of item + - #get_theta_base() == item.get_theta_base() + - #get_seq_len() == item.get_seq_len() + - #get_d_head() == item.get_d_head() + - All precomputed trigonometric caches are copied + !*/ + + rotary_positional_embedding_& operator=( + const rotary_positional_embedding_& item + ); + /*! + ensures + - Assigns item to *this + - returns #*this + !*/ + + void set_theta_base( + float base + ); + /*! + requires + - base > 0 + ensures + - #get_theta_base() == base + - Sets the base frequency for computing rotation angles + - Higher values result in slower rotation with increasing position + - Common values: 10000 (default), 500000 (for longer sequences) + - This should be called before setup() to take effect + !*/ + + float get_theta_base( + ) const; + /*! + ensures + - Returns the base frequency used for rotation angle computation + !*/ + + long get_seq_len( + ) const; + /*! + ensures + - Returns the most recent sequence length processed by this layer + - Returns 0 if forward() has not been called yet + - Note: this value may change between forward() calls if sequences + of different lengths are processed + !*/ + + long get_d_head( + ) const; + /*! + ensures + - Returns the head dimension that this layer was configured for + - Returns 0 if forward() has not been called yet + - This value remains constant once set (determined by network architecture) + !*/ + + void set_yarn_params( + float alpha, + float beta, + long original_len = 0, + bool enabled = true + ); + /*! + requires + - alpha >= 0 + - beta >= 0 + ensures + - Configures YaRN (Yet another RoPE extensioN) scaling parameters + - alpha controls the overall intensity of scaling (typical: 1.0) + - beta controls the curvature of scaling across frequency dimensions (typical: 0.25 to 0.5) + - original_len is the sequence length used during training + If 0, it will be set to the first sequence length observed in forward() + - enabled determines whether YaRN scaling is active + - YaRN allows better extrapolation to sequence lengths longer than training + - Should be called before forward() to take effect + !*/ + + const yarn_config& get_yarn_config( + ) const; + /*! + ensures + - Returns the current YaRN configuration + !*/ + + template + void setup( + const SUBNET& sub + ); + /*! + requires + - sub.get_output().nr() > 0 + - sub.get_output().nc() >= 2 + ensures + - Initializes this layer based on the input dimensions + - #get_seq_len() == sub.get_output().nr() + - #get_d_head() == sub.get_output().nc() + - Precomputes and caches all cosine and sine values for the rotation + angles based on the sequence length and head dimension + - The cos_cache and sin_cache tensors are allocated with shape: + (1, 1, seq_len, d_head/2) + - If d_head is odd, only (d_head-1) dimensions will be rotated + - If YaRN is enabled and original_len is 0, the observed sequence + length is recorded as the training length for YaRN scaling + !*/ + + template + void forward( + const SUBNET& sub, + resizable_tensor& output + ); + /*! + requires + - sub.get_output().nc() >= 2 + - sub.get_output().nr() > 0 + ensures + - Applies rotary positional embeddings to the input + - #output has the same dimensions as sub.get_output() + - If the input sequence length differs from get_seq_len(), or if + this is the first forward pass after deserialization, the rotation + angles are automatically recomputed for the current sequence length. + - For each position pos and dimension pair (i, i+1): + output[pos,i] = input[pos,i] * cos(θ_pos,i/2) - input[pos,i+1] * sin(θ_pos,i/2) + output[pos,i+1] = input[pos,i] * sin(θ_pos,i/2) + input[pos,i+1] * cos(θ_pos,i/2) + - The rotation preserves the magnitude of feature vectors while encoding + relative positional information + - If d_head is odd, the last dimension is copied without rotation + - Expected input shape: (batch_size, num_heads, seq_len, d_head) + - YaRN scaling is applied if enabled via set_yarn_params() + !*/ + + template + void backward( + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ); + /*! + requires + - setup() has been called + - gradient_input has the same dimensions as the output from forward() + ensures + - Computes gradients with respect to the input + - Applies the inverse rotation to gradient_input + - The inverse rotation is: + grad_input[pos,i] = grad_out[pos,i] * cos(θ) + grad_out[pos,i+1] * sin(θ) + grad_input[pos,i+1] = -grad_out[pos,i] * sin(θ) + grad_out[pos,i+1] * cos(θ) + - Accumulated gradients are added to sub.get_gradient_input() + - params_grad is not used (this layer has no trainable parameters) + !*/ + + const tensor& get_layer_params() const; + tensor& get_layer_params(); + inline dpoint map_input_to_output(const dpoint& p) const; + inline dpoint map_output_to_input(const dpoint& p) const; + + friend void serialize(const rotary_positional_embedding_& item, std::ostream& out); + friend void deserialize(rotary_positional_embedding_& item, std::istream& in); + friend std::ostream& operator<<(std::ostream& out, const rotary_positional_embedding_& item); + friend void to_xml(const rotary_positional_embedding_& item, std::ostream& out); + /*! + provides serialization support and output operators + !*/ + + }; + + template + using rope = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + long patch_size, + long embedding_dim, + long use_class_token, + long use_position_embeddings + > + class patch_embeddings_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This layer implements patch embeddings for Vision Transformers (ViT), as described + in "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" + (Dosovitskiy et al., 2021). + + The layer performs the following operations: + 1. Convolves the input image with filters of size (patch_size x patch_size) + and stride (patch_size) to create a set of projected patches + 2. Reshapes the resulting spatial feature maps into a sequence of vectors + 3. If use_class_token == 1, prepends a learnable 'class token' to the sequence + 4. If use_position_embeddings == 1, adds learnable position embeddings to + the entire sequence + + The input to this layer is a 4D tensor of shape: + (batch_size, in_channels, height, width) + + The output is a 4D tensor representing a sequence: + (batch_size, 1, sequence_length, embedding_dim) + where sequence_length is (height/patch_size * width/patch_size) + use_class_token + + TEMPLATE PARAMETERS + - patch_size: the side length of the square patches (e.g., 16) + - embedding_dim: the dimensionality of the resulting embeddings (e.g., 768) + - use_class_token: set to 1 to prepend a learnable CLS token, 0 otherwise + - use_position_embeddings: set to 1 to add learnable absolute position + embeddings to the sequence, 0 otherwise + !*/ + + public: + + patch_embeddings_( + ); + /*! + ensures + - #get_patch_size() == patch_size + - #get_embedding_dim() == embedding_dim + - #uses_class_token() == use_class_token + - #uses_position_embeddings() == use_position_embeddings + - #get_learning_rate_multiplier() == 1 + !*/ + + long get_patch_size() const; + long get_embedding_dim() const; + long uses_class_token() const; + long uses_position_embeddings() const; + + double get_learning_rate_multiplier() const; + void set_learning_rate_multiplier(double val); + /*! + ensures + - #get_learning_rate_multiplier() == val + !*/ + + template + void setup( + const SUBNET& sub + ); + /*! + requires + - sub.get_output().nr() % patch_size == 0 + - sub.get_output().nc() % patch_size == 0 + ensures + - Initialized the learned parameters: + - projection filters: (embedding_dim, in_channels, patch_size, patch_size) + - projection biases: (embedding_dim) + - (optional) class token and position embeddings. + - Parameters are initialized using Xavier/Glorot initialization for filters + and zero/truncated normal for other components. + !*/ + + template + void forward( + const SUBNET& sub, + resizable_tensor& output + ); + /*! + requires + - setup(sub) has been called. + ensures + - #output.num_samples() == sub.get_output().num_samples() + - #output.k() == 1 + - #output.nr() == (sub.get_output().nr()/patch_size * sub.get_output().nc()/patch_size) + use_class_token + - #output.nc() == embedding_dim + !*/ + + template + void backward( + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ); + /*! + requires + - gradient_input has the same dimensions as the output of forward() + ensures + - Computes the gradient of the loss with respect to the input of this + layer and adds it to #sub.get_gradient_input() + !*/ + }; + + template + using patch_embeddings = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/dnn/loss.h b/dlib/dnn/loss.h index 36b37a2956..823f2c2352 100644 --- a/dlib/dnn/loss.h +++ b/dlib/dnn/loss.h @@ -911,6 +911,124 @@ namespace dlib using loss_multibinary_log = add_loss_layer; // ---------------------------------------------------------------------------------------- + + class loss_cross_entropy_per_logit_ + { + public: + typedef unsigned long training_label_type; + typedef unsigned long output_label_type; + + loss_cross_entropy_per_logit_() : ignore_index_(-1) {} + + void set_ignore_index(long idx) { ignore_index_ = idx; } + long get_ignore_index() const { return ignore_index_; } + + template + void to_label( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const + { + const tensor& output_tensor = sub.get_output(); + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(output_tensor.k() == 1); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + + const long batch_size = output_tensor.num_samples(); + const long seq_len = output_tensor.nr(); + const long vocab_size = output_tensor.nc(); + + // Note that output_tensor.nc() should match the vocabulary size + const float* out_data = output_tensor.host(); + + for (long i = 0; i < batch_size; ++i, ++iter) + { + // For each sample, find the class with the maximum logit at the last + // position of the sequence (position seq_len-1). This is the predicted + // next token for autoregressive generation + long max_idx = 0; + float max_val = out_data[tensor_index(output_tensor, i, 0, seq_len - 1, 0)]; + for (long c = 1; c < vocab_size; ++c) + { + const float val = out_data[tensor_index(output_tensor, i, 0, seq_len - 1, c)]; + if (val > max_val) + { + max_val = val; + max_idx = c; + } + } + *iter = static_cast(max_idx); + } + } + + template + double compute_loss_value_and_gradient( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const + { + const tensor& output_tensor = sub.get_output(); + tensor& grad = sub.get_gradient_input(); + + DLIB_CASSERT(sub.sample_expansion_factor() == 1); + DLIB_CASSERT(input_tensor.num_samples() != 0); + DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); + DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); + DLIB_CASSERT(output_tensor.nr() == grad.nr() && + output_tensor.nc() == grad.nc() && + output_tensor.k() == grad.k()); + + double loss = 0.0; +#ifdef DLIB_USE_CUDA + cuda_compute(truth, input_tensor, output_tensor, grad, loss, ignore_index_); +#else + cpu_compute(truth, input_tensor, output_tensor, grad, loss, ignore_index_); +#endif + return loss; + } + + friend void serialize(const loss_cross_entropy_per_logit_& item, std::ostream& out) + { + serialize("loss_cross_entropy_per_logit_", out); + serialize(item.ignore_index_, out); + } + + friend void deserialize(loss_cross_entropy_per_logit_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "loss_cross_entropy_per_logit_") + throw serialization_error("Unexpected version found while deserializing dlib::loss_cross_entropy_per_logit_."); + deserialize(item.ignore_index_, in); + } + + friend std::ostream& operator<<(std::ostream& out, const loss_cross_entropy_per_logit_& item) + { + out << "loss_cross_entropy_per_logit"; + out << " (ignore_index=" << item.ignore_index_ << ")"; + return out; + } + + friend void to_xml(const loss_cross_entropy_per_logit_& item, std::ostream& out) + { + out << "\n"; + } + + private: + long ignore_index_; + +#ifdef DLIB_USE_CUDA + cuda::compute_loss_cross_entropy_per_logit cuda_compute; +#else + cpu::compute_loss_cross_entropy_per_logit cpu_compute; +#endif + }; + + template + using loss_cross_entropy_per_logit = add_loss_layer; + // ---------------------------------------------------------------------------------------- enum class use_image_pyramid : uint8_t diff --git a/dlib/dnn/loss_abstract.h b/dlib/dnn/loss_abstract.h index 9ddfb6a4a2..54d7413e55 100644 --- a/dlib/dnn/loss_abstract.h +++ b/dlib/dnn/loss_abstract.h @@ -810,6 +810,134 @@ namespace dlib using loss_multibinary_log = add_loss_layer; // ---------------------------------------------------------------------------------------- + + class loss_cross_entropy_per_logit_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This loss layer implements cross-entropy loss for next token prediction + in transformer-based language models. Unlike loss_multiclass_log_ which + requires the output to be flattened through an fc layer, this loss function + is designed to work directly with sequence outputs from linear layers. + + This loss expects the network to produce an output tensor with these dimensions: + - output_tensor.num_samples() == batch size + - output_tensor.k() == 1 (always) + - output_tensor.nr() == sequence length + - output_tensor.nc() == vocabulary size (number of classes) + + The key feature of this loss is that it computes the cross-entropy loss + only on the LAST position of each sequence (position nr()-1), which is + the standard approach for autoregressive next token prediction. + + TYPICAL NETWORK ARCHITECTURE: + using net_type = loss_cross_entropy_per_logit + linear> + > + > + > + >; + + TRAINING LABELS: + - Label type: unsigned long (scalar value per sample) + - Each label represents the target token ID: 0 <= label < vocab_size + - One label per sequence (predicting the token after the last position) + + LOSS COMPUTATION: + For each sample i in the batch: + 1. Extract logits at position [i, 0, seq_len-1, :] + 2. Compute softmax: probs = softmax(logits) + 3. Compute loss: loss += -log(probs[target_label]) + + Final loss = sum(all_losses) / batch_size + !*/ + + public: + typedef unsigned long training_label_type; + typedef unsigned long output_label_type; + + template + void to_label( + const tensor& input_tensor, + const SUB_TYPE& sub, + label_iterator iter + ) const; + /*! + requires + - SUBNET implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + - sub.get_output().k() == 1 + - sub.sample_expansion_factor() == 1 + ensures + - Converts the output of the subnetwork into predicted labels. + - For each sample in the batch, extracts the logits at the last + sequence position (nr()-1) and assigns the index of the maximum + logit as the predicted label. + - Interprets the output tensor as: + output[i, 0, nr()-1, c] = logit for class c in sample i + !*/ + + template + double compute_loss_value_and_gradient( + const tensor& input_tensor, + const_label_iterator truth, + SUBNET& sub + ) const; + /*! + requires + - SUBNET implements the EXAMPLE_COMPUTATIONAL_LAYER_ interface + - sub.sample_expansion_factor() == 1 + - sub.get_output().k() == 1 + - sub.get_output().num_samples() == input_tensor.num_samples() + - The output tensor has shape [batch_size, 1, seq_len, vocab_size] + - truth == an iterator pointing to the first label in a sequence + of input_tensor.num_samples() labels + - All values pointed to by truth are < sub.get_output().nc() + (i.e., valid token IDs within vocabulary) + ensures + - Computes the cross-entropy loss for next token prediction. + - For each sample, the loss is computed only at the last sequence + position (nr()-1) using the corresponding label from truth. + - The loss is averaged over all samples in the batch. + - this function returns the loss value. + - Computes gradients with respect to the output logits and stores + them in sub.get_gradient_input(). + - Gradients are non-zero only at the last position of each sequence. + - The gradient computation uses numerically stable softmax. + !*/ + + friend void serialize(const loss_cross_entropy_per_logit_& item, std::ostream& out); + friend void deserialize(loss_cross_entropy_per_logit_& item, std::istream& in); + /*! + provides serialization support for loss_cross_entropy_per_logit_ + !*/ + + friend std::ostream& operator<<(std::ostream& out, const loss_cross_entropy_per_logit_& item); + /*! + prints a human readable string describing the loss layer to the output stream + !*/ + + friend void to_xml(const loss_cross_entropy_per_logit_& item, std::ostream& out); + /*! + provides XML serialization support for loss_cross_entropy_per_logit_ + !*/ + }; + + template + using loss_cross_entropy_per_logit = add_loss_layer; + /*! + This adds the loss_cross_entropy_per_logit_ loss layer onto SUBNET. + + TYPICAL USAGE IN TRANSFORMER NETWORKS: + This loss layer is specifically designed for transformer-based language + models that use autoregressive next token prediction. It should be used + as the final layer of a network that outputs logits for each position + in a sequence. + !*/ + // ---------------------------------------------------------------------------------------- enum class use_image_pyramid : uint8_t diff --git a/dlib/dnn/lr_scheduler.h b/dlib/dnn/lr_scheduler.h new file mode 100644 index 0000000000..0ca8444c36 --- /dev/null +++ b/dlib/dnn/lr_scheduler.h @@ -0,0 +1,385 @@ +// Copyright (C) 2025 Cydral (cydraltechnology@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_LR_SCHEDULER_H_ +#define DLIB_DNN_LR_SCHEDULER_H_ + +#include "lr_scheduler_abstract.h" +#include "../serialize.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + constexpr double lr_scheduler_pi = 3.14159265358979323846; + } + +// ---------------------------------------------------------------------------------------- + + enum class lr_decay_type + { + COSINE, + LINEAR, + CONSTANT, + EXPONENTIAL + }; + +// ---------------------------------------------------------------------------------------- + + class lr_scheduler + { + public: + + lr_scheduler( + ) : + current_step_(0), + warmup_steps_(2000), + hold_steps_(0), + total_steps_(100000), + initial_lr_(1e-7), + peak_lr_(3e-4), + min_lr_(1e-6), + decay_type_(lr_decay_type::COSINE) + { + compute_decay_steps(); + } + + lr_scheduler( + double peak_lr, + size_t warmup_steps, + size_t total_steps, + double min_lr = 1e-6, + lr_decay_type decay_type = lr_decay_type::COSINE + ) : + current_step_(0), + warmup_steps_(warmup_steps), + hold_steps_(0), + total_steps_(total_steps), + initial_lr_(min_lr), + peak_lr_(peak_lr), + min_lr_(min_lr), + decay_type_(decay_type) + { + DLIB_CASSERT(peak_lr > 0, "peak_lr must be positive"); + DLIB_CASSERT(min_lr >= 0, "min_lr must be non-negative"); + DLIB_CASSERT(min_lr < peak_lr, "min_lr must be less than peak_lr"); + DLIB_CASSERT(warmup_steps < total_steps, "warmup_steps must be less than total_steps"); + compute_decay_steps(); + } + + double get_learning_rate( + ) const + { + // Phase 1: Warmup + if (current_step_ < warmup_steps_) + { + if (warmup_steps_ == 0) + return peak_lr_; + const double progress = static_cast(current_step_) / warmup_steps_; + return initial_lr_ + (peak_lr_ - initial_lr_) * progress; + } + + // Phase 2: Hold (optional) + const size_t post_warmup = current_step_ - warmup_steps_; + if (post_warmup < hold_steps_) + return peak_lr_; + + // Phase 3: Decay + if (decay_steps_ == 0) + return peak_lr_; + + const size_t decay_step = post_warmup - hold_steps_; + const double progress = std::min(1.0, static_cast(decay_step) / decay_steps_); + + switch (decay_type_) + { + case lr_decay_type::COSINE: + return min_lr_ + 0.5 * (peak_lr_ - min_lr_) * (1.0 + std::cos(impl::lr_scheduler_pi * progress)); + + case lr_decay_type::LINEAR: + return peak_lr_ - (peak_lr_ - min_lr_) * progress; + + case lr_decay_type::EXPONENTIAL: + return peak_lr_ * std::pow(min_lr_ / peak_lr_, progress); + + case lr_decay_type::CONSTANT: + default: + return peak_lr_; + } + } + + double get_learning_rate( + size_t step + ) const + { + lr_scheduler temp = *this; + temp.current_step_ = step; + return temp.get_learning_rate(); + } + + void step( + size_t n = 1 + ) + { + current_step_ += n; + } + + void reset( + ) + { + current_step_ = 0; + } + + void set_current_step( + size_t step + ) + { + current_step_ = step; + } + + size_t get_current_step( + ) const { return current_step_; } + + size_t get_warmup_steps( + ) const { return warmup_steps_; } + + size_t get_hold_steps( + ) const { return hold_steps_; } + + size_t get_total_steps( + ) const { return total_steps_; } + + size_t get_decay_steps( + ) const { return decay_steps_; } + + double get_initial_lr( + ) const { return initial_lr_; } + + double get_peak_lr( + ) const { return peak_lr_; } + + double get_min_lr( + ) const { return min_lr_; } + + lr_decay_type get_decay_type( + ) const { return decay_type_; } + + void set_peak_lr( + double lr + ) + { + DLIB_CASSERT(lr > 0 && lr > min_lr_); + peak_lr_ = lr; + } + + void set_min_lr( + double lr + ) + { + DLIB_CASSERT(lr >= 0 && lr < peak_lr_); + min_lr_ = lr; + } + + void set_initial_lr( + double lr + ) + { + DLIB_CASSERT(lr >= 0 && lr <= peak_lr_); + initial_lr_ = lr; + } + + void set_warmup_steps( + size_t steps + ) + { + DLIB_CASSERT(steps < total_steps_); + warmup_steps_ = steps; + compute_decay_steps(); + } + + void set_hold_steps( + size_t steps + ) + { + hold_steps_ = steps; + compute_decay_steps(); + } + + void set_total_steps( + size_t steps + ) + { + DLIB_CASSERT(steps > warmup_steps_); + total_steps_ = steps; + compute_decay_steps(); + } + + void set_decay_type( + lr_decay_type type + ) + { + decay_type_ = type; + } + + bool is_warmup_complete( + ) const { return current_step_ >= warmup_steps_; } + + bool is_training_complete( + ) const { return current_step_ >= total_steps_; } + + double get_warmup_progress( + ) const + { + if (warmup_steps_ == 0) + return 1.0; + return std::min(1.0, static_cast(current_step_) / warmup_steps_); + } + + double get_total_progress( + ) const + { + if (total_steps_ == 0) + return 1.0; + return std::min(1.0, static_cast(current_step_) / total_steps_); + } + + std::string get_phase_name( + ) const + { + if (current_step_ < warmup_steps_) + return "warmup"; + else if (current_step_ < warmup_steps_ + hold_steps_) + return "hold"; + else + return "decay"; + } + + private: + + void compute_decay_steps( + ) + { + const size_t non_decay = warmup_steps_ + hold_steps_; + decay_steps_ = (total_steps_ > non_decay) ? (total_steps_ - non_decay) : 0; + } + + size_t current_step_; + size_t warmup_steps_; + size_t hold_steps_; + size_t total_steps_; + size_t decay_steps_; + double initial_lr_; + double peak_lr_; + double min_lr_; + lr_decay_type decay_type_; + }; + +// ---------------------------------------------------------------------------------------- + + inline void serialize( + const lr_scheduler& item, + std::ostream& out + ) + { + serialize("lr_scheduler", out); + serialize(item.get_current_step(), out); + serialize(item.get_warmup_steps(), out); + serialize(item.get_hold_steps(), out); + serialize(item.get_total_steps(), out); + serialize(item.get_decay_steps(), out); + serialize(item.get_initial_lr(), out); + serialize(item.get_peak_lr(), out); + serialize(item.get_min_lr(), out); + serialize(static_cast(item.get_decay_type()), out); + } + + inline void deserialize( + lr_scheduler& item, + std::istream& in + ) + { + std::string version; + deserialize(version, in); + if (version != "lr_scheduler") + throw serialization_error("Unexpected version '" + version + + "' found while deserializing lr_scheduler."); + + size_t current_step, warmup_steps, hold_steps, total_steps, decay_steps; + double initial_lr, peak_lr, min_lr; + int decay_type_int; + + deserialize(current_step, in); + deserialize(warmup_steps, in); + deserialize(hold_steps, in); + deserialize(total_steps, in); + deserialize(decay_steps, in); + deserialize(initial_lr, in); + deserialize(peak_lr, in); + deserialize(min_lr, in); + deserialize(decay_type_int, in); + + item = lr_scheduler(peak_lr, warmup_steps, total_steps, min_lr, + static_cast(decay_type_int)); + item.set_initial_lr(initial_lr); + item.set_hold_steps(hold_steps); + item.set_current_step(current_step); + } + + inline std::ostream& operator<<( + std::ostream& out, + const lr_scheduler& item + ) + { + out << "lr_scheduler (" + << "step=" << item.get_current_step() + << ", lr=" << item.get_learning_rate() + << ", phase=" << item.get_phase_name() + << ", warmup=" << item.get_warmup_steps() + << ", total=" << item.get_total_steps() + << ", peak=" << item.get_peak_lr() + << ", min=" << item.get_min_lr() + << ")"; + return out; + } + +// ---------------------------------------------------------------------------------------- + + inline lr_scheduler make_transformer_scheduler( + double peak_lr, + size_t total_steps, + double warmup_fraction = 0.02, + double min_lr = 1e-6, + lr_decay_type decay_type = lr_decay_type::COSINE + ) + { + DLIB_CASSERT(peak_lr > 0, "peak_lr must be positive"); + DLIB_CASSERT(total_steps > 0, "total_steps must be positive"); + DLIB_CASSERT(warmup_fraction > 0 && warmup_fraction < 1, "warmup_fraction must be in (0, 1)"); + DLIB_CASSERT(min_lr >= 0 && min_lr < peak_lr, "min_lr must be in [0, peak_lr)"); + + size_t warmup_steps = static_cast(total_steps * warmup_fraction); + warmup_steps = std::max(size_t(100), warmup_steps); + return lr_scheduler(peak_lr, warmup_steps, total_steps, min_lr, decay_type); + } + + inline size_t estimate_total_steps( + size_t dataset_size, + size_t batch_size, + size_t num_epochs + ) + { + DLIB_CASSERT(batch_size > 0, "batch_size must be positive"); + const size_t steps_per_epoch = (dataset_size + batch_size - 1) / batch_size; + return steps_per_epoch * num_epochs; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNN_LR_SCHEDULER_H_ diff --git a/dlib/dnn/lr_scheduler_abstract.h b/dlib/dnn/lr_scheduler_abstract.h new file mode 100644 index 0000000000..f1ced39e50 --- /dev/null +++ b/dlib/dnn/lr_scheduler_abstract.h @@ -0,0 +1,481 @@ +// Copyright (C) 2025 Cydral (cydraltechnology@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNN_LR_SCHEDULER_ABSTRACT_H_ +#ifdef DLIB_DNN_LR_SCHEDULER_ABSTRACT_H_ + +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + enum class lr_decay_type + { + /*! + WHAT THIS ENUM REPRESENTS + This enum specifies the type of learning rate decay to use after the + warmup phase completes. The decay function determines how the learning + rate decreases from peak_lr to min_lr over the remaining training steps. + !*/ + + COSINE, + /*! + Cosine annealing decay. The learning rate follows a cosine curve: + lr = min_lr + 0.5 * (peak_lr - min_lr) * (1 + cos(pi * progress)) + + This is the recommended decay type for transformer training as it provides + smooth decay with a gradual slowdown near the end of training. + !*/ + + LINEAR, + /*! + Linear decay. The learning rate decreases linearly: + lr = peak_lr - (peak_lr - min_lr) * progress + + Simple and predictable decay suitable for general deep learning tasks. + !*/ + + CONSTANT, + /*! + No decay after warmup. The learning rate remains at peak_lr: + lr = peak_lr + + Useful when using external learning rate control or for debugging. + !*/ + + EXPONENTIAL + /*! + Exponential decay. The learning rate decreases exponentially: + lr = peak_lr * (min_lr / peak_lr)^progress + + Provides rapid initial decay that slows down over time. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class lr_scheduler + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements a learning rate scheduler with warmup and decay + phases, designed for training transformer-based neural networks. It is + intended to be used alongside dnn_trainer to provide dynamic learning + rate adjustment during training. + + The schedule consists of three phases: + 1. WARMUP: Linear increase from initial_lr to peak_lr + 2. HOLD (optional): Maintain peak_lr for hold_steps + 3. DECAY: Decrease from peak_lr to min_lr using selected decay type + + MATHEMATICAL FORMULATION + Warmup phase (step < warmup_steps): + lr = initial_lr + (peak_lr - initial_lr) * (step / warmup_steps) + + Hold phase (warmup_steps <= step < warmup_steps + hold_steps): + lr = peak_lr + + Decay phase (step >= warmup_steps + hold_steps): + progress = (step - warmup_steps - hold_steps) / decay_steps + + For COSINE: + lr = min_lr + 0.5 * (peak_lr - min_lr) * (1 + cos(pi * progress)) + + For LINEAR: + lr = peak_lr - (peak_lr - min_lr) * progress + + For EXPONENTIAL: + lr = peak_lr * (min_lr / peak_lr)^progress + + For CONSTANT: + lr = peak_lr + + THREAD SAFETY + This object is not thread-safe. Each trainer should have its own scheduler + instance. If using multiple trainers in parallel, each should maintain its + own lr_scheduler. + + SERIALIZATION + This object supports serialization through serialize() and deserialize() + functions, allowing training to be checkpointed and resumed. + + TYPICAL USAGE + // Create scheduler + lr_scheduler scheduler( + 3e-4, // peak_lr + 2000, // warmup_steps + 100000, // total_steps + 1e-6, // min_lr + lr_decay_type::COSINE + ); + + // Training loop + while (!scheduler.is_training_complete()) { + trainer.set_learning_rate(scheduler.get_learning_rate()); + trainer.train_one_step(data, labels); + scheduler.step(); + } + !*/ + + public: + + lr_scheduler( + ); + /*! + ensures + - Constructs a default scheduler with reasonable defaults for transformer training + - #get_peak_lr() == 3e-4 + - #get_min_lr() == 1e-6 + - #get_initial_lr() == 1e-7 + - #get_warmup_steps() == 2000 + - #get_hold_steps() == 0 + - #get_total_steps() == 100000 + - #get_decay_type() == lr_decay_type::COSINE + - #get_current_step() == 0 + !*/ + + lr_scheduler( + double peak_lr, + size_t warmup_steps, + size_t total_steps, + double min_lr = 1e-6, + lr_decay_type decay_type = lr_decay_type::COSINE + ); + /*! + requires + - peak_lr > 0 + - min_lr >= 0 + - min_lr < peak_lr + - warmup_steps < total_steps + ensures + - #get_peak_lr() == peak_lr + - #get_min_lr() == min_lr + - #get_initial_lr() == min_lr + - #get_warmup_steps() == warmup_steps + - #get_hold_steps() == 0 + - #get_total_steps() == total_steps + - #get_decay_type() == decay_type + - #get_current_step() == 0 + !*/ + + double get_learning_rate( + ) const; + /*! + ensures + - Returns the learning rate for the current step based on the schedule + - The returned value is always >= get_min_lr() + - The returned value is always <= get_peak_lr() + - During warmup: returns a value linearly interpolated between + get_initial_lr() and get_peak_lr() + - During hold: returns get_peak_lr() + - During decay: returns a value determined by get_decay_type() + !*/ + + double get_learning_rate( + size_t step + ) const; + /*! + ensures + - Returns the learning rate that would be used at the specified step + - Does not modify the scheduler state + - Equivalent to temporarily setting current_step to step and calling + get_learning_rate(), then restoring the original current_step + !*/ + + void step( + size_t n = 1 + ); + /*! + ensures + - #get_current_step() == get_current_step() + n + - Advances the scheduler by n steps + !*/ + + void reset( + ); + /*! + ensures + - #get_current_step() == 0 + - Resets the scheduler to its initial state + !*/ + + void set_current_step( + size_t step + ); + /*! + ensures + - #get_current_step() == step + - Useful for resuming training from a checkpoint + !*/ + + size_t get_current_step( + ) const; + /*! + ensures + - Returns the current training step + !*/ + + size_t get_warmup_steps( + ) const; + /*! + ensures + - Returns the number of warmup steps configured for this scheduler + - During warmup, the learning rate increases linearly from + get_initial_lr() to get_peak_lr() + !*/ + + size_t get_hold_steps( + ) const; + /*! + ensures + - Returns the number of hold steps configured for this scheduler + - During hold, the learning rate remains constant at get_peak_lr() + !*/ + + size_t get_total_steps( + ) const; + /*! + ensures + - Returns the total number of training steps configured for this scheduler + - Training is considered complete when get_current_step() >= get_total_steps() + !*/ + + size_t get_decay_steps( + ) const; + /*! + ensures + - Returns the number of steps in the decay phase + - Computed as: get_total_steps() - get_warmup_steps() - get_hold_steps() + !*/ + + double get_initial_lr( + ) const; + /*! + ensures + - Returns the initial learning rate at the start of warmup + - This is the learning rate used at step 0 + !*/ + + double get_peak_lr( + ) const; + /*! + ensures + - Returns the peak learning rate reached at the end of warmup + - This is the maximum learning rate during training + !*/ + + double get_min_lr( + ) const; + /*! + ensures + - Returns the minimum learning rate at the end of training + - The learning rate will never go below this value + !*/ + + lr_decay_type get_decay_type( + ) const; + /*! + ensures + - Returns the decay type used after warmup completes + !*/ + + void set_peak_lr( + double lr + ); + /*! + requires + - lr > 0 + - lr > get_min_lr() + ensures + - #get_peak_lr() == lr + !*/ + + void set_min_lr( + double lr + ); + /*! + requires + - lr >= 0 + - lr < get_peak_lr() + ensures + - #get_min_lr() == lr + !*/ + + void set_initial_lr( + double lr + ); + /*! + requires + - lr >= 0 + - lr <= get_peak_lr() + ensures + - #get_initial_lr() == lr + !*/ + + void set_warmup_steps( + size_t steps + ); + /*! + requires + - steps < get_total_steps() + ensures + - #get_warmup_steps() == steps + - #get_decay_steps() is recomputed accordingly + !*/ + + void set_hold_steps( + size_t steps + ); + /*! + ensures + - #get_hold_steps() == steps + - #get_decay_steps() is recomputed accordingly + !*/ + + void set_total_steps( + size_t steps + ); + /*! + requires + - steps > get_warmup_steps() + ensures + - #get_total_steps() == steps + - #get_decay_steps() is recomputed accordingly + !*/ + + void set_decay_type( + lr_decay_type type + ); + /*! + ensures + - #get_decay_type() == type + !*/ + + bool is_warmup_complete( + ) const; + /*! + ensures + - Returns true if the warmup phase has completed + - Equivalent to: get_current_step() >= get_warmup_steps() + !*/ + + bool is_training_complete( + ) const; + /*! + ensures + - Returns true if all training steps have been completed + - Equivalent to: get_current_step() >= get_total_steps() + !*/ + + double get_warmup_progress( + ) const; + /*! + ensures + - Returns a value between 0.0 and 1.0 indicating progress through warmup + - Returns 1.0 if warmup is complete + - Computed as: min(1.0, get_current_step() / get_warmup_steps()) + !*/ + + double get_total_progress( + ) const; + /*! + ensures + - Returns a value between 0.0 and 1.0 indicating overall training progress + - Computed as: min(1.0, get_current_step() / get_total_steps()) + !*/ + + std::string get_phase_name( + ) const; + /*! + ensures + - Returns "warmup" if in the warmup phase + - Returns "hold" if in the hold phase + - Returns "decay" if in the decay phase + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + void serialize( + const lr_scheduler& item, + std::ostream& out + ); + /*! + ensures + - Serializes the complete state of item to the output stream out + - The serialized state includes: current_step, warmup_steps, hold_steps, + total_steps, decay_steps, initial_lr, peak_lr, min_lr, and decay_type + !*/ + + void deserialize( + lr_scheduler& item, + std::istream& in + ); + /*! + ensures + - Deserializes the state of item from the input stream in + - Restores all configuration and progress state + throws + - serialization_error if the data in 'in' is not valid lr_scheduler data + !*/ + + std::ostream& operator<<( + std::ostream& out, + const lr_scheduler& item + ); + /*! + ensures + - Prints a human-readable summary of the scheduler state to out + - Includes: current step, current learning rate, phase name, and configuration + !*/ + +// ---------------------------------------------------------------------------------------- + + lr_scheduler make_transformer_scheduler( + double peak_lr, + size_t total_steps, + double warmup_fraction = 0.02, + double min_lr = 1e-6, + lr_decay_type decay_type = lr_decay_type::COSINE + ); + /*! + requires + - peak_lr > 0 + - total_steps > 0 + - 0 < warmup_fraction < 1 + - min_lr >= 0 + - min_lr < peak_lr + ensures + - Returns an lr_scheduler configured with common transformer training settings + - The warmup_steps is computed as: max(100, total_steps * warmup_fraction) + - returns a scheduler S such that: + - S.get_peak_lr() == peak_lr + - S.get_total_steps() == total_steps + - S.get_min_lr() == min_lr + - S.get_decay_type() == decay_type + - S.get_warmup_steps() == max(100, total_steps * warmup_fraction) + !*/ + + size_t estimate_total_steps( + size_t dataset_size, + size_t batch_size, + size_t num_epochs + ); + /*! + requires + - batch_size > 0 + ensures + - Returns an estimate of the total number of training steps + - Computed as: ceil(dataset_size / batch_size) * num_epochs + - Useful for configuring lr_scheduler when you know the dataset size, + batch size, and desired number of epochs + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNN_LR_SCHEDULER_ABSTRACT_H_ diff --git a/dlib/dnn/solvers.h b/dlib/dnn/solvers.h index 6eab32be12..d28a5aa93f 100644 --- a/dlib/dnn/solvers.h +++ b/dlib/dnn/solvers.h @@ -397,6 +397,349 @@ namespace dlib float t; }; + // ---------------------------------------------------------------------------------------- + + /*! + AdamW optimizer with decoupled weight decay regularization. + + This optimizer implements the AdamW algorithm from "Decoupled Weight Decay + Regularization" (Loshchilov & Hutter, ICLR 2019). Unlike standard Adam, + AdamW decouples the weight decay from the gradient-based optimization step, + leading to better generalization and easier hyperparameter tuning. + + THEORETICAL FOUNDATION: + Standard Adam with L2 regularization computes: + theta_t = theta_{t-1} - alpha * m_hat_t / sqrt(v_hat_t + epsilon) + where gradients include the L2 regularization term + + AdamW decouples weight decay and computes: + m_t = beta1 * m_{t-1} + (1-beta1) * gradient_L + v_t = beta2 * v_{t-1} + (1-beta2) * (gradient_L)^2 + theta_t = theta_{t-1} - alpha * (m_hat_t/sqrt(v_hat_t) + lambda*theta_{t-1}) + + This formulation makes the optimal weight decay factor independent of + the learning rate, improving generalization especially for long training runs. + + IMPLEMENTATION STRATEGY: + 1. Compute standard Adam update with weight_decay = 0 (decoupled) + 2. Explicitly apply weight decay: update = update - lr * wd * params + 3. The update is then added to parameters by the trainer + + KEY DIFFERENCES FROM ADAM: + - Weight decay is applied directly to parameters (multiplicative) + - Weight decay does not interact with adaptive learning rates + - Better hyperparameter independence (learning rate vs weight decay) + - Superior generalization on image classification and NLP tasks + + CONSTRUCTOR PARAMETERS: + - weight_decay: Decoupled weight decay coefficient (default: 0.01) + Typical range: 0.0001 to 0.1 + Higher values = stronger regularization + - momentum1 (beta1): Exponential decay rate for first moment (default: 0.9) + Controls the momentum of gradient moving average + - momentum2 (beta2): Exponential decay rate for second moment (default: 0.999) + Controls the momentum of squared gradient moving average + + REFERENCES: + - Loshchilov & Hutter (2019). "Decoupled Weight Decay Regularization" + ICLR 2019. https://arxiv.org/abs/1711.05101 + - Kingma & Ba (2015). "Adam: A Method for Stochastic Optimization" + ICLR 2015. https://arxiv.org/abs/1412.6980 + + NOTE: AdamW is the standard optimizer for modern transformer models including + GPT, BERT, LLaMA, Mistral, Qwen, DeepSeek, and other large language models. + It consistently outperforms standard Adam with L2 regularization. + !*/ + class adamw + { + public: + + explicit adamw( + float weight_decay_ = 0.01f, + float momentum1_ = 0.9f, + float momentum2_ = 0.999f + ) + { + weight_decay = weight_decay_; + momentum1 = momentum1_; + momentum2 = momentum2_; + t = 0; + } + + float get_momentum1() const { return momentum1; } + float get_momentum2() const { return momentum2; } + float get_weight_decay() const { return weight_decay; } + + template + const tensor& operator() ( + const float learning_rate, + const layer_type& l, + const tensor& params_grad + ) + { + const tensor& params = l.get_layer_params(); + DLIB_CASSERT(params.size() != 0); + + if (v.size() == 0) + { + m.copy_size(params_grad); + m = 0; + v.copy_size(params_grad); + v = 0; + s.copy_size(params_grad); + } + + ++t; + + // Step 1: compute standard Adam update with decoupled weight decay (wd = 0) + // This populates 's' with the adaptive gradient step: -alpha * m_hat_t / sqrt(v_hat_t) + // By passing weight_decay = 0, we decouple the regularization from the adaptive update + tt::compute_adam_update(0, params.size(), s, m, v, t, + learning_rate * get_learning_rate_multiplier(l), + 0, // Critical: weight_decay = 0 for decoupled regularization + momentum1, momentum2, params, params_grad); + + // Step 2: apply decoupled weight decay explicitly + // Formula: s = s - alpha * lambda * theta_{t-1} + // This implements the AdamW update: theta_t = theta_{t-1} - alpha * (m_hat_t/sqrt(v_hat_t) + lambda * theta_{t-1}) + const double lr = learning_rate * get_learning_rate_multiplier(l); + const double wd = weight_decay * get_weight_decay_multiplier(l); + + if (wd != 0) + { + // Compute: s = s + params * (-lr * wd) + tt::affine_transform(s, s, params, 1.0, -lr * wd); + } + + return s; + } + + template + const tensor& operator() ( + const float learning_rate, + const fc_& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size() - l.get_num_outputs()); + return s; + } + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y, + int _padding_x + > + const tensor& operator() ( + const float learning_rate, + const con_<_num_filters, _nr, _nc, _stride_y, _stride_x, _padding_y, _padding_x>& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size() - l.num_filters()); + return s; + } + + template < + long _num_filters, + long _nr, + long _nc, + int _stride_y, + int _stride_x, + int _padding_y, + int _padding_x + > + const tensor& operator() ( + const float learning_rate, + const cont_<_num_filters, _nr, _nc, _stride_y, _stride_x, _padding_y, _padding_x>& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size() - l.num_filters()); + return s; + } + + template < layer_mode mode > + const tensor& operator() ( + const float learning_rate, + const bn_& l, + const tensor& params_grad + ) + { + update_considering_bias(learning_rate, l, params_grad, params_grad.size() / 2); + return s; + } + + friend void serialize(const adamw& item, std::ostream& out) + { + serialize("adamw", out); + serialize(item.m, out); + serialize(item.v, out); + serialize(item.s, out); + serialize(item.weight_decay, out); + serialize(item.momentum1, out); + serialize(item.momentum2, out); + serialize(item.t, out); + } + + friend void deserialize(adamw& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "adamw") + throw serialization_error("Unexpected version found while deserializing dlib::adamw."); + deserialize(item.m, in); + deserialize(item.v, in); + deserialize(item.s, in); + deserialize(item.weight_decay, in); + deserialize(item.momentum1, in); + deserialize(item.momentum2, in); + deserialize(item.t, in); + } + + friend std::ostream& operator<< (std::ostream& out, const adamw& item) + { + out << "adamw: weight_decay=" << item.get_weight_decay() + << ", momentum1=" << item.get_momentum1() + << ", momentum2=" << item.get_momentum2(); + return out; + } + + private: + + /*! + Updates parameters that may have different learning rate and weight decay + multipliers for weights vs biases (e.g., fully connected and convolutional layers). + + BIAS HANDLING: + Most layers separate weights and biases: + - Weights: indices [0, bias_offset) + - Biases: indices [bias_offset, end) + + Different multipliers may apply to each section: + - bias_learning_rate_multiplier (typically 1.0 or 2.0) + - bias_weight_decay_multiplier (typically 0.0 - no decay on biases) + + PARAMETERS: + - learning_rate: base learning rate from trainer + - l: layer containing parameters and multiplier settings + - params_grad: gradient tensor + - bias_offset: index where biases start in the parameter tensor + !*/ + template + void update_considering_bias( + const float learning_rate, + const layer_type& l, + const tensor& params_grad, + unsigned long bias_offset + ) + { + const tensor& params = l.get_layer_params(); + DLIB_CASSERT(params.size() != 0); + + if (v.size() == 0) + { + m.copy_size(params_grad); + m = 0; + v.copy_size(params_grad); + v = 0; + s.copy_size(params_grad); + } + + ++t; + + // Step 1: compute adaptive gradient update with decoupled weight decay + if (l.get_bias_learning_rate_multiplier() == 1) + { + // Simple case: uniform learning rate for all parameters + tt::compute_adam_update(0, params.size(), s, m, v, t, + learning_rate * get_learning_rate_multiplier(l), + 0, // Decoupled: weight_decay = 0 in Adam computation + momentum1, momentum2, params, params_grad); + } + else + { + // Complex case: different learning rates for weights and biases + + // Process weights: indices [0, bias_offset) + tt::compute_adam_update(0, bias_offset, s, m, v, t, + learning_rate * get_learning_rate_multiplier(l), + 0, // Decoupled weight decay + momentum1, momentum2, params, params_grad); + + // Process biases: indices [bias_offset, end) + // Apply bias learning rate multiplier + tt::compute_adam_update(bias_offset, params.size(), s, m, v, t, + learning_rate * get_learning_rate_multiplier(l) * l.get_bias_learning_rate_multiplier(), + 0, // Decoupled weight decay + momentum1, momentum2, params, params_grad); + } + + // Step 2: apply decoupled weight decay + // Formula: s = s - lr * wd * params + // This is applied separately to weights and biases because they may have + // different weight decay multipliers + double lr = learning_rate * get_learning_rate_multiplier(l); + double wd = weight_decay * get_weight_decay_multiplier(l); + + if (l.get_bias_learning_rate_multiplier() == 1 && l.get_bias_weight_decay_multiplier() == 1) + { + // Simple case: uniform weight decay for all parameters + if (wd != 0) + tt::affine_transform(s, s, params, 1.0, -lr * wd); + } + else + { + // Complex case: different weight decay for weights vs biases + + // Apply weight decay to weights: indices [0, bias_offset) + // Computation: s[i] = 1.0 * s[i] + (-lr * wd) * params[i] + 0.0 * params[i] + // The third source (params) is not used since C = 0.0 + if (wd != 0) + { + tt::affine_transform_range(0, bias_offset, + s, // dest + s, // src1 (A coefficient) + params, // src2 (B coefficient) + params, // src3 (C coefficient = 0, so this is unused) + 1.0, // A: keep current update + -lr * wd, // B: subtract weight decay term + 0.0); // C: ignore third source + } + + // Apply weight decay to biases: indices [bias_offset, end) + // Note: typically bias_weight_decay_multiplier = 0 (no regularization on biases) + // This is a common practice in deep learning to prevent biases from becoming too small + lr *= l.get_bias_learning_rate_multiplier(); + wd *= l.get_bias_weight_decay_multiplier(); + + if (wd != 0) + { + tt::affine_transform_range(bias_offset, v.size(), + s, + s, + params, + params, + 1.0, + -lr * wd, + 0.0); + } + } + } + + resizable_tensor m; // First moment estimate (exponential moving average of gradients) + resizable_tensor v; // Second moment estimate (exponential moving average of squared gradients) + resizable_tensor s; // Parameter update computed by the optimizer + float weight_decay; // Weight decay coefficient (lambda in the paper) + float momentum1; // Beta1: decay rate for first moment + float momentum2; // Beta2: decay rate for second moment + float t; // Time step counter for bias correction + }; + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/dnn/solvers_abstract.h b/dlib/dnn/solvers_abstract.h index 7a07452170..20c37987dd 100644 --- a/dlib/dnn/solvers_abstract.h +++ b/dlib/dnn/solvers_abstract.h @@ -9,8 +9,6 @@ namespace dlib { -// ---------------------------------------------------------------------------------------- -// ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class EXAMPLE_SOLVER @@ -69,8 +67,6 @@ namespace dlib Prints the solver's name and parameters to out. !*/ -// ---------------------------------------------------------------------------------------- -// ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- class sgd @@ -196,6 +192,82 @@ namespace dlib Prints the solver's name and parameters to out. !*/ +// ---------------------------------------------------------------------------------------- + + class adamw + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_SOLVER interface defined above. In + particular, it implements the AdamW parameter update method with decoupled + weight decay regularization as described in the paper: + Loshchilov, Ilya, and Frank Hutter. "Decoupled weight decay + regularization." International Conference on Learning Representations. 2019. + + The key difference from standard Adam is that weight decay is decoupled from + the gradient-based optimization step. This leads to better generalization + performance and makes the optimal weight decay factor more independent of the + learning rate setting. AdamW has become the standard optimizer for training + large language models and transformer architectures. + + The update is computed as: + m_t = momentum1*m_{t-1} + (1-momentum1)*params_grad + v_t = momentum2*v_{t-1} + (1-momentum2)*(params_grad^2) + V = -learning_rate * (m_hat_t/sqrt(v_hat_t) + weight_decay*l.get_layer_params()) + where m_hat_t and v_hat_t are bias-corrected moment estimates. + + Note that the actual learning rate and weight decay used by the solver are + multiplied by the per layer multipliers. That is, the solver will call + get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and + multiply these values with the nominal learning rate and weight decay, + respectively, to determine the values it will use during each step. It is + also overloaded to allow additional learning rate multipliers to be applied + to fc_ and con_ bias parameters. + !*/ + + public: + + adamw( + ); + /*! + ensures + - #get_weight_decay() == 0.01 + - #get_momentum1() == 0.9 + - #get_momentum2() == 0.999 + !*/ + + explicit adamw( + float weight_decay, + float momentum1 = 0.9, + float momentum2 = 0.999 + ); + /*! + requires + - weight_decay >= 0 + - 0 <= momentum1 < 1 + - 0 <= momentum2 < 1 + ensures + - #get_weight_decay() == weight_decay + - #get_momentum1() == momentum1 + - #get_momentum2() == momentum2 + !*/ + + float get_weight_decay() const; + float get_momentum1() const; + float get_momentum2() const; + }; + + void serialize(const adamw& item, std::ostream& out); + void deserialize(adamw& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const adamw& item); + /*! + Prints the solver's name and parameters to out. + !*/ + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/dnn/trainer.h b/dlib/dnn/trainer.h index c329791e78..3cdc6fa1ec 100644 --- a/dlib/dnn/trainer.h +++ b/dlib/dnn/trainer.h @@ -11,6 +11,7 @@ #include #include #include "../serialize.h" +#include "lr_scheduler.h" #include "../pipe.h" #include "../threads.h" diff --git a/dlib/dnn/transformer.h b/dlib/dnn/transformer.h new file mode 100644 index 0000000000..786e8ea8a0 --- /dev/null +++ b/dlib/dnn/transformer.h @@ -0,0 +1,1019 @@ +// Copyright (C) 2025 Cydral Technology (cydraltechnology@gmail.com) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_DNN_TRANSFORMER_H_ +#define DLIB_DNN_TRANSFORMER_H_ + +#include "transformer_abstract.h" +#include "layers.h" + +namespace dlib +{ + // ---------------------------------------------------------------------------------------- + + template + class scale_weights_ : public multiply_ + { + public: + explicit scale_weights_() : multiply_(1.0f / std::sqrt(static_cast(d_k_))) {} + }; + + template + using scale_weights = add_layer, SUBNET>; + + // ---------------------------------------------------------------------------------------- + + template + using positional_embeddings = positional_encodings< + embeddings>; + + // ---------------------------------------------------------------------------------------- + + // CANONICAL TRANSFORMER ARCHITECTURE + namespace canonical_transformer + { + + template + using query = reshape_to>; + + template + using key = reshape_to>; + + template + using value = reshape_to>; + + template