diff --git a/crates/recursion/circuit/src/challenger.rs b/crates/recursion/circuit/src/challenger.rs index 4e902bf928..13602835a7 100644 --- a/crates/recursion/circuit/src/challenger.rs +++ b/crates/recursion/circuit/src/challenger.rs @@ -115,8 +115,7 @@ pub fn reduce_32(builder: &mut Builder, vals: &[Felt]) -> Va let mut power = C::N::one(); let result: Var = builder.eval(C::N::zero()); for val in vals.iter() { - let bits = builder.num2bits_f_circuit(*val); - let val = builder.bits2num_v_circuit(&bits); + let val = builder.felt2var_circuit(*val); builder.assign(result, result + val * power); power *= C::N::from_canonical_u64(1u64 << 32); } diff --git a/crates/recursion/circuit/src/fri.rs b/crates/recursion/circuit/src/fri.rs index ea6cb30c2b..f7bf84ac04 100644 --- a/crates/recursion/circuit/src/fri.rs +++ b/crates/recursion/circuit/src/fri.rs @@ -204,6 +204,10 @@ pub fn verify_query( let index_sibling: Var<_> = builder.eval(one - index_bits.clone()[offset]); let index_pair = &index_bits[(offset + 1)..]; + // Reduce folded_eval (mod the BabyBear prime) since it gets used multiple times below and + // the reductions will be repeated. + builder.reduce_e(folded_eval); + let evals_ext = [ builder.select_ef(index_sibling, folded_eval, step.sibling_value), builder.select_ef(index_sibling, step.sibling_value, folded_eval), diff --git a/crates/recursion/compiler/src/constraints/mod.rs b/crates/recursion/compiler/src/constraints/mod.rs index 0a5d065e62..23e419e11e 100644 --- a/crates/recursion/compiler/src/constraints/mod.rs +++ b/crates/recursion/compiler/src/constraints/mod.rs @@ -358,6 +358,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::ReduceE, args: vec![vec![a.id()]], }), + DslIr::CircuitFelt2Var(a, b) => constraints.push(Constraint { + opcode: ConstraintOpcode::CircuitFelt2Var, + args: vec![vec![b.id()], vec![a.id()]], + }), _ => panic!("unsupported {:?}", instruction), }; } diff --git a/crates/recursion/compiler/src/constraints/opcodes.rs b/crates/recursion/compiler/src/constraints/opcodes.rs index 02ed47eea5..edb6b1c2e0 100644 --- a/crates/recursion/compiler/src/constraints/opcodes.rs +++ b/crates/recursion/compiler/src/constraints/opcodes.rs @@ -46,6 +46,7 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + CircuitFelt2Var, PermuteBabyBear, ReduceE, } diff --git a/crates/recursion/compiler/src/ir/builder.rs b/crates/recursion/compiler/src/ir/builder.rs index a695b5818b..a1e2753c8e 100644 --- a/crates/recursion/compiler/src/ir/builder.rs +++ b/crates/recursion/compiler/src/ir/builder.rs @@ -474,6 +474,12 @@ impl Builder { self.operations.push(DslIr::ReduceE(ext)); } + pub fn felt2var_circuit(&mut self, felt: Felt) -> Var { + let var = self.uninit(); + self.operations.push(DslIr::CircuitFelt2Var(felt, var)); + var + } + pub fn cycle_tracker(&mut self, name: &str) { self.operations.push(DslIr::CycleTracker(name.to_string())); } diff --git a/crates/recursion/compiler/src/ir/instructions.rs b/crates/recursion/compiler/src/ir/instructions.rs index d2ad86dd60..4e63afd31f 100644 --- a/crates/recursion/compiler/src/ir/instructions.rs +++ b/crates/recursion/compiler/src/ir/instructions.rs @@ -193,6 +193,8 @@ pub enum DslIr { /// Decompose a field element into bits (bits = num2bits(felt)). Should only be used when /// target is a gnark circuit. CircuitNum2BitsF(Felt, Vec>), + /// Convert a Felt to a Var in a circuit. Avoids decomposing to bits and then reconstructing. + CircuitFelt2Var(Felt, Var), // Hashing. /// Permutes an array of baby bear elements using Poseidon2 (output = p2_permute(array)). diff --git a/crates/recursion/gnark-ffi/go/sp1/babybear/babybear.go b/crates/recursion/gnark-ffi/go/sp1/babybear/babybear.go index 8259d486c7..9e14da5608 100644 --- a/crates/recursion/gnark-ffi/go/sp1/babybear/babybear.go +++ b/crates/recursion/gnark-ffi/go/sp1/babybear/babybear.go @@ -53,7 +53,19 @@ func Zero() Variable { } } +func One() Variable { + return Variable{ + Value: frontend.Variable("1"), + NbBits: 1, + } +} + func NewF(value string) Variable { + if value == "0" { + return Zero() + } else if value == "1" { + return One() + } return Variable{ Value: frontend.Variable(value), NbBits: 31, @@ -105,7 +117,7 @@ func (c *Chip) MulFConst(a Variable, b int) Variable { } func (c *Chip) negF(a Variable) Variable { - if a.NbBits == 31 { + if a.NbBits <= 31 { return Variable{Value: c.api.Sub(modulus, a.Value), NbBits: 31} } @@ -283,7 +295,7 @@ func (p *Chip) reduceFast(x Variable) Variable { } func (p *Chip) ReduceSlow(x Variable) Variable { - if x.NbBits == 31 { + if x.NbBits <= 31 { return x } return Variable{ diff --git a/crates/recursion/gnark-ffi/go/sp1/sp1.go b/crates/recursion/gnark-ffi/go/sp1/sp1.go index eebef363b1..ac83d83624 100644 --- a/crates/recursion/gnark-ffi/go/sp1/sp1.go +++ b/crates/recursion/gnark-ffi/go/sp1/sp1.go @@ -201,6 +201,8 @@ func (circuit *Circuit) Define(api frontend.API) error { api.AssertIsEqual(circuit.CommitedValuesDigest, element) case "CircuitFelts2Ext": exts[cs.Args[0][0]] = babybear.Felts2Ext(felts[cs.Args[1][0]], felts[cs.Args[2][0]], felts[cs.Args[3][0]], felts[cs.Args[4][0]]) + case "CircuitFelt2Var": + vars[cs.Args[0][0]] = fieldAPI.ReduceSlow(felts[cs.Args[1][0]]).Value case "ReduceE": exts[cs.Args[0][0]] = fieldAPI.ReduceE(exts[cs.Args[0][0]]) default: