diff options
Diffstat (limited to 'pkg/pdef')
-rw-r--r-- | pkg/pdef/pdefgen/pdefgen.go | 171 |
1 files changed, 168 insertions, 3 deletions
diff --git a/pkg/pdef/pdefgen/pdefgen.go b/pkg/pdef/pdefgen/pdefgen.go index e9ab76f..fce1ea6 100644 --- a/pkg/pdef/pdefgen/pdefgen.go +++ b/pkg/pdef/pdefgen/pdefgen.go @@ -71,11 +71,14 @@ func main() { pln(&buf, `_ "embed"`) pln(&buf, `"encoding"`) pln(&buf, `"encoding/binary"`) + pln(&buf, `"encoding/json"`) pln(&buf, `"errors"`) pln(&buf, `"fmt"`) pln(&buf, `"io"`) pln(&buf, `"math"`) + pln(&buf, `"reflect"`) pln(&buf, `"strconv"`) + pln(&buf, `"strings"`) pln(&buf, `)`) pln(&buf, ``) @@ -168,6 +171,111 @@ func main() { generateEnum(&buf, k, v) }) + // this is needed since the default Go marshaler can't filter and doesn't + // handle NaN/Inf floats. + pln(&buf, "%s", ` + func pdataMarshalJSONStruct(obj any, filter func(path ...string) bool, path ...string) ([]byte, error) { + objVal := reflect.ValueOf(obj) + objTyp := objVal.Type() + + if objTyp.Kind() != reflect.Struct { + panic("not a struct") + } + + var b bytes.Buffer + b.WriteByte('{') + needComma := false + for i := 0; i < objTyp.NumField(); i++ { + fldTyp := objTyp.Field(i) + fldVal := objVal.Field(i) + fld := fldVal.Interface() + + fldTag := fldTyp.Tag.Get("pdef") + fldTagName, fldTagAttr, _ := strings.Cut(fldTag, ",") + if fldTagName == "" { + continue + } + if fldTagAttr != "" { + panic(fmt.Errorf("unknown pdef field tag attrs %q", fldTagAttr)) + } + + fldPath := append(path, fldTagName) + if filter != nil && !filter(fldPath...) { + continue + } + + if needComma { + b.WriteByte(',') + needComma = false + } + + b.WriteString("\"" + fldTagName + "\":") + needComma = true + + switch fldTyp.Type.Kind() { + case reflect.Struct: + buf, err := pdataMarshalJSONStruct(fld, filter, fldPath...) + if err != nil { + return nil, err + } + b.Write(buf) + case reflect.Array: + b.WriteByte('[') + for j := 0; j < fldTyp.Type.Len(); j++ { + fldValElemVal := fldVal.Index(j) + fldValElemTyp := fldValElemVal.Type() + fldValElem := fldValElemVal.Interface() + + if j != 0 { + b.WriteByte(',') + } + if fldValElemTyp.Kind() == reflect.Struct { + buf, err := pdataMarshalJSONStruct(fldValElem, filter, fldPath...) + if err != nil { + return nil, err + } + b.Write(buf) + } else { + buf, err := pdataMarshalJSONPrimitive(fldValElem) + if err != nil { + return nil, err + } + b.Write(buf) + } + } + b.WriteByte(']') + default: + buf, err := pdataMarshalJSONPrimitive(fld) + if err != nil { + return nil, err + } + b.Write(buf) + } + } + b.WriteByte('}') + return b.Bytes(), nil + } + + func pdataMarshalJSONPrimitive(v any) ([]byte, error) { + switch v := v.(type) { + case float32: + if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) { + return []byte("null"), nil // match the JS behaviour + } + return json.Marshal(v) + case int32, bool, string: + return json.Marshal(v) + case json.Marshaler: + if reflect.TypeOf(v).ConvertibleTo(reflect.TypeOf(uint8(0))) { + return v.MarshalJSON() // enum + } + panic(fmt.Errorf("unhandled type %T", v)) + default: + panic(fmt.Errorf("unhandled type %T", v)) + } + } + `) + if err := writeGo(strings.TrimSuffix(filepath.Base(f.Name()), filepath.Ext(f.Name()))+".go", buf.Bytes()); err != nil { fmt.Fprintf(os.Stderr, "error: write generated source: %v\n", err) os.Exit(1) @@ -185,6 +293,7 @@ func main() { import ( "bytes" + "encoding/json" "os" "testing" ) @@ -221,6 +330,12 @@ func main() { if !bytes.Equal(rbuf, ebuf) { t.Errorf("internal round-trip failed: re-marshaled unmarshaled data encoded by marshal does not match") } + + if buf, err := d2.MarshalJSON(); err != nil { + t.Errorf("failed to marshal as JSON: %%v", err) + } else if err = json.Unmarshal(buf, new(map[string]interface{})); err != nil { + t.Errorf("bad json marshal result: %%v", err) + } }) } } @@ -243,15 +358,17 @@ func generateStruct(buf *bytes.Buffer, name string, pd *pdef.Pdef, fields []pdef { pln(buf, `type %s struct {`, mangle(name, true)) for _, v := range fields { - pln(buf, `%s %s`+" `json:%q`", mangle(v.Name, true), pdefGoType(v.Type), v.Name) + pln(buf, `%s %s`+" `pdef:%q`", mangle(v.Name, true), pdefGoType(v.Type), v.Name) } if root { - pln(buf, `%s %s`+" `json:%q`", mangle("extraData", true), "[]byte", "extraData,omitempty") + pln(buf, `%s %s`, mangle("extraData", true), "[]byte") } pln(buf, `}`) pln(buf, `var _ encoding.BinaryUnmarshaler = (*%s)(nil)`, mangle(name, true)) pln(buf, `var _ encoding.BinaryMarshaler = %s{}`, mangle(name, true)) + pln(buf, `var _ json.Unmarshaler = (*%s)(nil)`, mangle(name, true)) + pln(buf, `var _ json.Marshaler = %s{}`, mangle(name, true)) } { pln(buf, `func (v *%s) UnmarshalBinary(b []byte) error {`, mangle(name, true)) @@ -433,6 +550,27 @@ func generateStruct(buf *bytes.Buffer, name string, pd *pdef.Pdef, fields []pdef pln(buf, `return b, nil`) pln(buf, `}`) } + { + pln(buf, `func (v *%s) UnmarshalJSON(b []byte) error {`, mangle(name, true)) + // TODO: implement this if we actually have a use for it + pln(buf, `return fmt.Errorf("not implemented")`) + pln(buf, `}`) + } + { + pln(buf, `func (v %s) MarshalJSON() ([]byte, error) {`, mangle(name, true)) + pln(buf, `return v.MarshalJSONFilter(nil)`) + pln(buf, `}`) + } + { + pln(buf, `func (v %s) MarshalJSONFilter(filter func(path ...string) bool) ([]byte, error) {`, mangle(name, true)) + if root { + pln(buf, `if x := v.%s; x != Version {`, mangle(fields[0].Name, true)) + pln(buf, `return nil, fmt.Errorf(%#v, %#v, Version, ErrUnsupportedVersion, x)`, `encode %q (v%d): %w: got %d`, name) + pln(buf, `}`) + } + pln(buf, `return pdataMarshalJSONStruct(v, filter)`) + pln(buf, `}`) + } } func generateEnum(buf *bytes.Buffer, name string, values []string) { @@ -447,8 +585,10 @@ func generateEnum(buf *bytes.Buffer, name string, values []string) { pln(buf, `var _ fmt.Stringer = %s(0)`, mangle(name, true)) pln(buf, `var _ fmt.GoStringer = %s(0)`, mangle(name, true)) - pln(buf, `//var _ encoding.TextMarshaler = %s(0)`, mangle(name, true)) + pln(buf, `var _ encoding.TextMarshaler = %s(0)`, mangle(name, true)) pln(buf, `var _ encoding.TextUnmarshaler = (*%s)(nil)`, mangle(name, true)) + pln(buf, `var _ json.Marshaler = %s(0)`, mangle(name, true)) + pln(buf, `var _ json.Unmarshaler = (*%s)(nil)`, mangle(name, true)) } { pln(buf, `func (v %s) String() string {`, mangle(name, true)) @@ -495,6 +635,31 @@ func generateEnum(buf *bytes.Buffer, name string, values []string) { pln(buf, `return nil`) pln(buf, `}`) } + { + pln(buf, `func (v %s) MarshalJSON() ([]byte, error) {`, mangle(name, true)) + pln(buf, `switch v {`) + for _, v := range values { + pln(buf, `case %s:`, mangleEnumValue(name, v)) + pln(buf, `return []byte(%#v), nil`, `"`+v+`"`) + } + pln(buf, `default:`) + pln(buf, `return []byte(strconv.Itoa(int(v))), nil`) + pln(buf, `}`) + pln(buf, `}`) + } + { + pln(buf, `func (v *%s) UnmarshalJSON(b []byte) error {`, mangle(name, true)) + pln(buf, `switch string(b) {`) + for _, v := range values { + pln(buf, `case %#v:`, `"`+v+`"`) + pln(buf, `*v = %s`, mangleEnumValue(name, v)) + } + pln(buf, `default:`) + pln(buf, `return json.Unmarshal(b, (*uint8)(v))`) + pln(buf, `}`) + pln(buf, `return nil`) + pln(buf, `}`) + } } func pdefGoType(t pdef.TypeInfo) string { |