diff options
Diffstat (limited to 'src/codegen.cpp')
| -rw-r--r-- | src/codegen.cpp | 104 |
1 files changed, 73 insertions, 31 deletions
diff --git a/src/codegen.cpp b/src/codegen.cpp index 3dd6995c61..41caa29dbd 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -806,32 +806,47 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu return fn_val; } -static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) { - assert(type_entry->id == ZigTypeIdFloat); +static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id, BuiltinFnId op) { + assert(type_entry->id == ZigTypeIdFloat || + type_entry->id == ZigTypeIdVector); + + bool is_vector = (type_entry->id == ZigTypeIdVector); + ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry; ZigLLVMFnKey key = {}; key.id = fn_id; - key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count; + key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count; + key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0; + key.data.floating.op = op; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) return existing_entry->value; const char *name; - if (fn_id == ZigLLVMFnIdFloor) { - name = "floor"; - } else if (fn_id == ZigLLVMFnIdCeil) { - name = "ceil"; - } else if (fn_id == ZigLLVMFnIdSqrt) { - name = "sqrt"; + uint32_t num_args; + if (fn_id == ZigLLVMFnIdFMA) { + name = "fma"; + num_args = 3; + } else if (fn_id == ZigLLVMFnIdFloatOp) { + name = float_op_to_name(op, true); + num_args = 1; } else { zig_unreachable(); } char fn_name[64]; - sprintf(fn_name, "llvm.%s.f%" ZIG_PRI_usize "", name, type_entry->data.floating.bit_count); + if (is_vector) + sprintf(fn_name, "llvm.%s.v%" PRIu32 "f%" PRIu32, name, key.data.floating.vector_len, key.data.floating.bit_count); + else + sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count); LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry); - LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, &float_type_ref, 1, false); + LLVMTypeRef return_elem_types[3] = { + float_type_ref, + float_type_ref, + float_type_ref, + }; + LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false); LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); assert(LLVMGetIntrinsicID(fn_val)); @@ -2460,22 +2475,17 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, return result; } -static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) +static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) { + if ((op == BuiltinFnIdCeil || + op == BuiltinFnIdFloor) && + type_entry->id == ZigTypeIdInt) return val; + assert(type_entry->id == ZigTypeIdFloat); - LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor); + LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op); return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); } -static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) - return val; - - LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil); - return LLVMBuildCall(g->builder, ceil_fn, &val, 1, ""); -} - enum DivKind { DivKindFloat, DivKindTrunc, @@ -2551,7 +2561,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return result; case DivKindExact: if (want_runtime_safety) { - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); @@ -2573,12 +2583,12 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); LLVMPositionBuilderAtEnd(g->builder, ltz_block); - LLVMValueRef ceiled = gen_ceil(g, result, type_entry); + LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil); LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); LLVMPositionBuilderAtEnd(g->builder, gez_block); - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); @@ -2590,7 +2600,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return phi; } case DivKindFloor: - return gen_floor(g, result, type_entry); + return gen_float_op(g, result, type_entry, BuiltinFnIdFloor); } zig_unreachable(); } @@ -5430,13 +5440,28 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e return nullptr; } -static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) { - LLVMValueRef op = ir_llvm_value(g, instruction->op); +static LLVMValueRef ir_render_float_op(CodeGen *g, IrExecutable *executable, IrInstructionFloatOp *instruction) { + LLVMValueRef op = ir_llvm_value(g, instruction->op1); assert(instruction->base.value.type->id == ZigTypeIdFloat); - LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFloatOp, instruction->op); return LLVMBuildCall(g->builder, fn_val, &op, 1, ""); } +static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrInstructionMulAdd *instruction) { + LLVMValueRef op1 = ir_llvm_value(g, instruction->op1); + LLVMValueRef op2 = ir_llvm_value(g, instruction->op2); + LLVMValueRef op3 = ir_llvm_value(g, instruction->op3); + assert(instruction->base.value.type->id == ZigTypeIdFloat || + instruction->base.value.type->id == ZigTypeIdVector); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd); + LLVMValueRef args[3] = { + op1, + op2, + op3, + }; + return LLVMBuildCall(g->builder, fn_val, args, 3, ""); +} + static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) { LLVMValueRef op = ir_llvm_value(g, instruction->op); ZigType *int_type = instruction->base.value.type; @@ -5779,8 +5804,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction); - case IrInstructionIdSqrt: - return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction); + case IrInstructionIdFloatOp: + return ir_render_float_op(g, executable, (IrInstructionFloatOp *)instruction); + case IrInstructionIdMulAdd: + return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction); case IrInstructionIdArrayToVector: return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction); case IrInstructionIdVectorToArray: @@ -7398,6 +7425,21 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2); + create_builtin_fn(g, BuiltinFnIdSin, "sin", 2); + create_builtin_fn(g, BuiltinFnIdCos, "cos", 2); + create_builtin_fn(g, BuiltinFnIdExp, "exp", 2); + create_builtin_fn(g, BuiltinFnIdExp2, "exp2", 2); + create_builtin_fn(g, BuiltinFnIdLn, "ln", 2); + create_builtin_fn(g, BuiltinFnIdLog2, "log2", 2); + create_builtin_fn(g, BuiltinFnIdLog10, "log10", 2); + create_builtin_fn(g, BuiltinFnIdFabs, "fabs", 2); + create_builtin_fn(g, BuiltinFnIdFloor, "floor", 2); + create_builtin_fn(g, BuiltinFnIdCeil, "ceil", 2); + create_builtin_fn(g, BuiltinFnIdTrunc, "trunc", 2); + //Needs library support on Windows + //create_builtin_fn(g, BuiltinFnIdNearbyInt, "nearbyInt", 2); + create_builtin_fn(g, BuiltinFnIdRound, "round", 2); + create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4); create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX); |
