diff options
Diffstat (limited to 'lib/std/math/sqrt.zig')
| -rw-r--r-- | lib/std/math/sqrt.zig | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/lib/std/math/sqrt.zig b/lib/std/math/sqrt.zig new file mode 100644 index 0000000000..30af5915d4 --- /dev/null +++ b/lib/std/math/sqrt.zig @@ -0,0 +1,136 @@ +const std = @import("../std.zig"); +const math = std.math; +const expect = std.testing.expect; +const builtin = @import("builtin"); +const TypeId = builtin.TypeId; +const maxInt = std.math.maxInt; + +/// Returns the square root of x. +/// +/// Special Cases: +/// - sqrt(+inf) = +inf +/// - sqrt(+-0) = +-0 +/// - sqrt(x) = nan if x < 0 +/// - sqrt(nan) = nan +pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typeOf(x).bit_count / 2) else @typeOf(x)) { + const T = @typeOf(x); + switch (@typeId(T)) { + TypeId.ComptimeFloat => return T(@sqrt(f64, x)), // TODO upgrade to f128 + TypeId.Float => return @sqrt(T, x), + TypeId.ComptimeInt => comptime { + if (x > maxInt(u128)) { + @compileError("sqrt not implemented for comptime_int greater than 128 bits"); + } + if (x < 0) { + @compileError("sqrt on negative number"); + } + return T(sqrt_int(u128, x)); + }, + TypeId.Int => return sqrt_int(T, x), + else => @compileError("sqrt not implemented for " ++ @typeName(T)), + } +} + +test "math.sqrt" { + expect(sqrt(f16(0.0)) == @sqrt(f16, 0.0)); + expect(sqrt(f32(0.0)) == @sqrt(f32, 0.0)); + expect(sqrt(f64(0.0)) == @sqrt(f64, 0.0)); +} + +test "math.sqrt16" { + const epsilon = 0.000001; + + expect(@sqrt(f16, 0.0) == 0.0); + expect(math.approxEq(f16, @sqrt(f16, 2.0), 1.414214, epsilon)); + expect(math.approxEq(f16, @sqrt(f16, 3.6), 1.897367, epsilon)); + expect(@sqrt(f16, 4.0) == 2.0); + expect(math.approxEq(f16, @sqrt(f16, 7.539840), 2.745877, epsilon)); + expect(math.approxEq(f16, @sqrt(f16, 19.230934), 4.385309, epsilon)); + expect(@sqrt(f16, 64.0) == 8.0); + expect(math.approxEq(f16, @sqrt(f16, 64.1), 8.006248, epsilon)); + expect(math.approxEq(f16, @sqrt(f16, 8942.230469), 94.563370, epsilon)); +} + +test "math.sqrt32" { + const epsilon = 0.000001; + + expect(@sqrt(f32, 0.0) == 0.0); + expect(math.approxEq(f32, @sqrt(f32, 2.0), 1.414214, epsilon)); + expect(math.approxEq(f32, @sqrt(f32, 3.6), 1.897367, epsilon)); + expect(@sqrt(f32, 4.0) == 2.0); + expect(math.approxEq(f32, @sqrt(f32, 7.539840), 2.745877, epsilon)); + expect(math.approxEq(f32, @sqrt(f32, 19.230934), 4.385309, epsilon)); + expect(@sqrt(f32, 64.0) == 8.0); + expect(math.approxEq(f32, @sqrt(f32, 64.1), 8.006248, epsilon)); + expect(math.approxEq(f32, @sqrt(f32, 8942.230469), 94.563370, epsilon)); +} + +test "math.sqrt64" { + const epsilon = 0.000001; + + expect(@sqrt(f64, 0.0) == 0.0); + expect(math.approxEq(f64, @sqrt(f64, 2.0), 1.414214, epsilon)); + expect(math.approxEq(f64, @sqrt(f64, 3.6), 1.897367, epsilon)); + expect(@sqrt(f64, 4.0) == 2.0); + expect(math.approxEq(f64, @sqrt(f64, 7.539840), 2.745877, epsilon)); + expect(math.approxEq(f64, @sqrt(f64, 19.230934), 4.385309, epsilon)); + expect(@sqrt(f64, 64.0) == 8.0); + expect(math.approxEq(f64, @sqrt(f64, 64.1), 8.006248, epsilon)); + expect(math.approxEq(f64, @sqrt(f64, 8942.230469), 94.563367, epsilon)); +} + +test "math.sqrt16.special" { + expect(math.isPositiveInf(@sqrt(f16, math.inf(f16)))); + expect(@sqrt(f16, 0.0) == 0.0); + expect(@sqrt(f16, -0.0) == -0.0); + expect(math.isNan(@sqrt(f16, -1.0))); + expect(math.isNan(@sqrt(f16, math.nan(f16)))); +} + +test "math.sqrt32.special" { + expect(math.isPositiveInf(@sqrt(f32, math.inf(f32)))); + expect(@sqrt(f32, 0.0) == 0.0); + expect(@sqrt(f32, -0.0) == -0.0); + expect(math.isNan(@sqrt(f32, -1.0))); + expect(math.isNan(@sqrt(f32, math.nan(f32)))); +} + +test "math.sqrt64.special" { + expect(math.isPositiveInf(@sqrt(f64, math.inf(f64)))); + expect(@sqrt(f64, 0.0) == 0.0); + expect(@sqrt(f64, -0.0) == -0.0); + expect(math.isNan(@sqrt(f64, -1.0))); + expect(math.isNan(@sqrt(f64, math.nan(f64)))); +} + +fn sqrt_int(comptime T: type, value: T) @IntType(false, T.bit_count / 2) { + var op = value; + var res: T = 0; + var one: T = 1 << (T.bit_count - 2); + + // "one" starts at the highest power of four <= than the argument. + while (one > op) { + one >>= 2; + } + + while (one != 0) { + if (op >= res + one) { + op -= res + one; + res += 2 * one; + } + res >>= 1; + one >>= 2; + } + + const ResultType = @IntType(false, T.bit_count / 2); + return @intCast(ResultType, res); +} + +test "math.sqrt_int" { + expect(sqrt_int(u32, 3) == 1); + expect(sqrt_int(u32, 4) == 2); + expect(sqrt_int(u32, 5) == 2); + expect(sqrt_int(u32, 8) == 2); + expect(sqrt_int(u32, 9) == 3); + expect(sqrt_int(u32, 10) == 3); +} |
