aboutsummaryrefslogtreecommitdiff
path: root/src/codegen.cpp
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2019-06-23 17:10:33 -0400
committerAndrew Kelley <andrew@ziglang.org>2019-06-23 17:10:33 -0400
commitb2cbc59e4c354264f295a98d3df077773acd8400 (patch)
tree1e21b4ee73bc9eadafb52b912cda174109f7b69d /src/codegen.cpp
parentca3660f6bf3a1f8d77692acf72eefe148802d342 (diff)
parent71e014caecaa54fdd8a0516710d2d9597da41398 (diff)
downloadzig-b2cbc59e4c354264f295a98d3df077773acd8400.tar.gz
zig-b2cbc59e4c354264f295a98d3df077773acd8400.zip
Merge branch 'simd2' of https://github.com/shawnl/zig into shawnl-simd2
Diffstat (limited to 'src/codegen.cpp')
-rw-r--r--src/codegen.cpp104
1 files changed, 73 insertions, 31 deletions
diff --git a/src/codegen.cpp b/src/codegen.cpp
index 3dd6995c61..41caa29dbd 100644
--- a/src/codegen.cpp
+++ b/src/codegen.cpp
@@ -806,32 +806,47 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu
return fn_val;
}
-static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) {
- assert(type_entry->id == ZigTypeIdFloat);
+static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id, BuiltinFnId op) {
+ assert(type_entry->id == ZigTypeIdFloat ||
+ type_entry->id == ZigTypeIdVector);
+
+ bool is_vector = (type_entry->id == ZigTypeIdVector);
+ ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry;
ZigLLVMFnKey key = {};
key.id = fn_id;
- key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count;
+ key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count;
+ key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0;
+ key.data.floating.op = op;
auto existing_entry = g->llvm_fn_table.maybe_get(key);
if (existing_entry)
return existing_entry->value;
const char *name;
- if (fn_id == ZigLLVMFnIdFloor) {
- name = "floor";
- } else if (fn_id == ZigLLVMFnIdCeil) {
- name = "ceil";
- } else if (fn_id == ZigLLVMFnIdSqrt) {
- name = "sqrt";
+ uint32_t num_args;
+ if (fn_id == ZigLLVMFnIdFMA) {
+ name = "fma";
+ num_args = 3;
+ } else if (fn_id == ZigLLVMFnIdFloatOp) {
+ name = float_op_to_name(op, true);
+ num_args = 1;
} else {
zig_unreachable();
}
char fn_name[64];
- sprintf(fn_name, "llvm.%s.f%" ZIG_PRI_usize "", name, type_entry->data.floating.bit_count);
+ if (is_vector)
+ sprintf(fn_name, "llvm.%s.v%" PRIu32 "f%" PRIu32, name, key.data.floating.vector_len, key.data.floating.bit_count);
+ else
+ sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count);
LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry);
- LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, &float_type_ref, 1, false);
+ LLVMTypeRef return_elem_types[3] = {
+ float_type_ref,
+ float_type_ref,
+ float_type_ref,
+ };
+ LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false);
LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
assert(LLVMGetIntrinsicID(fn_val));
@@ -2460,22 +2475,17 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry,
return result;
}
-static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, ZigType *type_entry) {
- if (type_entry->id == ZigTypeIdInt)
+static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) {
+ if ((op == BuiltinFnIdCeil ||
+ op == BuiltinFnIdFloor) &&
+ type_entry->id == ZigTypeIdInt)
return val;
+ assert(type_entry->id == ZigTypeIdFloat);
- LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor);
+ LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op);
return LLVMBuildCall(g->builder, floor_fn, &val, 1, "");
}
-static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, ZigType *type_entry) {
- if (type_entry->id == ZigTypeIdInt)
- return val;
-
- LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil);
- return LLVMBuildCall(g->builder, ceil_fn, &val, 1, "");
-}
-
enum DivKind {
DivKindFloat,
DivKindTrunc,
@@ -2551,7 +2561,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
return result;
case DivKindExact:
if (want_runtime_safety) {
- LLVMValueRef floored = gen_floor(g, result, type_entry);
+ LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
@@ -2573,12 +2583,12 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
LLVMPositionBuilderAtEnd(g->builder, ltz_block);
- LLVMValueRef ceiled = gen_ceil(g, result, type_entry);
+ LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil);
LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);
LLVMPositionBuilderAtEnd(g->builder, gez_block);
- LLVMValueRef floored = gen_floor(g, result, type_entry);
+ LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);
@@ -2590,7 +2600,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
return phi;
}
case DivKindFloor:
- return gen_floor(g, result, type_entry);
+ return gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
}
zig_unreachable();
}
@@ -5430,13 +5440,28 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e
return nullptr;
}
-static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) {
- LLVMValueRef op = ir_llvm_value(g, instruction->op);
+static LLVMValueRef ir_render_float_op(CodeGen *g, IrExecutable *executable, IrInstructionFloatOp *instruction) {
+ LLVMValueRef op = ir_llvm_value(g, instruction->op1);
assert(instruction->base.value.type->id == ZigTypeIdFloat);
- LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt);
+ LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFloatOp, instruction->op);
return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
}
+static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrInstructionMulAdd *instruction) {
+ LLVMValueRef op1 = ir_llvm_value(g, instruction->op1);
+ LLVMValueRef op2 = ir_llvm_value(g, instruction->op2);
+ LLVMValueRef op3 = ir_llvm_value(g, instruction->op3);
+ assert(instruction->base.value.type->id == ZigTypeIdFloat ||
+ instruction->base.value.type->id == ZigTypeIdVector);
+ LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd);
+ LLVMValueRef args[3] = {
+ op1,
+ op2,
+ op3,
+ };
+ return LLVMBuildCall(g->builder, fn_val, args, 3, "");
+}
+
static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) {
LLVMValueRef op = ir_llvm_value(g, instruction->op);
ZigType *int_type = instruction->base.value.type;
@@ -5779,8 +5804,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction);
case IrInstructionIdMarkErrRetTracePtr:
return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction);
- case IrInstructionIdSqrt:
- return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction);
+ case IrInstructionIdFloatOp:
+ return ir_render_float_op(g, executable, (IrInstructionFloatOp *)instruction);
+ case IrInstructionIdMulAdd:
+ return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction);
case IrInstructionIdArrayToVector:
return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction);
case IrInstructionIdVectorToArray:
@@ -7398,6 +7425,21 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdRem, "rem", 2);
create_builtin_fn(g, BuiltinFnIdMod, "mod", 2);
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
+ create_builtin_fn(g, BuiltinFnIdSin, "sin", 2);
+ create_builtin_fn(g, BuiltinFnIdCos, "cos", 2);
+ create_builtin_fn(g, BuiltinFnIdExp, "exp", 2);
+ create_builtin_fn(g, BuiltinFnIdExp2, "exp2", 2);
+ create_builtin_fn(g, BuiltinFnIdLn, "ln", 2);
+ create_builtin_fn(g, BuiltinFnIdLog2, "log2", 2);
+ create_builtin_fn(g, BuiltinFnIdLog10, "log10", 2);
+ create_builtin_fn(g, BuiltinFnIdFabs, "fabs", 2);
+ create_builtin_fn(g, BuiltinFnIdFloor, "floor", 2);
+ create_builtin_fn(g, BuiltinFnIdCeil, "ceil", 2);
+ create_builtin_fn(g, BuiltinFnIdTrunc, "trunc", 2);
+ //Needs library support on Windows
+ //create_builtin_fn(g, BuiltinFnIdNearbyInt, "nearbyInt", 2);
+ create_builtin_fn(g, BuiltinFnIdRound, "round", 2);
+ create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4);
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX);