aboutsummaryrefslogtreecommitdiff
path: root/pkg/pdef
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/pdef')
-rw-r--r--pkg/pdef/pdefgen/pdefgen.go171
1 files changed, 168 insertions, 3 deletions
diff --git a/pkg/pdef/pdefgen/pdefgen.go b/pkg/pdef/pdefgen/pdefgen.go
index e9ab76f..fce1ea6 100644
--- a/pkg/pdef/pdefgen/pdefgen.go
+++ b/pkg/pdef/pdefgen/pdefgen.go
@@ -71,11 +71,14 @@ func main() {
pln(&buf, `_ "embed"`)
pln(&buf, `"encoding"`)
pln(&buf, `"encoding/binary"`)
+ pln(&buf, `"encoding/json"`)
pln(&buf, `"errors"`)
pln(&buf, `"fmt"`)
pln(&buf, `"io"`)
pln(&buf, `"math"`)
+ pln(&buf, `"reflect"`)
pln(&buf, `"strconv"`)
+ pln(&buf, `"strings"`)
pln(&buf, `)`)
pln(&buf, ``)
@@ -168,6 +171,111 @@ func main() {
generateEnum(&buf, k, v)
})
+ // this is needed since the default Go marshaler can't filter and doesn't
+ // handle NaN/Inf floats.
+ pln(&buf, "%s", `
+ func pdataMarshalJSONStruct(obj any, filter func(path ...string) bool, path ...string) ([]byte, error) {
+ objVal := reflect.ValueOf(obj)
+ objTyp := objVal.Type()
+
+ if objTyp.Kind() != reflect.Struct {
+ panic("not a struct")
+ }
+
+ var b bytes.Buffer
+ b.WriteByte('{')
+ needComma := false
+ for i := 0; i < objTyp.NumField(); i++ {
+ fldTyp := objTyp.Field(i)
+ fldVal := objVal.Field(i)
+ fld := fldVal.Interface()
+
+ fldTag := fldTyp.Tag.Get("pdef")
+ fldTagName, fldTagAttr, _ := strings.Cut(fldTag, ",")
+ if fldTagName == "" {
+ continue
+ }
+ if fldTagAttr != "" {
+ panic(fmt.Errorf("unknown pdef field tag attrs %q", fldTagAttr))
+ }
+
+ fldPath := append(path, fldTagName)
+ if filter != nil && !filter(fldPath...) {
+ continue
+ }
+
+ if needComma {
+ b.WriteByte(',')
+ needComma = false
+ }
+
+ b.WriteString("\"" + fldTagName + "\":")
+ needComma = true
+
+ switch fldTyp.Type.Kind() {
+ case reflect.Struct:
+ buf, err := pdataMarshalJSONStruct(fld, filter, fldPath...)
+ if err != nil {
+ return nil, err
+ }
+ b.Write(buf)
+ case reflect.Array:
+ b.WriteByte('[')
+ for j := 0; j < fldTyp.Type.Len(); j++ {
+ fldValElemVal := fldVal.Index(j)
+ fldValElemTyp := fldValElemVal.Type()
+ fldValElem := fldValElemVal.Interface()
+
+ if j != 0 {
+ b.WriteByte(',')
+ }
+ if fldValElemTyp.Kind() == reflect.Struct {
+ buf, err := pdataMarshalJSONStruct(fldValElem, filter, fldPath...)
+ if err != nil {
+ return nil, err
+ }
+ b.Write(buf)
+ } else {
+ buf, err := pdataMarshalJSONPrimitive(fldValElem)
+ if err != nil {
+ return nil, err
+ }
+ b.Write(buf)
+ }
+ }
+ b.WriteByte(']')
+ default:
+ buf, err := pdataMarshalJSONPrimitive(fld)
+ if err != nil {
+ return nil, err
+ }
+ b.Write(buf)
+ }
+ }
+ b.WriteByte('}')
+ return b.Bytes(), nil
+ }
+
+ func pdataMarshalJSONPrimitive(v any) ([]byte, error) {
+ switch v := v.(type) {
+ case float32:
+ if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
+ return []byte("null"), nil // match the JS behaviour
+ }
+ return json.Marshal(v)
+ case int32, bool, string:
+ return json.Marshal(v)
+ case json.Marshaler:
+ if reflect.TypeOf(v).ConvertibleTo(reflect.TypeOf(uint8(0))) {
+ return v.MarshalJSON() // enum
+ }
+ panic(fmt.Errorf("unhandled type %T", v))
+ default:
+ panic(fmt.Errorf("unhandled type %T", v))
+ }
+ }
+ `)
+
if err := writeGo(strings.TrimSuffix(filepath.Base(f.Name()), filepath.Ext(f.Name()))+".go", buf.Bytes()); err != nil {
fmt.Fprintf(os.Stderr, "error: write generated source: %v\n", err)
os.Exit(1)
@@ -185,6 +293,7 @@ func main() {
import (
"bytes"
+ "encoding/json"
"os"
"testing"
)
@@ -221,6 +330,12 @@ func main() {
if !bytes.Equal(rbuf, ebuf) {
t.Errorf("internal round-trip failed: re-marshaled unmarshaled data encoded by marshal does not match")
}
+
+ if buf, err := d2.MarshalJSON(); err != nil {
+ t.Errorf("failed to marshal as JSON: %%v", err)
+ } else if err = json.Unmarshal(buf, new(map[string]interface{})); err != nil {
+ t.Errorf("bad json marshal result: %%v", err)
+ }
})
}
}
@@ -243,15 +358,17 @@ func generateStruct(buf *bytes.Buffer, name string, pd *pdef.Pdef, fields []pdef
{
pln(buf, `type %s struct {`, mangle(name, true))
for _, v := range fields {
- pln(buf, `%s %s`+" `json:%q`", mangle(v.Name, true), pdefGoType(v.Type), v.Name)
+ pln(buf, `%s %s`+" `pdef:%q`", mangle(v.Name, true), pdefGoType(v.Type), v.Name)
}
if root {
- pln(buf, `%s %s`+" `json:%q`", mangle("extraData", true), "[]byte", "extraData,omitempty")
+ pln(buf, `%s %s`, mangle("extraData", true), "[]byte")
}
pln(buf, `}`)
pln(buf, `var _ encoding.BinaryUnmarshaler = (*%s)(nil)`, mangle(name, true))
pln(buf, `var _ encoding.BinaryMarshaler = %s{}`, mangle(name, true))
+ pln(buf, `var _ json.Unmarshaler = (*%s)(nil)`, mangle(name, true))
+ pln(buf, `var _ json.Marshaler = %s{}`, mangle(name, true))
}
{
pln(buf, `func (v *%s) UnmarshalBinary(b []byte) error {`, mangle(name, true))
@@ -433,6 +550,27 @@ func generateStruct(buf *bytes.Buffer, name string, pd *pdef.Pdef, fields []pdef
pln(buf, `return b, nil`)
pln(buf, `}`)
}
+ {
+ pln(buf, `func (v *%s) UnmarshalJSON(b []byte) error {`, mangle(name, true))
+ // TODO: implement this if we actually have a use for it
+ pln(buf, `return fmt.Errorf("not implemented")`)
+ pln(buf, `}`)
+ }
+ {
+ pln(buf, `func (v %s) MarshalJSON() ([]byte, error) {`, mangle(name, true))
+ pln(buf, `return v.MarshalJSONFilter(nil)`)
+ pln(buf, `}`)
+ }
+ {
+ pln(buf, `func (v %s) MarshalJSONFilter(filter func(path ...string) bool) ([]byte, error) {`, mangle(name, true))
+ if root {
+ pln(buf, `if x := v.%s; x != Version {`, mangle(fields[0].Name, true))
+ pln(buf, `return nil, fmt.Errorf(%#v, %#v, Version, ErrUnsupportedVersion, x)`, `encode %q (v%d): %w: got %d`, name)
+ pln(buf, `}`)
+ }
+ pln(buf, `return pdataMarshalJSONStruct(v, filter)`)
+ pln(buf, `}`)
+ }
}
func generateEnum(buf *bytes.Buffer, name string, values []string) {
@@ -447,8 +585,10 @@ func generateEnum(buf *bytes.Buffer, name string, values []string) {
pln(buf, `var _ fmt.Stringer = %s(0)`, mangle(name, true))
pln(buf, `var _ fmt.GoStringer = %s(0)`, mangle(name, true))
- pln(buf, `//var _ encoding.TextMarshaler = %s(0)`, mangle(name, true))
+ pln(buf, `var _ encoding.TextMarshaler = %s(0)`, mangle(name, true))
pln(buf, `var _ encoding.TextUnmarshaler = (*%s)(nil)`, mangle(name, true))
+ pln(buf, `var _ json.Marshaler = %s(0)`, mangle(name, true))
+ pln(buf, `var _ json.Unmarshaler = (*%s)(nil)`, mangle(name, true))
}
{
pln(buf, `func (v %s) String() string {`, mangle(name, true))
@@ -495,6 +635,31 @@ func generateEnum(buf *bytes.Buffer, name string, values []string) {
pln(buf, `return nil`)
pln(buf, `}`)
}
+ {
+ pln(buf, `func (v %s) MarshalJSON() ([]byte, error) {`, mangle(name, true))
+ pln(buf, `switch v {`)
+ for _, v := range values {
+ pln(buf, `case %s:`, mangleEnumValue(name, v))
+ pln(buf, `return []byte(%#v), nil`, `"`+v+`"`)
+ }
+ pln(buf, `default:`)
+ pln(buf, `return []byte(strconv.Itoa(int(v))), nil`)
+ pln(buf, `}`)
+ pln(buf, `}`)
+ }
+ {
+ pln(buf, `func (v *%s) UnmarshalJSON(b []byte) error {`, mangle(name, true))
+ pln(buf, `switch string(b) {`)
+ for _, v := range values {
+ pln(buf, `case %#v:`, `"`+v+`"`)
+ pln(buf, `*v = %s`, mangleEnumValue(name, v))
+ }
+ pln(buf, `default:`)
+ pln(buf, `return json.Unmarshal(b, (*uint8)(v))`)
+ pln(buf, `}`)
+ pln(buf, `return nil`)
+ pln(buf, `}`)
+ }
}
func pdefGoType(t pdef.TypeInfo) string {