diff options
author | pg9182 <96569817+pg9182@users.noreply.github.com> | 2022-10-19 08:11:35 -0400 |
---|---|---|
committer | pg9182 <96569817+pg9182@users.noreply.github.com> | 2022-10-19 08:11:35 -0400 |
commit | 4d1f3137d7aeebabd3298cb219a252cc457b473d (patch) | |
tree | 5fae7b96357559b90c77b0e6f42693c3149b37de /db/pdatadb/migrations.go | |
parent | 8d5f88b727fb413f95e8818b9e7e9ad26806df7c (diff) | |
download | Atlas-4d1f3137d7aeebabd3298cb219a252cc457b473d.tar.gz Atlas-4d1f3137d7aeebabd3298cb219a252cc457b473d.zip |
db/pdatadb: Implement sqlite3 PdataStorage
Diffstat (limited to 'db/pdatadb/migrations.go')
-rw-r--r-- | db/pdatadb/migrations.go | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/db/pdatadb/migrations.go b/db/pdatadb/migrations.go new file mode 100644 index 0000000..429659a --- /dev/null +++ b/db/pdatadb/migrations.go @@ -0,0 +1,170 @@ +package pdatadb + +import ( + "context" + "database/sql" + "fmt" + "path" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" +) + +// TODO: support versions which can't be migrated down from + +type migration struct { + Name string + Up func(context.Context, *sqlx.Tx) error + Down func(context.Context, *sqlx.Tx) error +} + +var migrations = map[uint64]migration{} + +func migrate(up, down func(context.Context, *sqlx.Tx) error) { + _, fn, _, ok := runtime.Caller(1) + if !ok { + panic("add migration: failed to get filename") + } + fn = path.Base(strings.ReplaceAll(fn, `\`, `/`)) + + if n, _, ok := strings.Cut(fn, "_"); !ok { + panic("add migration: failed to parse filename") + } else if v, err := strconv.ParseUint(n, 10, 64); err != nil { + panic("add migration: failed to parse filename: " + err.Error()) + } else if v == 0 { + panic("add migration: version must not be 0") + } else { + migrations[v] = migration{strings.TrimSuffix(n, ".go"), up, down} + } +} + +// Version gets the current and required database versions. It should be checked +// before using the database. +func (db *DB) Version() (current, required uint64, err error) { + if err = db.x.Get(¤t, `PRAGMA user_version`); err != nil { + err = fmt.Errorf("get version: %w", err) + return + } + for v := range migrations { + if v > required { + required = v + } + } + return +} + +// MigrateUp migrates the database to the provided version. +func (db *DB) MigrateUp(ctx context.Context, to uint64) error { + tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + var cv uint64 + if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil { + return fmt.Errorf("get version: %w", err) + } + if to < cv { + return fmt.Errorf("target version %d is less than current version %d", to, cv) + } + + var ms []uint64 + foundC, foundT := cv == 0, to == 0 + for v := range migrations { + if v == cv { + foundC = true + } + if v == to { + foundT = true + } + if v > cv && v <= to { + ms = append(ms, v) + } + } + if !foundC { + return fmt.Errorf("unsupported db version %d", cv) + } + if !foundT { + return fmt.Errorf("unknown db version %d", cv) + } + + sort.Slice(ms, func(i, j int) bool { + return ms[i] < ms[j] + }) + + for _, v := range ms { + if err := migrations[v].Up(ctx, tx); err != nil { + return fmt.Errorf("migrate %d: %w", v, err) + } + } + + if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil { + return fmt.Errorf("update version: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} + +// MigrateDown migrates the database down to the provided version. This will +// probably eat your data. +func (db *DB) MigrateDown(ctx context.Context, to uint64) error { + tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + var cv uint64 + if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil { + return fmt.Errorf("get version: %w", err) + } + if cv < to { + return fmt.Errorf("current version %d is less than target version %d", cv, to) + } + + var ms []uint64 + foundC, foundT := cv == 0, to == 0 + for v := range migrations { + if v == cv { + foundC = true + } + if v == to { + foundT = true + } + if v <= cv && v > to { + ms = append(ms, v) + } + } + if !foundC { + return fmt.Errorf("unsupported db version %d", cv) + } + if !foundT { + return fmt.Errorf("unknown db version %d", cv) + } + + sort.Slice(ms, func(i, j int) bool { + return ms[i] > ms[j] + }) + + for _, v := range ms { + if err := migrations[v].Down(ctx, tx); err != nil { + return fmt.Errorf("migrate %d: %w", v, err) + } + } + + if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil { + return fmt.Errorf("update version: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} |