diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2023-11-03 19:50:49 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2023-11-03 20:05:32 -0700 |
| commit | 13964796561a5e6ca16c76020bc3371858decb31 (patch) | |
| tree | 181663792c456e56d0a5975eee5ac1eff0a7355a /src/reduce | |
| parent | dff13e088b4da3fab50d72a8967b0d7cffeb9d7d (diff) | |
| download | zig-13964796561a5e6ca16c76020bc3371858decb31.tar.gz zig-13964796561a5e6ca16c76020bc3371858decb31.zip | |
zig reduce: redesign
Now it works like this:
1. Walk the AST of the source file looking for independent
reductions and collecting them all into an array list.
2. Randomize the list of transformations. A future enhancement will add
priority weights to the sorting but for now they are completely
shuffled.
3. Apply a subset consisting of 1/2 of the transformations and check for
interestingness.
4. If not interesting, half the subset size again and check again.
5. Repeat until the subset size is 1, then march the transformation
index forward by 1 with each non-interesting attempt.
At any point if a subset of transformations succeeds in producing an interesting
result, restart the whole process, reparsing the AST and re-generating the list
of all possible transformations and shuffling it again.
As for std.zig.render, the fixups operate based on AST Node Index rather
than Nth index of the function occurence. This allows precise control
over how to mutate the input.
Diffstat (limited to 'src/reduce')
| -rw-r--r-- | src/reduce/Walk.zig | 792 |
1 files changed, 792 insertions, 0 deletions
diff --git a/src/reduce/Walk.zig b/src/reduce/Walk.zig new file mode 100644 index 0000000000..ea6528c917 --- /dev/null +++ b/src/reduce/Walk.zig @@ -0,0 +1,792 @@ +const std = @import("std"); +const Ast = std.zig.Ast; +const Walk = @This(); +const assert = std.debug.assert; + +ast: *const Ast, +transformations: *std.ArrayList(Transformation), + +pub const Transformation = union(enum) { + /// Replace the fn decl AST Node with one whose body is only `@trap()` with + /// discarded parameters. + gut_function: Ast.Node.Index, +}; + +pub const Error = error{OutOfMemory}; + +/// The result will be priority shuffled. +pub fn findTransformations(ast: *const Ast, transformations: *std.ArrayList(Transformation)) !void { + transformations.clearRetainingCapacity(); + + var walk: Walk = .{ + .ast = ast, + .transformations = transformations, + }; + try walkMembers(&walk, walk.ast.rootDecls()); +} + +fn walkMembers(w: *Walk, members: []const Ast.Node.Index) Error!void { + for (members) |member| { + try walkMember(w, member); + } +} + +fn walkMember(w: *Walk, decl: Ast.Node.Index) Error!void { + const ast = w.ast; + const datas = ast.nodes.items(.data); + switch (ast.nodes.items(.tag)[decl]) { + .fn_decl => { + const fn_proto = datas[decl].lhs; + try walkExpression(w, fn_proto); + const body_node = datas[decl].rhs; + if (!isFnBodyGutted(ast, body_node)) { + try w.transformations.append(.{ .gut_function = decl }); + } + try walkExpression(w, body_node); + }, + .fn_proto_simple, + .fn_proto_multi, + .fn_proto_one, + .fn_proto, + => { + try walkExpression(w, decl); + }, + + .@"usingnamespace" => { + const expr = datas[decl].lhs; + try walkExpression(w, expr); + }, + + .global_var_decl, + .local_var_decl, + .simple_var_decl, + .aligned_var_decl, + => try walkVarDecl(w, ast.fullVarDecl(decl).?), + + .test_decl => { + try walkExpression(w, datas[decl].rhs); + }, + + .container_field_init, + .container_field_align, + .container_field, + => try walkContainerField(w, ast.fullContainerField(decl).?), + + .@"comptime" => try walkExpression(w, decl), + + .root => unreachable, + else => unreachable, + } +} + +fn walkExpression(w: *Walk, node: Ast.Node.Index) Error!void { + const ast = w.ast; + const token_tags = ast.tokens.items(.tag); + const main_tokens = ast.nodes.items(.main_token); + const node_tags = ast.nodes.items(.tag); + const datas = ast.nodes.items(.data); + switch (node_tags[node]) { + .identifier => {}, + + .number_literal, + .char_literal, + .unreachable_literal, + .anyframe_literal, + .string_literal, + => {}, + + .multiline_string_literal => {}, + + .error_value => {}, + + .block_two, + .block_two_semicolon, + => { + const statements = [2]Ast.Node.Index{ datas[node].lhs, datas[node].rhs }; + if (datas[node].lhs == 0) { + return walkBlock(w, node, statements[0..0]); + } else if (datas[node].rhs == 0) { + return walkBlock(w, node, statements[0..1]); + } else { + return walkBlock(w, node, statements[0..2]); + } + }, + .block, + .block_semicolon, + => { + const statements = ast.extra_data[datas[node].lhs..datas[node].rhs]; + return walkBlock(w, node, statements); + }, + + .@"errdefer" => { + const expr = datas[node].rhs; + return walkExpression(w, expr); + }, + + .@"defer" => { + const expr = datas[node].rhs; + return walkExpression(w, expr); + }, + .@"comptime", .@"nosuspend" => { + const block = datas[node].lhs; + return walkExpression(w, block); + }, + + .@"suspend" => { + const body = datas[node].lhs; + return walkExpression(w, body); + }, + + .@"catch" => { + try walkExpression(w, datas[node].lhs); // target + try walkExpression(w, datas[node].rhs); // fallback + }, + + .field_access => { + const field_access = datas[node]; + try walkExpression(w, field_access.lhs); + }, + + .error_union, + .switch_range, + => { + const infix = datas[node]; + try walkExpression(w, infix.lhs); + return walkExpression(w, infix.rhs); + }, + .for_range => { + const infix = datas[node]; + try walkExpression(w, infix.lhs); + if (infix.rhs != 0) { + return walkExpression(w, infix.rhs); + } + }, + + .add, + .add_wrap, + .add_sat, + .array_cat, + .array_mult, + .assign, + .assign_bit_and, + .assign_bit_or, + .assign_shl, + .assign_shl_sat, + .assign_shr, + .assign_bit_xor, + .assign_div, + .assign_sub, + .assign_sub_wrap, + .assign_sub_sat, + .assign_mod, + .assign_add, + .assign_add_wrap, + .assign_add_sat, + .assign_mul, + .assign_mul_wrap, + .assign_mul_sat, + .bang_equal, + .bit_and, + .bit_or, + .shl, + .shl_sat, + .shr, + .bit_xor, + .bool_and, + .bool_or, + .div, + .equal_equal, + .greater_or_equal, + .greater_than, + .less_or_equal, + .less_than, + .merge_error_sets, + .mod, + .mul, + .mul_wrap, + .mul_sat, + .sub, + .sub_wrap, + .sub_sat, + .@"orelse", + => { + const infix = datas[node]; + try walkExpression(w, infix.lhs); + try walkExpression(w, infix.rhs); + }, + + .assign_destructure => { + const lhs_count = ast.extra_data[datas[node].lhs]; + assert(lhs_count > 1); + const lhs_exprs = ast.extra_data[datas[node].lhs + 1 ..][0..lhs_count]; + const rhs = datas[node].rhs; + + for (lhs_exprs) |lhs_node| { + switch (node_tags[lhs_node]) { + .global_var_decl, + .local_var_decl, + .simple_var_decl, + .aligned_var_decl, + => try walkVarDecl(w, ast.fullVarDecl(lhs_node).?), + else => try walkExpression(w, lhs_node), + } + } + return walkExpression(w, rhs); + }, + + .bit_not, + .bool_not, + .negation, + .negation_wrap, + .optional_type, + .address_of, + => { + return walkExpression(w, datas[node].lhs); + }, + + .@"try", + .@"resume", + .@"await", + => { + return walkExpression(w, datas[node].lhs); + }, + + .array_type, + .array_type_sentinel, + => {}, + + .ptr_type_aligned, + .ptr_type_sentinel, + .ptr_type, + .ptr_type_bit_range, + => {}, + + .array_init_one, + .array_init_one_comma, + .array_init_dot_two, + .array_init_dot_two_comma, + .array_init_dot, + .array_init_dot_comma, + .array_init, + .array_init_comma, + => { + var elements: [2]Ast.Node.Index = undefined; + return walkArrayInit(w, ast.fullArrayInit(&elements, node).?); + }, + + .struct_init_one, + .struct_init_one_comma, + .struct_init_dot_two, + .struct_init_dot_two_comma, + .struct_init_dot, + .struct_init_dot_comma, + .struct_init, + .struct_init_comma, + => { + var buf: [2]Ast.Node.Index = undefined; + return walkStructInit(w, node, ast.fullStructInit(&buf, node).?); + }, + + .call_one, + .call_one_comma, + .async_call_one, + .async_call_one_comma, + .call, + .call_comma, + .async_call, + .async_call_comma, + => { + var buf: [1]Ast.Node.Index = undefined; + return walkCall(w, ast.fullCall(&buf, node).?); + }, + + .array_access => { + const suffix = datas[node]; + try walkExpression(w, suffix.lhs); + try walkExpression(w, suffix.rhs); + }, + + .slice_open, .slice, .slice_sentinel => return walkSlice(w, node, ast.fullSlice(node).?), + + .deref => { + try walkExpression(w, datas[node].lhs); + }, + + .unwrap_optional => { + try walkExpression(w, datas[node].lhs); + }, + + .@"break" => { + const label_token = datas[node].lhs; + const target = datas[node].rhs; + if (label_token == 0 and target == 0) { + // no expressions + } else if (label_token == 0 and target != 0) { + try walkExpression(w, target); + } else if (label_token != 0 and target == 0) { + try walkIdentifier(w, label_token); + } else if (label_token != 0 and target != 0) { + try walkExpression(w, target); + } + }, + + .@"continue" => { + const label = datas[node].lhs; + if (label != 0) { + return walkIdentifier(w, label); // label + } + }, + + .@"return" => { + if (datas[node].lhs != 0) { + try walkExpression(w, datas[node].lhs); + } + }, + + .grouped_expression => { + try walkExpression(w, datas[node].lhs); + }, + + .container_decl, + .container_decl_trailing, + .container_decl_arg, + .container_decl_arg_trailing, + .container_decl_two, + .container_decl_two_trailing, + .tagged_union, + .tagged_union_trailing, + .tagged_union_enum_tag, + .tagged_union_enum_tag_trailing, + .tagged_union_two, + .tagged_union_two_trailing, + => { + var buf: [2]Ast.Node.Index = undefined; + return walkContainerDecl(w, node, ast.fullContainerDecl(&buf, node).?); + }, + + .error_set_decl => { + const error_token = main_tokens[node]; + const lbrace = error_token + 1; + const rbrace = datas[node].rhs; + + var i = lbrace + 1; + while (i < rbrace) : (i += 1) { + switch (token_tags[i]) { + .doc_comment => unreachable, // TODO + .identifier => try walkIdentifier(w, i), + .comma => {}, + else => unreachable, + } + } + }, + + .builtin_call_two, .builtin_call_two_comma => { + if (datas[node].lhs == 0) { + return walkBuiltinCall(w, main_tokens[node], &.{}); + } else if (datas[node].rhs == 0) { + return walkBuiltinCall(w, main_tokens[node], &.{datas[node].lhs}); + } else { + return walkBuiltinCall(w, main_tokens[node], &.{ datas[node].lhs, datas[node].rhs }); + } + }, + .builtin_call, .builtin_call_comma => { + const params = ast.extra_data[datas[node].lhs..datas[node].rhs]; + return walkBuiltinCall(w, main_tokens[node], params); + }, + + .fn_proto_simple, + .fn_proto_multi, + .fn_proto_one, + .fn_proto, + => { + var buf: [1]Ast.Node.Index = undefined; + return walkFnProto(w, ast.fullFnProto(&buf, node).?); + }, + + .anyframe_type => { + if (datas[node].rhs != 0) { + return walkExpression(w, datas[node].rhs); + } + }, + + .@"switch", + .switch_comma, + => { + const condition = datas[node].lhs; + const extra = ast.extraData(datas[node].rhs, Ast.Node.SubRange); + const cases = ast.extra_data[extra.start..extra.end]; + + try walkExpression(w, condition); // condition expression + try walkExpressions(w, cases); + }, + + .switch_case_one, + .switch_case_inline_one, + .switch_case, + .switch_case_inline, + => return walkSwitchCase(w, ast.fullSwitchCase(node).?), + + .while_simple, + .while_cont, + .@"while", + => return walkWhile(w, ast.fullWhile(node).?), + + .for_simple, + .@"for", + => return walkFor(w, ast.fullFor(node).?), + + .if_simple, + .@"if", + => return walkIf(w, ast.fullIf(node).?), + + .asm_simple, + .@"asm", + => return walkAsm(w, ast.fullAsm(node).?), + + .enum_literal => { + return walkIdentifier(w, main_tokens[node]); // name + }, + + .fn_decl => unreachable, + .container_field => unreachable, + .container_field_init => unreachable, + .container_field_align => unreachable, + .root => unreachable, + .global_var_decl => unreachable, + .local_var_decl => unreachable, + .simple_var_decl => unreachable, + .aligned_var_decl => unreachable, + .@"usingnamespace" => unreachable, + .test_decl => unreachable, + .asm_output => unreachable, + .asm_input => unreachable, + } +} + +fn walkVarDecl(w: *Walk, var_decl: Ast.full.VarDecl) Error!void { + try walkIdentifier(w, var_decl.ast.mut_token + 1); // name + + if (var_decl.ast.type_node != 0) { + try walkExpression(w, var_decl.ast.type_node); + } + + if (var_decl.ast.align_node != 0) { + try walkExpression(w, var_decl.ast.align_node); + } + + if (var_decl.ast.addrspace_node != 0) { + try walkExpression(w, var_decl.ast.addrspace_node); + } + + if (var_decl.ast.section_node != 0) { + try walkExpression(w, var_decl.ast.section_node); + } + + assert(var_decl.ast.init_node != 0); + + return walkExpression(w, var_decl.ast.init_node); +} + +fn walkContainerField(w: *Walk, field: Ast.full.ContainerField) Error!void { + if (field.ast.type_expr != 0) { + try walkExpression(w, field.ast.type_expr); // type + } + if (field.ast.align_expr != 0) { + try walkExpression(w, field.ast.align_expr); // alignment + } + try walkExpression(w, field.ast.value_expr); // value +} + +fn walkBlock( + w: *Walk, + block_node: Ast.Node.Index, + statements: []const Ast.Node.Index, +) Error!void { + _ = block_node; + const ast = w.ast; + const node_tags = ast.nodes.items(.tag); + + for (statements) |stmt| { + switch (node_tags[stmt]) { + .global_var_decl, + .local_var_decl, + .simple_var_decl, + .aligned_var_decl, + => try walkVarDecl(w, ast.fullVarDecl(stmt).?), + + else => try walkExpression(w, stmt), + } + } +} + +fn walkArrayType(w: *Walk, array_type: Ast.full.ArrayType) Error!void { + try walkExpression(w, array_type.ast.elem_count); + if (array_type.ast.sentinel != 0) { + try walkExpression(w, array_type.ast.sentinel); + } + return walkExpression(w, array_type.ast.elem_type); +} + +fn walkArrayInit(w: *Walk, array_init: Ast.full.ArrayInit) Error!void { + if (array_init.ast.type_expr != 0) { + try walkExpression(w, array_init.ast.type_expr); // T + } + for (array_init.ast.elements) |elem_init| { + try walkExpression(w, elem_init); + } +} + +fn walkStructInit( + w: *Walk, + struct_node: Ast.Node.Index, + struct_init: Ast.full.StructInit, +) Error!void { + _ = struct_node; + if (struct_init.ast.type_expr != 0) { + try walkExpression(w, struct_init.ast.type_expr); // T + } + for (struct_init.ast.fields) |field_init| { + try walkExpression(w, field_init); + } +} + +fn walkCall(w: *Walk, call: Ast.full.Call) Error!void { + try walkExpression(w, call.ast.fn_expr); + try walkParamList(w, call.ast.params); +} + +fn walkSlice( + w: *Walk, + slice_node: Ast.Node.Index, + slice: Ast.full.Slice, +) Error!void { + _ = slice_node; + try walkExpression(w, slice.ast.sliced); + try walkExpression(w, slice.ast.start); + if (slice.ast.end != 0) { + try walkExpression(w, slice.ast.end); + } + if (slice.ast.sentinel != 0) { + try walkExpression(w, slice.ast.sentinel); + } +} + +fn walkIdentifier(w: *Walk, token_index: Ast.TokenIndex) Error!void { + _ = w; + _ = token_index; +} + +fn walkContainerDecl( + w: *Walk, + container_decl_node: Ast.Node.Index, + container_decl: Ast.full.ContainerDecl, +) Error!void { + _ = container_decl_node; + if (container_decl.ast.arg != 0) { + try walkExpression(w, container_decl.ast.arg); + } + for (container_decl.ast.members) |member| { + try walkMember(w, member); + } +} + +fn walkBuiltinCall( + w: *Walk, + builtin_token: Ast.TokenIndex, + params: []const Ast.Node.Index, +) Error!void { + _ = builtin_token; + for (params) |param_node| { + try walkExpression(w, param_node); + } +} + +fn walkFnProto(w: *Walk, fn_proto: Ast.full.FnProto) Error!void { + const ast = w.ast; + + { + var it = fn_proto.iterate(ast); + while (it.next()) |param| { + if (param.type_expr != 0) { + try walkExpression(w, param.type_expr); + } + } + } + + if (fn_proto.ast.align_expr != 0) { + try walkExpression(w, fn_proto.ast.align_expr); + } + + if (fn_proto.ast.addrspace_expr != 0) { + try walkExpression(w, fn_proto.ast.addrspace_expr); + } + + if (fn_proto.ast.section_expr != 0) { + try walkExpression(w, fn_proto.ast.section_expr); + } + + if (fn_proto.ast.callconv_expr != 0) { + try walkExpression(w, fn_proto.ast.callconv_expr); + } + + try walkExpression(w, fn_proto.ast.return_type); +} + +fn walkExpressions(w: *Walk, expressions: []const Ast.Node.Index) Error!void { + for (expressions) |expression| { + try walkExpression(w, expression); + } +} + +fn walkSwitchCase(w: *Walk, switch_case: Ast.full.SwitchCase) Error!void { + for (switch_case.ast.values) |value_expr| { + try walkExpression(w, value_expr); + } + try walkExpression(w, switch_case.ast.target_expr); +} + +fn walkWhile(w: *Walk, while_node: Ast.full.While) Error!void { + try walkExpression(w, while_node.ast.cond_expr); // condition + + if (while_node.ast.cont_expr != 0) { + try walkExpression(w, while_node.ast.cont_expr); + } + + try walkExpression(w, while_node.ast.cond_expr); // condition + + if (while_node.ast.then_expr != 0) { + try walkExpression(w, while_node.ast.then_expr); + } + if (while_node.ast.else_expr != 0) { + try walkExpression(w, while_node.ast.else_expr); + } +} + +fn walkFor(w: *Walk, for_node: Ast.full.For) Error!void { + try walkParamList(w, for_node.ast.inputs); + if (for_node.ast.then_expr != 0) { + try walkExpression(w, for_node.ast.then_expr); + } + if (for_node.ast.else_expr != 0) { + try walkExpression(w, for_node.ast.else_expr); + } +} + +fn walkIf(w: *Walk, if_node: Ast.full.If) Error!void { + try walkExpression(w, if_node.ast.cond_expr); // condition + + if (if_node.ast.then_expr != 0) { + try walkExpression(w, if_node.ast.then_expr); + } + if (if_node.ast.else_expr != 0) { + try walkExpression(w, if_node.ast.else_expr); + } +} + +fn walkAsm(w: *Walk, asm_node: Ast.full.Asm) Error!void { + try walkExpression(w, asm_node.ast.template); + for (asm_node.ast.items) |item| { + try walkExpression(w, item); + } +} + +fn walkParamList(w: *Walk, params: []const Ast.Node.Index) Error!void { + for (params) |param_node| { + try walkExpression(w, param_node); + } +} + +/// Check if it is already gutted (i.e. its body replaced with `@trap()`). +fn isFnBodyGutted(ast: *const Ast, body_node: Ast.Node.Index) bool { + // skip over discards + const node_tags = ast.nodes.items(.tag); + const datas = ast.nodes.items(.data); + var statements_buf: [2]Ast.Node.Index = undefined; + const statements = switch (node_tags[body_node]) { + .block_two, + .block_two_semicolon, + => blk: { + statements_buf[0..2].* = .{ datas[body_node].lhs, datas[body_node].rhs }; + break :blk if (datas[body_node].lhs == 0) + statements_buf[0..0] + else if (datas[body_node].rhs == 0) + statements_buf[0..1] + else + statements_buf[0..2]; + }, + + .block, + .block_semicolon, + => ast.extra_data[datas[body_node].lhs..datas[body_node].rhs], + + else => return false, + }; + var i: usize = 0; + while (i < statements.len) : (i += 1) { + switch (categorizeStmt(ast, statements[i])) { + .discard_identifier => continue, + .trap_call => return i + 1 == statements.len, + else => return false, + } + } + return false; +} + +const StmtCategory = enum { + discard_identifier, + trap_call, + other, +}; + +fn categorizeStmt(ast: *const Ast, stmt: Ast.Node.Index) StmtCategory { + const node_tags = ast.nodes.items(.tag); + const datas = ast.nodes.items(.data); + const main_tokens = ast.nodes.items(.main_token); + switch (node_tags[stmt]) { + .builtin_call_two, .builtin_call_two_comma => { + if (datas[stmt].lhs == 0) { + return categorizeBuiltinCall(ast, main_tokens[stmt], &.{}); + } else if (datas[stmt].rhs == 0) { + return categorizeBuiltinCall(ast, main_tokens[stmt], &.{datas[stmt].lhs}); + } else { + return categorizeBuiltinCall(ast, main_tokens[stmt], &.{ datas[stmt].lhs, datas[stmt].rhs }); + } + }, + .builtin_call, .builtin_call_comma => { + const params = ast.extra_data[datas[stmt].lhs..datas[stmt].rhs]; + return categorizeBuiltinCall(ast, main_tokens[stmt], params); + }, + .assign => { + const infix = datas[stmt]; + if (isDiscardIdent(ast, infix.lhs) and node_tags[infix.rhs] == .identifier) + return .discard_identifier; + return .other; + }, + else => return .other, + } +} + +fn categorizeBuiltinCall( + ast: *const Ast, + builtin_token: Ast.TokenIndex, + params: []const Ast.Node.Index, +) StmtCategory { + if (params.len != 0) return .other; + const name_bytes = ast.tokenSlice(builtin_token); + if (std.mem.eql(u8, name_bytes, "@trap")) + return .trap_call; + return .other; +} + +fn isDiscardIdent(ast: *const Ast, node: Ast.Node.Index) bool { + const node_tags = ast.nodes.items(.tag); + const main_tokens = ast.nodes.items(.main_token); + switch (node_tags[node]) { + .identifier => { + const token_index = main_tokens[node]; + const name_bytes = ast.tokenSlice(token_index); + return std.mem.eql(u8, name_bytes, "_"); + }, + else => return false, + } +} |
