From d2903907e990918ee71fc836d142403de64d9334 Mon Sep 17 00:00:00 2001 From: Zedong Peng Date: Thu, 20 Nov 2025 02:37:10 -0500 Subject: [PATCH] first trial cuda graph --- internal/internal_types.h | 9 ++ src/solver.cu | 215 ++++++++++++++++++++++++++++++++------ src/utils.cu | 2 +- 3 files changed, 194 insertions(+), 32 deletions(-) diff --git a/internal/internal_types.h b/internal/internal_types.h index 1f42582..67d9800 100644 --- a/internal/internal_types.h +++ b/internal/internal_types.h @@ -125,6 +125,15 @@ typedef struct double feasibility_polishing_time; int feasibility_iteration; + + // --- CUDA Graph Related Fields --- + cudaGraph_t graph; // Handle to the captured graph + cudaGraphExec_t graph_instance; // Executable graph instance + bool graph_created; // Flag indicating if the graph is instantiated + bool graph_needs_update; + cudaStream_t stream; // Dedicated stream for solver operations + double *d_halpern_weight; // Device pointer for dynamic Halpern weight + } pdhg_solver_state_t; typedef struct diff --git a/src/solver.cu b/src/solver.cu index a9e5f2b..c849286 100644 --- a/src/solver.cu +++ b/src/solver.cu @@ -48,7 +48,7 @@ halpern_update_kernel(const double *initial_primal, double *current_primal, const double *reflected_primal, const double *initial_dual, double *current_dual, const double *reflected_dual, int n_vars, int n_cons, - double weight, double reflection_coeff); + const double *weight_ptr, double reflection_coeff); __global__ void rescale_solution_kernel(double *primal_solution, double *dual_solution, const double *variable_rescaling, @@ -103,13 +103,24 @@ cupdlpx_result_t *optimize(const pdhg_parameters_t *params, rescale_info_free(rescale_info); initialize_step_size_and_primal_weight(state, params); + clock_t start_time = clock(); bool do_restart = false; + double h_weight_val = 0.0; // Host variable for Halpern weight + while (state->termination_reason == TERMINATION_REASON_UNSPECIFIED) { - if ((state->is_this_major_iteration || state->total_count == 0) || - (state->total_count % get_print_frequency(state->total_count) == 0)) + // Check if this step requires logging or major iteration logic + bool is_print_iter = (state->total_count % get_print_frequency(state->total_count) == 0); + + // ==================================================== + // 1. Major Iteration & Logging (Standard CPU Control) + // ==================================================== + if ((state->is_this_major_iteration || state->total_count == 0) || is_print_iter) { + // Synchronize stream to ensure graph execution is finished before CPU access + CUDA_CHECK(cudaStreamSynchronize(state->stream)); + compute_residual(state); if (state->is_this_major_iteration && state->total_count < 3 * params->termination_evaluation_frequency) @@ -124,32 +135,128 @@ cupdlpx_result_t *optimize(const pdhg_parameters_t *params, display_iteration_stats(state, params->verbose); } + // ==================================================== + // 2. Restart Handling + // ==================================================== if ((state->is_this_major_iteration || state->total_count == 0)) { do_restart = should_do_adaptive_restart(state, ¶ms->restart_params, params->termination_evaluation_frequency); if (do_restart) + { perform_restart(state, params); + + // If restart happens, scalar parameters (step_size, weights) baked into the graph + // might change. We must destroy the graph to force re-capture. + if (state->graph_created) { + state->graph_needs_update = true; + // CUDA_CHECK(cudaGraphExecDestroy(state->graph_instance)); + // CUDA_CHECK(cudaGraphDestroy(state->graph)); + // state->graph_created = false; + } + } } state->is_this_major_iteration = - ((state->total_count + 1) % params->termination_evaluation_frequency) == - 0; + ((state->total_count + 1) % params->termination_evaluation_frequency) == 0; - compute_next_pdhg_primal_solution(state); - compute_next_pdhg_dual_solution(state); + // Calculate current Halpern weight on Host + h_weight_val = (double)(state->inner_count + 1) / (state->inner_count + 2); - if (state->is_this_major_iteration || do_restart) + // ==================================================== + // 3. Core Iteration: Graph vs Standard Path + // ==================================================== + + // We use CUDA Graph ONLY for standard minor iterations. + // We avoid graph if: + // - It is a major iteration (requires different kernel logic with slack calc). + // - A restart occurred (parameters changed). + // - The NEXT iteration will be a major/print iteration (logic inside compute_next functions). + + bool next_is_major_logic = state->is_this_major_iteration || + ((state->total_count + 2) % get_print_frequency(state->total_count + 2) == 0); + CUDA_CHECK(cudaMemcpyAsync(state->d_halpern_weight, &h_weight_val, sizeof(double), cudaMemcpyHostToDevice, state->stream)); + if (next_is_major_logic || do_restart) { - compute_fixed_point_error(state); - if (do_restart) + // --- Path A: Standard Kernel Launch (Slow Path) --- + + // Update the weight in device memory (needed because kernel now reads pointer) + compute_next_pdhg_primal_solution(state); + compute_next_pdhg_dual_solution(state); + + if (state->is_this_major_iteration || do_restart) { - state->initial_fixed_point_error = state->fixed_point_error; - do_restart = false; + compute_fixed_point_error(state); + if (do_restart) + { + state->initial_fixed_point_error = state->fixed_point_error; + do_restart = false; + } } + halpern_update(state, params->reflection_coefficient); + } + else + { + // --- Path B: CUDA Graph Launch (Fast Path) --- + + // 1. Update the dynamic parameter in device memory asynchronously. + // Since MemcpyAsync and GraphLaunch are in the same stream, + // the update is guaranteed to happen before the kernel reads it. + + CUDA_CHECK(cudaMemcpyAsync(state->d_halpern_weight, &h_weight_val, sizeof(double), cudaMemcpyHostToDevice, state->stream)); + + // if (!state->graph_created) + // { + // // --- Capture Phase --- + // CUDA_CHECK(cudaStreamBeginCapture(state->stream, cudaStreamCaptureModeGlobal)); + + // // Record the sequence of operations. + // // Note: These functions must use state->stream internally. + // compute_next_pdhg_primal_solution(state); + // compute_next_pdhg_dual_solution(state); + // halpern_update(state, params->reflection_coefficient); + + // CUDA_CHECK(cudaStreamEndCapture(state->stream, &state->graph)); + // CUDA_CHECK(cudaGraphInstantiate(&state->graph_instance, state->graph, NULL, NULL, 0)); + // state->graph_created = true; + // } + + if (!state->graph_created) + { + // Capture + Instantiate + CUDA_CHECK(cudaStreamBeginCapture(state->stream, cudaStreamCaptureModeGlobal)); + + compute_next_pdhg_primal_solution(state); + compute_next_pdhg_dual_solution(state); + halpern_update(state, params->reflection_coefficient); + + CUDA_CHECK(cudaStreamEndCapture(state->stream, &state->graph)); + CUDA_CHECK(cudaGraphInstantiate(&state->graph_instance, state->graph, NULL, NULL, 0)); + state->graph_created = true; + state->graph_needs_update = false; + } + else if (state->graph_needs_update) + { + // Capture + Update + cudaGraph_t temp_graph; + CUDA_CHECK(cudaStreamBeginCapture(state->stream, cudaStreamCaptureModeGlobal)); + + compute_next_pdhg_primal_solution(state); + compute_next_pdhg_dual_solution(state); + halpern_update(state, params->reflection_coefficient); + + CUDA_CHECK(cudaStreamEndCapture(state->stream, &temp_graph)); + cudaGraphExecUpdateResult updateResult; + CUDA_CHECK(cudaGraphExecUpdate(state->graph_instance, temp_graph, NULL, &updateResult)); + CUDA_CHECK(cudaGraphDestroy(temp_graph)); + + state->graph_needs_update = false; + } + + // --- Execution Phase --- + CUDA_CHECK(cudaGraphLaunch(state->graph_instance, state->stream)); } - halpern_update(state, params->reflection_coefficient); state->inner_count++; state->total_count++; @@ -470,6 +577,17 @@ initialize_solver_state(const lp_problem_t *original_problem, cudaMemcpyHostToDevice)); free(ones_dual_h); + // --- CUDA Graph Initialization --- + state->graph_created = false; + state->graph_needs_update = false; + CUDA_CHECK(cudaStreamCreate(&state->stream)); + CUDA_CHECK(cudaMalloc(&state->d_halpern_weight, sizeof(double))); + + // CRITICAL: Bind libraries to the specific stream + // If this is skipped, library calls will use the default stream and fail to capture + CUSPARSE_CHECK(cusparseSetStream(state->sparse_handle, state->stream)); + CUBLAS_CHECK(cublasSetStream(state->blas_handle, state->stream)); + return state; } @@ -538,9 +656,10 @@ halpern_update_kernel(const double *initial_primal, double *current_primal, const double *reflected_primal, const double *initial_dual, double *current_dual, const double *reflected_dual, int n_vars, int n_cons, - double weight, double reflection_coeff) + const double *weight_ptr, double reflection_coeff) { int i = blockIdx.x * blockDim.x + threadIdx.x; + double weight = *weight_ptr; if (i < n_vars) { double reflected = reflection_coeff * reflected_primal[i] + @@ -597,11 +716,15 @@ __global__ void compute_delta_solution_kernel( static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) { + // Update vector descriptors (Host-side operation, no stream needed) CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_dual_sol, state->current_dual_solution)); CUSPARSE_CHECK( cusparseDnVecSetValues(state->vec_dual_prod, state->dual_product)); + // Execute SpMV + // Note: This automatically uses state->stream because we called + // cusparseSetStream(state->sparse_handle, state->stream) in initialization. CUSPARSE_CHECK(cusparseSpMV( state->sparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &HOST_ONE, state->matAt, state->vec_dual_sol, &HOST_ZERO, state->vec_dual_prod, @@ -609,12 +732,18 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) double step = state->step_size / state->primal_weight; + // Determine which kernel to launch if (state->is_this_major_iteration || ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) { - compute_next_pdhg_primal_solution_major_kernel<<num_blocks_primal, - THREADS_PER_BLOCK>>>( + // MODIFIED: Added stream argument + compute_next_pdhg_primal_solution_major_kernel<<< + state->num_blocks_primal, + THREADS_PER_BLOCK, + 0, // Shared memory size (0 bytes) + state->stream // Capture stream + >>>( state->current_primal_solution, state->pdhg_primal_solution, state->reflected_primal_solution, state->dual_product, state->objective_vector, state->variable_lower_bound, @@ -623,8 +752,13 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) } else { - compute_next_pdhg_primal_solution_kernel<<num_blocks_primal, - THREADS_PER_BLOCK>>>( + // MODIFIED: Added stream argument + compute_next_pdhg_primal_solution_kernel<<< + state->num_blocks_primal, + THREADS_PER_BLOCK, + 0, // Shared memory size (0 bytes) + state->stream // Capture stream + >>>( state->current_primal_solution, state->reflected_primal_solution, state->dual_product, state->objective_vector, state->variable_lower_bound, state->variable_upper_bound, @@ -634,11 +768,13 @@ static void compute_next_pdhg_primal_solution(pdhg_solver_state_t *state) static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) { + // Update vector descriptors CUSPARSE_CHECK(cusparseDnVecSetValues(state->vec_primal_sol, state->reflected_primal_solution)); CUSPARSE_CHECK( cusparseDnVecSetValues(state->vec_primal_prod, state->primal_product)); + // Execute SpMV (Uses stream from handle) CUSPARSE_CHECK(cusparseSpMV( state->sparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &HOST_ONE, state->matA, state->vec_primal_sol, &HOST_ZERO, state->vec_primal_prod, @@ -650,8 +786,13 @@ static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) ((state->total_count + 2) % get_print_frequency(state->total_count + 2)) == 0) { - compute_next_pdhg_dual_solution_major_kernel<<num_blocks_dual, - THREADS_PER_BLOCK>>>( + // MODIFIED: Added stream argument + compute_next_pdhg_dual_solution_major_kernel<<< + state->num_blocks_dual, + THREADS_PER_BLOCK, + 0, // Shared memory size + state->stream // Capture stream + >>>( state->current_dual_solution, state->pdhg_dual_solution, state->reflected_dual_solution, state->primal_product, state->constraint_lower_bound, state->constraint_upper_bound, @@ -659,8 +800,13 @@ static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) } else { - compute_next_pdhg_dual_solution_kernel<<num_blocks_dual, - THREADS_PER_BLOCK>>>( + // MODIFIED: Added stream argument + compute_next_pdhg_dual_solution_kernel<<< + state->num_blocks_dual, + THREADS_PER_BLOCK, + 0, // Shared memory size + state->stream // Capture stream + >>>( state->current_dual_solution, state->reflected_dual_solution, state->primal_product, state->constraint_lower_bound, state->constraint_upper_bound, state->num_constraints, step); @@ -670,18 +816,18 @@ static void compute_next_pdhg_dual_solution(pdhg_solver_state_t *state) static void halpern_update(pdhg_solver_state_t *state, double reflection_coefficient) { - double weight = (double)(state->inner_count + 1) / (state->inner_count + 2); - halpern_update_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK>>>( + // double weight = (double)(state->inner_count + 1) / (state->inner_count + 2); + halpern_update_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK, 0, state->stream>>>( state->initial_primal_solution, state->current_primal_solution, state->reflected_primal_solution, state->initial_dual_solution, state->current_dual_solution, state->reflected_dual_solution, - state->num_variables, state->num_constraints, weight, + state->num_variables, state->num_constraints, state->d_halpern_weight, reflection_coefficient); } static void rescale_solution(pdhg_solver_state_t *state) { - rescale_solution_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK>>>( + rescale_solution_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK, 0, state->stream>>>( state->pdhg_primal_solution, state->pdhg_dual_solution, state->variable_rescaling, state->constraint_rescaling, state->objective_vector_rescaling, state->constraint_bound_rescaling, @@ -692,7 +838,7 @@ static void perform_restart(pdhg_solver_state_t *state, const pdhg_parameters_t *params) { compute_delta_solution_kernel<<num_blocks_primal_dual, - THREADS_PER_BLOCK>>>( + THREADS_PER_BLOCK, 0, state->stream>>>( state->initial_primal_solution, state->pdhg_primal_solution, state->delta_primal_solution, state->initial_dual_solution, state->pdhg_dual_solution, state->delta_dual_solution, @@ -778,7 +924,7 @@ initialize_step_size_and_primal_weight(pdhg_solver_state_t *state, static void compute_fixed_point_error(pdhg_solver_state_t *state) { compute_delta_solution_kernel<<num_blocks_primal_dual, - THREADS_PER_BLOCK>>>( + THREADS_PER_BLOCK, 0, state->stream>>>( state->current_primal_solution, state->reflected_primal_solution, state->delta_primal_solution, state->current_dual_solution, state->reflected_dual_solution, state->delta_dual_solution, @@ -893,6 +1039,13 @@ void pdhg_solver_state_free(pdhg_solver_state_t *state) CUDA_CHECK(cudaFree(state->ones_primal_d)); if (state->ones_dual_d) CUDA_CHECK(cudaFree(state->ones_dual_d)); + + if (state->graph_created) { + CUDA_CHECK(cudaGraphExecDestroy(state->graph_instance)); + CUDA_CHECK(cudaGraphDestroy(state->graph)); + } + CUDA_CHECK(cudaStreamDestroy(state->stream)); + CUDA_CHECK(cudaFree(state->d_halpern_weight)); free(state); } @@ -1264,7 +1417,7 @@ static pdhg_solver_state_t *initialize_dual_feas_polish_state( { \ int threads = 256; \ int blocks = (n + threads - 1) / threads; \ - zero_finite_value_vectors_kernel<<>>(vec, n); \ + zero_finite_value_vectors_kernel<<stream>>>(vec, n); \ CUDA_CHECK(cudaDeviceSynchronize()); \ } @@ -1410,7 +1563,7 @@ __global__ void compute_delta_dual_solution_kernel( static void compute_primal_fixed_point_error(pdhg_solver_state_t *state) { - compute_delta_primal_solution_kernel<<num_blocks_primal, THREADS_PER_BLOCK>>>( + compute_delta_primal_solution_kernel<<num_blocks_primal, THREADS_PER_BLOCK, 0, state->stream>>>( state->current_primal_solution, state->reflected_primal_solution, state->delta_primal_solution, @@ -1426,7 +1579,7 @@ static void compute_primal_fixed_point_error(pdhg_solver_state_t *state) static void compute_dual_fixed_point_error(pdhg_solver_state_t *state) { - compute_delta_dual_solution_kernel<<num_blocks_dual, THREADS_PER_BLOCK>>>( + compute_delta_dual_solution_kernel<<num_blocks_dual, THREADS_PER_BLOCK, 0, state->stream>>>( state->current_dual_solution, state->reflected_dual_solution, state->delta_dual_solution, diff --git a/src/utils.cu b/src/utils.cu index 96ae8d2..1e60d3b 100644 --- a/src/utils.cu +++ b/src/utils.cu @@ -585,7 +585,7 @@ void compute_residual(pdhg_solver_state_t *state) state->matAt, state->vec_dual_sol, &HOST_ZERO, state->vec_dual_prod, CUDA_R_64F, CUSPARSE_SPMV_CSR_ALG2, state->dual_spmv_buffer)); - compute_residual_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK>>>( + compute_residual_kernel<<num_blocks_primal_dual, THREADS_PER_BLOCK, 0, state->stream>>>( state->primal_residual, state->primal_product, state->constraint_lower_bound, state->constraint_upper_bound, state->pdhg_dual_solution, state->dual_residual, state->dual_product,