diff options
| author | Robin Voetter <robin@voetter.nl> | 2021-10-03 16:03:43 +0200 |
|---|---|---|
| committer | Robin Voetter <robin@voetter.nl> | 2021-10-04 11:25:29 +0200 |
| commit | 41e9c1bac1c447fe42a191bf16ee25ddb3bba97a (patch) | |
| tree | 600df607201912d324354f1444be3711f2dcddc8 /lib/std/math/big | |
| parent | 5907b3e3830f95d111d9d60027af1350b81f4378 (diff) | |
| download | zig-41e9c1bac1c447fe42a191bf16ee25ddb3bba97a.tar.gz zig-41e9c1bac1c447fe42a191bf16ee25ddb3bba97a.zip | |
big ints: Allow llmulaccum to wrap
Diffstat (limited to 'lib/std/math/big')
| -rw-r--r-- | lib/std/math/big/int.zig | 122 |
1 files changed, 82 insertions, 40 deletions
diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index 90961db3dc..bfcef46bce 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -658,7 +658,7 @@ pub const Mutable = struct { mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0); - llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs); + _ = llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs); rma.normalize(a.limbs.len + b.limbs.len); rma.positive = (a.positive == b.positive); @@ -2365,9 +2365,12 @@ const AccOp = enum { /// /// r = r (op) a * b /// r MUST NOT alias any of a or b. +/// +/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs. fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void { @setRuntimeSafety(debug_safety); - assert(r.len >= a.len + b.len); + assert(r.len >= a.len); + assert(r.len >= b.len); // Order greatest first. var x = a; @@ -2395,6 +2398,8 @@ fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []cons /// /// r = r (op) a * b /// r MUST NOT alias any of a or b. +/// +/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs. fn llmulaccKaratsuba( comptime op: AccOp, allocator: *Allocator, @@ -2403,7 +2408,7 @@ fn llmulaccKaratsuba( b: []const Limb, ) error{OutOfMemory}!void { @setRuntimeSafety(debug_safety); - assert(r.len >= a.len + b.len); + assert(r.len >= a.len); assert(a.len >= b.len); // Classical karatsuba algorithm: @@ -2437,49 +2442,84 @@ fn llmulaccKaratsuba( // // Note, when B is a multiple of the limb size, multiplies by B amount to shifts or // slices of a limbs array. + // + // This function computes the result of the multiplication modulo r.len. This means: + // - p2 and p1 only need to be computed modulo r.len - B. + // - In the case of p2, p2 * B^2 needs to be added modulo r.len - 2 * B. const split = b.len / 2; // B + + const limbs_after_split = r.len - split; // Limbs to compute for p1 and p2. + const limbs_after_split2 = r.len - split * 2; // Limbs to add for p2 * B^2. + + // For a0 and b0 we need the full range. const a0 = a[0..llnormalize(a[0..split])]; - const a1 = a[split..][0..llnormalize(a[split..])]; const b0 = b[0..llnormalize(b[0..split])]; - const b1 = b[split..][0..llnormalize(b[split..])]; - // Note that the above slices work because we have a.len > b.len. - // We now also have: - // a1.len >= a0.len - // a1.len >= b1.len >= b0.len - // a0.len == b0.len + // For a1 and b1 we only need `limbs_after_split` limbs. + const a1 = blk: { + var a1 = a[split..]; + a1.len = math.min(llnormalize(a1), limbs_after_split); + break :blk a1; + }; + + const b1 = blk: { + var b1 = b[split..]; + b1.len = math.min(llnormalize(b1), limbs_after_split); + break :blk b1; + }; + + // Note that the above slices relative to `split` work because we have a.len > b.len. // We need some temporary memory to store intermediate results. // Note, we can reduce the amount of temporaries we need by reordering the computation here: // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0 // = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0 // = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B - // By allocating a1.len * b1.len we can be sure that all the intermediary results fit. + + // Allocate at least enough memory to be able to multiply the upper two segments of a and b, assuming + // no overflow. const tmp = try allocator.alloc(Limb, a.len - split + b.len - split); defer allocator.free(tmp); // Compute p2. - mem.set(Limb, tmp, 0); - llmulacc(.add, allocator, tmp, a1, b1); - const p2 = tmp[0 .. llnormalize(tmp)]; + // Note, we don't need to compute all of p2, just enough limbs to satisfy r. + const p2_limbs = math.min(limbs_after_split, a1.len + b1.len); - // Add terms p2 * B^2 and p2 * B to the result. - _ = llaccum(op, r[split..], p2); - _ = llaccum(op, r[split * 2..], p2); + mem.set(Limb, tmp[0..p2_limbs], 0); + llmulacc(.add, allocator, tmp[0..p2_limbs], a1[0..math.min(a1.len, p2_limbs)], b1[0..math.min(b1.len, p2_limbs)]); + const p2 = tmp[0 .. llnormalize(tmp[0..p2_limbs])]; + + // Add p2 * B to the result. + llaccum(op, r[split..], p2); + + // Add p2 * B^2 to the result if required. + if (limbs_after_split2 > 0) { + llaccum(op, r[split * 2..], p2[0..math.min(p2.len, limbs_after_split2)]); + } // Compute p0. - mem.set(Limb, p2, 0); - llmulacc(.add, allocator, tmp, a0, b0); - const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])]; + // Since a0.len, b0.len <= split and r.len >= split * 2, the full width of p0 needs to be computed. + const p0_limbs = a0.len + b0.len; + mem.set(Limb, tmp[0..p0_limbs], 0); + llmulacc(.add, allocator, tmp[0..p0_limbs], a0, b0); + const p0 = tmp[0 .. llnormalize(tmp[0..p0_limbs])]; + + // Add p0 to the result. + llaccum(op, r, p0); + + // Add p0 * B to the result. In this case, we may not need all of it. + llaccum(op, r[split..], p0[0..math.min(limbs_after_split, p0.len)]); - // Add terms p0 * B and p0 to the result. - _ = llaccum(op, r, p0); - _ = llaccum(op, r[split..], p0); // Finally, compute and add p1. - const j0_sign = llcmp(a0, a1); - const j1_sign = llcmp(b1, b0); + // From now on we only need `limbs_after_split` limbs for a0 and b0, since the result of the + // following computation will be added * B. + const a0x = a0[0..std.math.min(a0.len, limbs_after_split)]; + const b0x = b0[0..std.math.min(b0.len, limbs_after_split)]; + + const j0_sign = llcmp(a0x, a1); + const j1_sign = llcmp(b1, b0x); if (j0_sign * j1_sign == 0) { // p1 is zero, we don't need to do any computation at all. @@ -2492,24 +2532,24 @@ fn llmulaccKaratsuba( // Note that in this case, we again need some storage for intermediary results // j0 and j1. Since we have tmp.len >= 2B, we can store both // intermediaries in the already allocated array. - const j0 = tmp[0..a1.len]; - const j1 = tmp[a1.len..]; + const j0 = tmp[0..a.len - split]; + const j1 = tmp[a.len - split..]; // Ensure that no subtraction overflows. if (j0_sign == 1) { // a0 > a1. - _ = llsubcarry(j0, a0, a1); + _ = llsubcarry(j0, a0x, a1); } else { // a0 < a1. - _ = llsubcarry(j0, a1, a0); + _ = llsubcarry(j0, a1, a0x); } if (j1_sign == 1) { // b1 > b0. - _ = llsubcarry(j1, b1, b0); + _ = llsubcarry(j1, b1, b0x); } else { // b1 > b0. - _ = llsubcarry(j1, b0, b1); + _ = llsubcarry(j1, b0x, b1); } if (j0_sign * j1_sign == 1) { @@ -2528,11 +2568,13 @@ fn llmulaccKaratsuba( } } -// r = r (op) a -fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb { +/// r = r (op) a. +/// The result is computed modulo `r.len`. +fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) void { @setRuntimeSafety(debug_safety); if (op == .sub) { - return llsubcarry(r, r, a); + _ = llsubcarry(r, r, a); + return; } assert(r.len != 0 and a.len != 0); @@ -2551,8 +2593,6 @@ fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb { while ((carry != 0) and i < r.len) : (i += 1) { carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i])); } - - return carry; } /// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs. @@ -2583,19 +2623,21 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 { } } -// r = r (op) y * xi +/// r = r (op) y * xi +/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs. fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void { @setRuntimeSafety(debug_safety); assert(r.len >= a.len + b.len); assert(a.len >= b.len); var i: usize = 0; - while (i < a.len) : (i += 1) { - llmulLimb(op, r[i..], b, a[i]); + while (i < b.len) : (i += 1) { + llmulLimb(op, r[i..], a, b[i]); } } -// r = r (op) y * xi +/// r = r (op) y * xi +/// The result is computed modulo `r.len`. fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void { @setRuntimeSafety(debug_safety); if (xi == 0) { |
