aboutsummaryrefslogtreecommitdiff
path: root/std/math/sqrt.zig
diff options
context:
space:
mode:
Diffstat (limited to 'std/math/sqrt.zig')
-rw-r--r--std/math/sqrt.zig318
1 files changed, 56 insertions, 262 deletions
diff --git a/std/math/sqrt.zig b/std/math/sqrt.zig
index 690f8b6901..e12ecf9683 100644
--- a/std/math/sqrt.zig
+++ b/std/math/sqrt.zig
@@ -14,27 +14,9 @@ const TypeId = builtin.TypeId;
pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typeOf(x).bit_count / 2) else @typeOf(x)) {
const T = @typeOf(x);
switch (@typeId(T)) {
- TypeId.FloatLiteral => {
- return T(sqrt64(x));
- },
- TypeId.Float => {
- switch (T) {
- f32 => {
- switch (builtin.arch) {
- builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt32(x),
- else => return sqrt32(x),
- }
- },
- f64 => {
- switch (builtin.arch) {
- builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt64(x),
- else => return sqrt64(x),
- }
- },
- else => @compileError("sqrt not implemented for " ++ @typeName(T)),
- }
- },
- TypeId.IntLiteral => comptime {
+ TypeId.ComptimeFloat => return T(@sqrt(f64, x)), // TODO upgrade to f128
+ TypeId.Float => return @sqrt(T, x),
+ TypeId.ComptimeInt => comptime {
if (x > @maxValue(u128)) {
@compileError("sqrt not implemented for comptime_int greater than 128 bits");
}
@@ -43,269 +25,81 @@ pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typ
}
return T(sqrt_int(u128, x));
},
- TypeId.Int => {
- return sqrt_int(T, x);
- },
+ TypeId.Int => return sqrt_int(T, x),
else => @compileError("sqrt not implemented for " ++ @typeName(T)),
}
}
-fn sqrt32(x: f32) f32 {
- const tiny: f32 = 1.0e-30;
- const sign: i32 = @bitCast(i32, u32(0x80000000));
- var ix: i32 = @bitCast(i32, x);
-
- if ((ix & 0x7F800000) == 0x7F800000) {
- return x * x + x; // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = snan
- }
-
- // zero
- if (ix <= 0) {
- if (ix & ~sign == 0) {
- return x; // sqrt (+-0) = +-0
- }
- if (ix < 0) {
- return math.snan(f32);
- }
- }
-
- // normalize
- var m = ix >> 23;
- if (m == 0) {
- // subnormal
- var i: i32 = 0;
- while (ix & 0x00800000 == 0) : (i += 1) {
- ix <<= 1;
- }
- m -= i - 1;
- }
-
- m -= 127; // unbias exponent
- ix = (ix & 0x007FFFFF) | 0x00800000;
-
- if (m & 1 != 0) { // odd m, double x to even
- ix += ix;
- }
-
- m >>= 1; // m = [m / 2]
-
- // sqrt(x) bit by bit
- ix += ix;
- var q: i32 = 0; // q = sqrt(x)
- var s: i32 = 0;
- var r: i32 = 0x01000000; // r = moving bit right -> left
-
- while (r != 0) {
- const t = s + r;
- if (t <= ix) {
- s = t + r;
- ix -= t;
- q += r;
- }
- ix += ix;
- r >>= 1;
- }
-
- // floating add to find rounding direction
- if (ix != 0) {
- var z = 1.0 - tiny; // inexact
- if (z >= 1.0) {
- z = 1.0 + tiny;
- if (z > 1.0) {
- q += 2;
- } else {
- if (q & 1 != 0) {
- q += 1;
- }
- }
- }
- }
-
- ix = (q >> 1) + 0x3f000000;
- ix += m << 23;
- return @bitCast(f32, ix);
+test "math.sqrt" {
+ assert(sqrt(f16(0.0)) == @sqrt(f16, 0.0));
+ assert(sqrt(f32(0.0)) == @sqrt(f32, 0.0));
+ assert(sqrt(f64(0.0)) == @sqrt(f64, 0.0));
}
-// NOTE: The original code is full of implicit signed -> unsigned assumptions and u32 wraparound
-// behaviour. Most intermediate i32 values are changed to u32 where appropriate but there are
-// potentially some edge cases remaining that are not handled in the same way.
-fn sqrt64(x: f64) f64 {
- const tiny: f64 = 1.0e-300;
- const sign: u32 = 0x80000000;
- const u = @bitCast(u64, x);
-
- var ix0 = u32(u >> 32);
- var ix1 = u32(u & 0xFFFFFFFF);
-
- // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = nan
- if (ix0 & 0x7FF00000 == 0x7FF00000) {
- return x * x + x;
- }
-
- // sqrt(+-0) = +-0
- if (x == 0.0) {
- return x;
- }
- // sqrt(-ve) = snan
- if (ix0 & sign != 0) {
- return math.snan(f64);
- }
-
- // normalize x
- var m = i32(ix0 >> 20);
- if (m == 0) {
- // subnormal
- while (ix0 == 0) {
- m -= 21;
- ix0 |= ix1 >> 11;
- ix1 <<= 21;
- }
-
- // subnormal
- var i: u32 = 0;
- while (ix0 & 0x00100000 == 0) : (i += 1) {
- ix0 <<= 1;
- }
- m -= i32(i) - 1;
- ix0 |= ix1 >> u5(32 - i);
- ix1 <<= u5(i);
- }
-
- // unbias exponent
- m -= 1023;
- ix0 = (ix0 & 0x000FFFFF) | 0x00100000;
- if (m & 1 != 0) {
- ix0 += ix0 + (ix1 >> 31);
- ix1 = ix1 +% ix1;
- }
- m >>= 1;
-
- // sqrt(x) bit by bit
- ix0 += ix0 + (ix1 >> 31);
- ix1 = ix1 +% ix1;
-
- var q: u32 = 0;
- var q1: u32 = 0;
- var s0: u32 = 0;
- var s1: u32 = 0;
- var r: u32 = 0x00200000;
- var t: u32 = undefined;
- var t1: u32 = undefined;
-
- while (r != 0) {
- t = s0 +% r;
- if (t <= ix0) {
- s0 = t + r;
- ix0 -= t;
- q += r;
- }
- ix0 = ix0 +% ix0 +% (ix1 >> 31);
- ix1 = ix1 +% ix1;
- r >>= 1;
- }
-
- r = sign;
- while (r != 0) {
- t = s1 +% r;
- t = s0;
- if (t < ix0 or (t == ix0 and t1 <= ix1)) {
- s1 = t1 +% r;
- if (t1 & sign == sign and s1 & sign == 0) {
- s0 += 1;
- }
- ix0 -= t;
- if (ix1 < t1) {
- ix0 -= 1;
- }
- ix1 = ix1 -% t1;
- q1 += r;
- }
- ix0 = ix0 +% ix0 +% (ix1 >> 31);
- ix1 = ix1 +% ix1;
- r >>= 1;
- }
-
- // rounding direction
- if (ix0 | ix1 != 0) {
- var z = 1.0 - tiny; // raise inexact
- if (z >= 1.0) {
- z = 1.0 + tiny;
- if (q1 == 0xFFFFFFFF) {
- q1 = 0;
- q += 1;
- } else if (z > 1.0) {
- if (q1 == 0xFFFFFFFE) {
- q += 1;
- }
- q1 += 2;
- } else {
- q1 += q1 & 1;
- }
- }
- }
-
- ix0 = (q >> 1) + 0x3FE00000;
- ix1 = q1 >> 1;
- if (q & 1 != 0) {
- ix1 |= 0x80000000;
- }
-
- // NOTE: musl here appears to rely on signed twos-complement wraparound. +% has the same
- // behaviour at least.
- var iix0 = i32(ix0);
- iix0 = iix0 +% (m << 20);
+test "math.sqrt16" {
+ const epsilon = 0.000001;
- const uz = (u64(iix0) << 32) | ix1;
- return @bitCast(f64, uz);
-}
-
-test "math.sqrt" {
- assert(sqrt(f32(0.0)) == sqrt32(0.0));
- assert(sqrt(f64(0.0)) == sqrt64(0.0));
+ assert(@sqrt(f16, 0.0) == 0.0);
+ assert(math.approxEq(f16, @sqrt(f16, 2.0), 1.414214, epsilon));
+ assert(math.approxEq(f16, @sqrt(f16, 3.6), 1.897367, epsilon));
+ assert(@sqrt(f16, 4.0) == 2.0);
+ assert(math.approxEq(f16, @sqrt(f16, 7.539840), 2.745877, epsilon));
+ assert(math.approxEq(f16, @sqrt(f16, 19.230934), 4.385309, epsilon));
+ assert(@sqrt(f16, 64.0) == 8.0);
+ assert(math.approxEq(f16, @sqrt(f16, 64.1), 8.006248, epsilon));
+ assert(math.approxEq(f16, @sqrt(f16, 8942.230469), 94.563370, epsilon));
}
test "math.sqrt32" {
const epsilon = 0.000001;
- assert(sqrt32(0.0) == 0.0);
- assert(math.approxEq(f32, sqrt32(2.0), 1.414214, epsilon));
- assert(math.approxEq(f32, sqrt32(3.6), 1.897367, epsilon));
- assert(sqrt32(4.0) == 2.0);
- assert(math.approxEq(f32, sqrt32(7.539840), 2.745877, epsilon));
- assert(math.approxEq(f32, sqrt32(19.230934), 4.385309, epsilon));
- assert(sqrt32(64.0) == 8.0);
- assert(math.approxEq(f32, sqrt32(64.1), 8.006248, epsilon));
- assert(math.approxEq(f32, sqrt32(8942.230469), 94.563370, epsilon));
+ assert(@sqrt(f32, 0.0) == 0.0);
+ assert(math.approxEq(f32, @sqrt(f32, 2.0), 1.414214, epsilon));
+ assert(math.approxEq(f32, @sqrt(f32, 3.6), 1.897367, epsilon));
+ assert(@sqrt(f32, 4.0) == 2.0);
+ assert(math.approxEq(f32, @sqrt(f32, 7.539840), 2.745877, epsilon));
+ assert(math.approxEq(f32, @sqrt(f32, 19.230934), 4.385309, epsilon));
+ assert(@sqrt(f32, 64.0) == 8.0);
+ assert(math.approxEq(f32, @sqrt(f32, 64.1), 8.006248, epsilon));
+ assert(math.approxEq(f32, @sqrt(f32, 8942.230469), 94.563370, epsilon));
}
test "math.sqrt64" {
const epsilon = 0.000001;
- assert(sqrt64(0.0) == 0.0);
- assert(math.approxEq(f64, sqrt64(2.0), 1.414214, epsilon));
- assert(math.approxEq(f64, sqrt64(3.6), 1.897367, epsilon));
- assert(sqrt64(4.0) == 2.0);
- assert(math.approxEq(f64, sqrt64(7.539840), 2.745877, epsilon));
- assert(math.approxEq(f64, sqrt64(19.230934), 4.385309, epsilon));
- assert(sqrt64(64.0) == 8.0);
- assert(math.approxEq(f64, sqrt64(64.1), 8.006248, epsilon));
- assert(math.approxEq(f64, sqrt64(8942.230469), 94.563367, epsilon));
+ assert(@sqrt(f64, 0.0) == 0.0);
+ assert(math.approxEq(f64, @sqrt(f64, 2.0), 1.414214, epsilon));
+ assert(math.approxEq(f64, @sqrt(f64, 3.6), 1.897367, epsilon));
+ assert(@sqrt(f64, 4.0) == 2.0);
+ assert(math.approxEq(f64, @sqrt(f64, 7.539840), 2.745877, epsilon));
+ assert(math.approxEq(f64, @sqrt(f64, 19.230934), 4.385309, epsilon));
+ assert(@sqrt(f64, 64.0) == 8.0);
+ assert(math.approxEq(f64, @sqrt(f64, 64.1), 8.006248, epsilon));
+ assert(math.approxEq(f64, @sqrt(f64, 8942.230469), 94.563367, epsilon));
+}
+
+test "math.sqrt16.special" {
+ assert(math.isPositiveInf(@sqrt(f16, math.inf(f16))));
+ assert(@sqrt(f16, 0.0) == 0.0);
+ assert(@sqrt(f16, -0.0) == -0.0);
+ assert(math.isNan(@sqrt(f16, -1.0)));
+ assert(math.isNan(@sqrt(f16, math.nan(f16))));
}
test "math.sqrt32.special" {
- assert(math.isPositiveInf(sqrt32(math.inf(f32))));
- assert(sqrt32(0.0) == 0.0);
- assert(sqrt32(-0.0) == -0.0);
- assert(math.isNan(sqrt32(-1.0)));
- assert(math.isNan(sqrt32(math.nan(f32))));
+ assert(math.isPositiveInf(@sqrt(f32, math.inf(f32))));
+ assert(@sqrt(f32, 0.0) == 0.0);
+ assert(@sqrt(f32, -0.0) == -0.0);
+ assert(math.isNan(@sqrt(f32, -1.0)));
+ assert(math.isNan(@sqrt(f32, math.nan(f32))));
}
test "math.sqrt64.special" {
- assert(math.isPositiveInf(sqrt64(math.inf(f64))));
- assert(sqrt64(0.0) == 0.0);
- assert(sqrt64(-0.0) == -0.0);
- assert(math.isNan(sqrt64(-1.0)));
- assert(math.isNan(sqrt64(math.nan(f64))));
+ assert(math.isPositiveInf(@sqrt(f64, math.inf(f64))));
+ assert(@sqrt(f64, 0.0) == 0.0);
+ assert(@sqrt(f64, -0.0) == -0.0);
+ assert(math.isNan(@sqrt(f64, -1.0)));
+ assert(math.isNan(@sqrt(f64, math.nan(f64))));
}
fn sqrt_int(comptime T: type, value: T) @IntType(false, T.bit_count / 2) {
@@ -328,7 +122,7 @@ fn sqrt_int(comptime T: type, value: T) @IntType(false, T.bit_count / 2) {
}
const ResultType = @IntType(false, T.bit_count / 2);
- return ResultType(res);
+ return @intCast(ResultType, res);
}
test "math.sqrt_int" {