Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extensions/native/circuit/cuda/include/native/sumcheck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ template <typename T> struct HeaderSpecificCols {
template <typename T> struct ProdSpecificCols {
T data_ptr;
T p[EXT_DEG * 2];
MemoryReadAuxCols<T> read_records[1];
T p_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG> write_record;
MemoryWriteAuxCols<T, EXT_DEG * 2> ps_record;
T eval_rlc[EXT_DEG];
};

template <typename T> struct LogupSpecificCols {
T data_ptr;
T pq[EXT_DEG * 4];
MemoryReadAuxCols<T> read_records[1];
T p_evals[EXT_DEG];
T q_evals[EXT_DEG];
MemoryWriteAuxCols<T, EXT_DEG * 4> pqs_record;
MemoryWriteAuxCols<T, EXT_DEG> write_records[2];
T eval_rlc[EXT_DEG];
};
Expand Down
8 changes: 7 additions & 1 deletion extensions/native/circuit/cuda/src/sumcheck.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@

using namespace native;

constexpr uint32_t header_read_records_len() {
return sizeof(((HeaderSpecificCols<uint8_t> *)nullptr)->read_records)
/ sizeof(MemoryReadAuxCols<uint8_t>);
}

__device__ void fill_sumcheck_specific(RowSlice row, MemoryAuxColsFactory &mem_helper) {
RowSlice specific = row.slice_from(COL_INDEX(NativeSumcheckCols, specific));
uint32_t start_timestamp = row[COL_INDEX(NativeSumcheckCols, start_timestamp)].asUInt32();

if (row[COL_INDEX(NativeSumcheckCols, header_row)] == Fp::one()) {
for (uint32_t i = 0; i < 8; ++i) {
constexpr uint32_t header_records = header_read_records_len();
for (uint32_t i = 0; i < header_records; ++i) {
mem_fill_base(
mem_helper,
start_timestamp + i,
Expand Down
7 changes: 3 additions & 4 deletions extensions/native/circuit/src/sumcheck/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ use openvm_circuit::{
native_adapter::util::{memory_read_native, tracing_write_native_inplace},
},
};
use openvm_instructions::{
instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS,
};
use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
use openvm_native_compiler::SumcheckOpcode::SUMCHECK_LAYER_EVAL;
use openvm_stark_backend::p3_field::PrimeField32;

Expand Down Expand Up @@ -227,7 +225,8 @@ where
let mut eval_acc = elem_to_ext(F::from_canonical_u32(0));
let mut alpha_acc = elem_to_ext(F::from_canonical_u32(1));

// all rows share same register values, ctx, challenges, max_round, hint_space_ptrs (optional)
// all rows share same register values, ctx, challenges, max_round, hint_space_ptrs
// (optional)
for row in rows.iter_mut() {
// c1, c2 are same during the entire execution
row.challenges[EXT_DEG..3 * EXT_DEG].copy_from_slice(&challenges[EXT_DEG..3 * EXT_DEG]);
Expand Down
2 changes: 1 addition & 1 deletion extensions/native/circuit/src/sumcheck/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl NativeSumcheckExecutor {
#[inline(always)]
fn pre_compute_impl<F: PrimeField32>(
&self,
pc: u32,
_pc: u32,
inst: &Instruction<F>,
data: &mut NativeSumcheckPreCompute,
) -> Result<(), StaticProgramError> {
Expand Down
1 change: 1 addition & 0 deletions extensions/native/compiler/src/ir/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl<C: Config> Builder<C> {
///
/// 2. for computing expected eval of next layer, output[1+i] = eq(0,r)*p[i][0] + eq(1,r) *
/// p[i][1].
#[allow(clippy::too_many_arguments)]
pub fn sumcheck_layer_eval(
&mut self,
input_ctx: &Array<C, Usize<C::N>>, // Context variables
Expand Down
25 changes: 10 additions & 15 deletions extensions/native/recursion/tests/sumcheck.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::iter::{once, repeat_n};
use std::iter::once;

use openvm_circuit::{arch::instructions::program::Program, utils::air_test_impl};
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -35,12 +35,10 @@ fn test_sumcheck_layer_eval_with_hint_ids() {
let num_logup_specs = 8;

let prod_evals: Vec<E> = (0..(num_prod_specs * num_layers * 2))
.into_iter()
.map(|_| new_rand_ext(&mut rng))
.collect();

let logup_evals: Vec<E> = (0..(num_logup_specs * num_layers * 4))
.into_iter()
.map(|_| new_rand_ext(&mut rng))
.collect();

Expand Down Expand Up @@ -73,13 +71,10 @@ fn test_sumcheck_layer_eval_with_hint_ids() {
standard_fri_params_with_100_bits_conjectured_security(1)
};

let mut input_stream: Vec<Vec<F>> = vec![];
input_stream.push(
prod_evals
.into_iter()
.flat_map(|e| <E as FieldExtensionAlgebra<F>>::as_base_slice(&e).to_vec())
.collect(),
);
let mut input_stream: Vec<Vec<F>> = vec![prod_evals
.into_iter()
.flat_map(|e| <E as FieldExtensionAlgebra<F>>::as_base_slice(&e).to_vec())
.collect()];
input_stream.push(
logup_evals
.into_iter()
Expand Down Expand Up @@ -137,7 +132,7 @@ fn build_test_program<C: Config>(
) {
let mode = 1; // current_layer

let mut ctx_u32s = vec![
let ctx_u32s = vec![
round,
num_prod_specs,
num_logup_specs,
Expand Down Expand Up @@ -175,16 +170,16 @@ fn build_test_program<C: Config>(

let num_prod_evals = num_prod_specs * num_layers * 2;
let prod_spec_evals: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(num_prod_evals);
for idx in 0..num_prod_evals {
let e: Ext<C::F, C::EF> = builder.constant(prod_evals[idx]);
for (idx, prod_eval) in prod_evals.into_iter().enumerate() {
let e: Ext<C::F, C::EF> = builder.constant(prod_eval);

builder.set(&prod_spec_evals, idx, e);
}

let num_logup_evals = num_logup_specs * num_layers * 4;
let logup_spec_evals: Array<C, Ext<C::F, C::EF>> = builder.dyn_array(num_logup_evals);
for idx in 0..num_logup_evals {
let e: Ext<C::F, C::EF> = builder.constant(logup_evals[idx]);
for (idx, logup_eval) in logup_evals.into_iter().enumerate() {
let e: Ext<C::F, C::EF> = builder.constant(logup_eval);

builder.set(&logup_spec_evals, idx, e);
}
Expand Down
Loading