From 8d5f88b727fb413f95e8818b9e7e9ad26806df7c Mon Sep 17 00:00:00 2001 From: pg9182 <96569817+pg9182@users.noreply.github.com> Date: Wed, 19 Oct 2022 08:11:23 -0400 Subject: db/atlasdb: Implement sqlite3 AccountStorage --- db/atlasdb/001_init_db.go | 43 +++++++++++ db/atlasdb/db.go | 118 +++++++++++++++++++++++++++++ db/atlasdb/db_test.go | 31 ++++++++ db/atlasdb/migrations.go | 170 ++++++++++++++++++++++++++++++++++++++++++ db/atlasdb/migrations_test.go | 49 ++++++++++++ 5 files changed, 411 insertions(+) create mode 100644 db/atlasdb/001_init_db.go create mode 100644 db/atlasdb/db.go create mode 100644 db/atlasdb/db_test.go create mode 100644 db/atlasdb/migrations.go create mode 100644 db/atlasdb/migrations_test.go (limited to 'db') diff --git a/db/atlasdb/001_init_db.go b/db/atlasdb/001_init_db.go new file mode 100644 index 0000000..8550fc0 --- /dev/null +++ b/db/atlasdb/001_init_db.go @@ -0,0 +1,43 @@ +package atlasdb + +import ( + "context" + "fmt" + "strings" + + "github.com/jmoiron/sqlx" +) + +func init() { + migrate(up001, down001) +} + +func up001(ctx context.Context, tx *sqlx.Tx) error { + if _, err := tx.ExecContext(ctx, strings.ReplaceAll(` + CREATE TABLE accounts ( + uid TEXT PRIMARY KEY NOT NULL, + username TEXT NOT NULL DEFAULT '' COLLATE NOCASE, + auth_ip TEXT, + auth_token TEXT, + auth_expiry INTEGER, + last_server TEXT + ) STRICT; + `, ` + `, "\n")); err != nil { + return fmt.Errorf("create accounts table: %w", err) + } + if _, err := tx.ExecContext(ctx, `CREATE INDEX accounts_username_idx ON accounts(username, uid)`); err != nil { + return fmt.Errorf("create accounts index: %w", err) + } + return nil +} + +func down001(ctx context.Context, tx *sqlx.Tx) error { + if _, err := tx.ExecContext(ctx, `DROP INDEX accounts_username_idx`); err != nil { + return fmt.Errorf("drop accounts_username_idx index: %w", err) + } + if _, err := tx.ExecContext(ctx, `DROP TABLE accounts`); err != nil { + return fmt.Errorf("drop accounts table: %w", err) + } + return nil +} diff --git a/db/atlasdb/db.go b/db/atlasdb/db.go new file mode 100644 index 0000000..d58a784 --- /dev/null +++ b/db/atlasdb/db.go @@ -0,0 +1,118 @@ +// Package atlasdb implements sqlite3 database storage for accounts and other atlas data. +package atlasdb + +import ( + "database/sql" + "errors" + "fmt" + "net/netip" + "net/url" + "time" + + "github.com/jmoiron/sqlx" + "github.com/pg9182/atlas/pkg/api/api0" +) + +// DB stores atlas data in a sqlite3 database. +type DB struct { + x *sqlx.DB +} + +// Open opens a DB from the provided sqlite3 filename. +func Open(name string) (*DB, error) { + // note: WAL and a larger cache makes our writes and queries MUCH faster + x, err := sqlx.Connect("sqlite3", (&url.URL{ + Path: name, + RawQuery: (url.Values{ + "_journal": {"WAL"}, + "_cache_size": {"-32000"}, + "_busy_timeout": {"4000"}, + }).Encode(), + }).String()) + if err != nil { + return nil, err + } + return &DB{x}, nil +} + +func (db *DB) Close() error { + return db.x.Close() +} + +func (db *DB) GetUIDsByUsername(username string) ([]uint64, error) { + var u []uint64 + if username != "" { + if err := db.x.Select(&u, `SELECT uid FROM accounts WHERE username = ?`, username); err != nil { + return nil, err + } + } + return u, nil +} + +func (db *DB) GetAccount(uid uint64) (*api0.Account, error) { + var obj struct { + UID uint64 `db:"uid"` + Username string `db:"username"` + AuthIP string `db:"auth_ip"` + AuthToken string `db:"auth_token"` + AuthExpiry int64 `db:"auth_expiry"` + LastServer string `db:"last_server"` + } + if err := db.x.Get(&obj, `SELECT * FROM accounts WHERE uid = ?`, uid); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + + var authExpiry time.Time + if obj.AuthExpiry != 0 { + authExpiry = time.Unix(obj.AuthExpiry, 0) + } + + var authIP netip.Addr + if obj.AuthIP != "" { + if v, err := netip.ParseAddr(obj.AuthIP); err == nil { + authIP = v + } else { + return nil, fmt.Errorf("parse auth_ip: %w", err) + } + } + + return &api0.Account{ + UID: obj.UID, + Username: obj.Username, + AuthIP: authIP, + AuthToken: obj.AuthToken, + AuthTokenExpiry: authExpiry, + LastServerID: obj.LastServer, + }, nil +} + +func (db *DB) SaveAccount(a *api0.Account) error { + var authExpiry int64 + if !a.AuthTokenExpiry.IsZero() { + authExpiry = a.AuthTokenExpiry.Unix() + } + + var authIP string + if a.AuthIP.IsValid() { + authIP = a.AuthIP.StringExpanded() + } + + if _, err := db.x.NamedExec(` + INSERT OR REPLACE INTO + accounts ( uid, username, auth_ip, auth_token, auth_expiry, last_server) + VALUES (:uid, :username, :auth_ip, :auth_token, :auth_expiry, :last_server) + `, map[string]any{ + "uid": a.UID, + "username": a.Username, + "auth_ip": authIP, + "auth_token": a.AuthToken, + "auth_expiry": authExpiry, + "last_server": a.LastServerID, + }); err != nil { + return err + } + return nil +} diff --git a/db/atlasdb/db_test.go b/db/atlasdb/db_test.go new file mode 100644 index 0000000..1af59f1 --- /dev/null +++ b/db/atlasdb/db_test.go @@ -0,0 +1,31 @@ +package atlasdb + +import ( + "context" + "path/filepath" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/pg9182/atlas/pkg/api/api0/api0testutil" +) + +func TestAccountStorage(t *testing.T) { + db, err := Open(filepath.Join(t.TempDir(), "atlas.db")) + if err != nil { + panic(err) + } + defer db.Close() + + cur, tgt, err := db.Version() + if err != nil { + panic(err) + } + if cur != 0 { + panic("current version not 0") + } + if err := db.MigrateUp(context.Background(), tgt); err != nil { + panic(err) + } + + api0testutil.TestAccountStorage(t, db) +} diff --git a/db/atlasdb/migrations.go b/db/atlasdb/migrations.go new file mode 100644 index 0000000..3aa47bb --- /dev/null +++ b/db/atlasdb/migrations.go @@ -0,0 +1,170 @@ +package atlasdb + +import ( + "context" + "database/sql" + "fmt" + "path" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" +) + +// TODO: support versions which can't be migrated down from + +type migration struct { + Name string + Up func(context.Context, *sqlx.Tx) error + Down func(context.Context, *sqlx.Tx) error +} + +var migrations = map[uint64]migration{} + +func migrate(up, down func(context.Context, *sqlx.Tx) error) { + _, fn, _, ok := runtime.Caller(1) + if !ok { + panic("add migration: failed to get filename") + } + fn = path.Base(strings.ReplaceAll(fn, `\`, `/`)) + + if n, _, ok := strings.Cut(fn, "_"); !ok { + panic("add migration: failed to parse filename") + } else if v, err := strconv.ParseUint(n, 10, 64); err != nil { + panic("add migration: failed to parse filename: " + err.Error()) + } else if v == 0 { + panic("add migration: version must not be 0") + } else { + migrations[v] = migration{strings.TrimSuffix(n, ".go"), up, down} + } +} + +// Version gets the current and required database versions. It should be checked +// before using the database. +func (db *DB) Version() (current, required uint64, err error) { + if err = db.x.Get(¤t, `PRAGMA user_version`); err != nil { + err = fmt.Errorf("get version: %w", err) + return + } + for v := range migrations { + if v > required { + required = v + } + } + return +} + +// MigrateUp migrates the database to the provided version. +func (db *DB) MigrateUp(ctx context.Context, to uint64) error { + tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + var cv uint64 + if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil { + return fmt.Errorf("get version: %w", err) + } + if to < cv { + return fmt.Errorf("target version %d is less than current version %d", to, cv) + } + + var ms []uint64 + foundC, foundT := cv == 0, to == 0 + for v := range migrations { + if v == cv { + foundC = true + } + if v == to { + foundT = true + } + if v > cv && v <= to { + ms = append(ms, v) + } + } + if !foundC { + return fmt.Errorf("unsupported db version %d", cv) + } + if !foundT { + return fmt.Errorf("unknown db version %d", cv) + } + + sort.Slice(ms, func(i, j int) bool { + return ms[i] < ms[j] + }) + + for _, v := range ms { + if err := migrations[v].Up(ctx, tx); err != nil { + return fmt.Errorf("migrate %d: %w", v, err) + } + } + + if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil { + return fmt.Errorf("update version: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} + +// MigrateDown migrates the database down to the provided version. This will +// probably eat your data. +func (db *DB) MigrateDown(ctx context.Context, to uint64) error { + tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + var cv uint64 + if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil { + return fmt.Errorf("get version: %w", err) + } + if cv < to { + return fmt.Errorf("current version %d is less than target version %d", cv, to) + } + + var ms []uint64 + foundC, foundT := cv == 0, to == 0 + for v := range migrations { + if v == cv { + foundC = true + } + if v == to { + foundT = true + } + if v <= cv && v > to { + ms = append(ms, v) + } + } + if !foundC { + return fmt.Errorf("unsupported db version %d", cv) + } + if !foundT { + return fmt.Errorf("unknown db version %d", cv) + } + + sort.Slice(ms, func(i, j int) bool { + return ms[i] > ms[j] + }) + + for _, v := range ms { + if err := migrations[v].Down(ctx, tx); err != nil { + return fmt.Errorf("migrate %d: %w", v, err) + } + } + + if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil { + return fmt.Errorf("update version: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} diff --git a/db/atlasdb/migrations_test.go b/db/atlasdb/migrations_test.go new file mode 100644 index 0000000..2149717 --- /dev/null +++ b/db/atlasdb/migrations_test.go @@ -0,0 +1,49 @@ +package atlasdb + +import ( + "context" + "path/filepath" + "sort" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestMigrations(t *testing.T) { + db, err := Open(filepath.Join(t.TempDir(), "atlas.db")) + if err != nil { + panic(err) + } + defer db.Close() + + cur, _, err := db.Version() + if err != nil { + panic(err) + } + if cur != 0 { + t.Fatalf("current version not 0") + } + + var ms []uint64 + for m := range migrations { + ms = append(ms, m) + } + sort.Slice(ms, func(i, j int) bool { + return ms[i] < ms[j] + }) + + for _, to := range ms { + if err := db.MigrateUp(context.Background(), to); err != nil { + t.Fatalf("migrate up to %d: %v", to, err) + } + if err := db.MigrateDown(context.Background(), 0); err != nil { + t.Fatalf("migrate down from %d to 0: %v", to, err) + } + if err := db.MigrateUp(context.Background(), to); err != nil { + t.Fatalf("migrate up to %d again: %v", to, err) + } + if err := db.MigrateDown(context.Background(), 0); err != nil { + t.Fatalf("migrate down from %d to 0 again: %v", to, err) + } + } +} -- cgit v1.2.3