aboutsummaryrefslogtreecommitdiff
path: root/db/atlasdb
diff options
context:
space:
mode:
Diffstat (limited to 'db/atlasdb')
-rw-r--r--db/atlasdb/001_init_db.go43
-rw-r--r--db/atlasdb/db.go118
-rw-r--r--db/atlasdb/db_test.go31
-rw-r--r--db/atlasdb/migrations.go170
-rw-r--r--db/atlasdb/migrations_test.go49
5 files changed, 411 insertions, 0 deletions
diff --git a/db/atlasdb/001_init_db.go b/db/atlasdb/001_init_db.go
new file mode 100644
index 0000000..8550fc0
--- /dev/null
+++ b/db/atlasdb/001_init_db.go
@@ -0,0 +1,43 @@
+package atlasdb
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/jmoiron/sqlx"
+)
+
+func init() {
+ migrate(up001, down001)
+}
+
+func up001(ctx context.Context, tx *sqlx.Tx) error {
+ if _, err := tx.ExecContext(ctx, strings.ReplaceAll(`
+ CREATE TABLE accounts (
+ uid TEXT PRIMARY KEY NOT NULL,
+ username TEXT NOT NULL DEFAULT '' COLLATE NOCASE,
+ auth_ip TEXT,
+ auth_token TEXT,
+ auth_expiry INTEGER,
+ last_server TEXT
+ ) STRICT;
+ `, `
+ `, "\n")); err != nil {
+ return fmt.Errorf("create accounts table: %w", err)
+ }
+ if _, err := tx.ExecContext(ctx, `CREATE INDEX accounts_username_idx ON accounts(username, uid)`); err != nil {
+ return fmt.Errorf("create accounts index: %w", err)
+ }
+ return nil
+}
+
+func down001(ctx context.Context, tx *sqlx.Tx) error {
+ if _, err := tx.ExecContext(ctx, `DROP INDEX accounts_username_idx`); err != nil {
+ return fmt.Errorf("drop accounts_username_idx index: %w", err)
+ }
+ if _, err := tx.ExecContext(ctx, `DROP TABLE accounts`); err != nil {
+ return fmt.Errorf("drop accounts table: %w", err)
+ }
+ return nil
+}
diff --git a/db/atlasdb/db.go b/db/atlasdb/db.go
new file mode 100644
index 0000000..d58a784
--- /dev/null
+++ b/db/atlasdb/db.go
@@ -0,0 +1,118 @@
+// Package atlasdb implements sqlite3 database storage for accounts and other atlas data.
+package atlasdb
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/netip"
+ "net/url"
+ "time"
+
+ "github.com/jmoiron/sqlx"
+ "github.com/pg9182/atlas/pkg/api/api0"
+)
+
+// DB stores atlas data in a sqlite3 database.
+type DB struct {
+ x *sqlx.DB
+}
+
+// Open opens a DB from the provided sqlite3 filename.
+func Open(name string) (*DB, error) {
+ // note: WAL and a larger cache makes our writes and queries MUCH faster
+ x, err := sqlx.Connect("sqlite3", (&url.URL{
+ Path: name,
+ RawQuery: (url.Values{
+ "_journal": {"WAL"},
+ "_cache_size": {"-32000"},
+ "_busy_timeout": {"4000"},
+ }).Encode(),
+ }).String())
+ if err != nil {
+ return nil, err
+ }
+ return &DB{x}, nil
+}
+
+func (db *DB) Close() error {
+ return db.x.Close()
+}
+
+func (db *DB) GetUIDsByUsername(username string) ([]uint64, error) {
+ var u []uint64
+ if username != "" {
+ if err := db.x.Select(&u, `SELECT uid FROM accounts WHERE username = ?`, username); err != nil {
+ return nil, err
+ }
+ }
+ return u, nil
+}
+
+func (db *DB) GetAccount(uid uint64) (*api0.Account, error) {
+ var obj struct {
+ UID uint64 `db:"uid"`
+ Username string `db:"username"`
+ AuthIP string `db:"auth_ip"`
+ AuthToken string `db:"auth_token"`
+ AuthExpiry int64 `db:"auth_expiry"`
+ LastServer string `db:"last_server"`
+ }
+ if err := db.x.Get(&obj, `SELECT * FROM accounts WHERE uid = ?`, uid); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ var authExpiry time.Time
+ if obj.AuthExpiry != 0 {
+ authExpiry = time.Unix(obj.AuthExpiry, 0)
+ }
+
+ var authIP netip.Addr
+ if obj.AuthIP != "" {
+ if v, err := netip.ParseAddr(obj.AuthIP); err == nil {
+ authIP = v
+ } else {
+ return nil, fmt.Errorf("parse auth_ip: %w", err)
+ }
+ }
+
+ return &api0.Account{
+ UID: obj.UID,
+ Username: obj.Username,
+ AuthIP: authIP,
+ AuthToken: obj.AuthToken,
+ AuthTokenExpiry: authExpiry,
+ LastServerID: obj.LastServer,
+ }, nil
+}
+
+func (db *DB) SaveAccount(a *api0.Account) error {
+ var authExpiry int64
+ if !a.AuthTokenExpiry.IsZero() {
+ authExpiry = a.AuthTokenExpiry.Unix()
+ }
+
+ var authIP string
+ if a.AuthIP.IsValid() {
+ authIP = a.AuthIP.StringExpanded()
+ }
+
+ if _, err := db.x.NamedExec(`
+ INSERT OR REPLACE INTO
+ accounts ( uid, username, auth_ip, auth_token, auth_expiry, last_server)
+ VALUES (:uid, :username, :auth_ip, :auth_token, :auth_expiry, :last_server)
+ `, map[string]any{
+ "uid": a.UID,
+ "username": a.Username,
+ "auth_ip": authIP,
+ "auth_token": a.AuthToken,
+ "auth_expiry": authExpiry,
+ "last_server": a.LastServerID,
+ }); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/db/atlasdb/db_test.go b/db/atlasdb/db_test.go
new file mode 100644
index 0000000..1af59f1
--- /dev/null
+++ b/db/atlasdb/db_test.go
@@ -0,0 +1,31 @@
+package atlasdb
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/pg9182/atlas/pkg/api/api0/api0testutil"
+)
+
+func TestAccountStorage(t *testing.T) {
+ db, err := Open(filepath.Join(t.TempDir(), "atlas.db"))
+ if err != nil {
+ panic(err)
+ }
+ defer db.Close()
+
+ cur, tgt, err := db.Version()
+ if err != nil {
+ panic(err)
+ }
+ if cur != 0 {
+ panic("current version not 0")
+ }
+ if err := db.MigrateUp(context.Background(), tgt); err != nil {
+ panic(err)
+ }
+
+ api0testutil.TestAccountStorage(t, db)
+}
diff --git a/db/atlasdb/migrations.go b/db/atlasdb/migrations.go
new file mode 100644
index 0000000..3aa47bb
--- /dev/null
+++ b/db/atlasdb/migrations.go
@@ -0,0 +1,170 @@
+package atlasdb
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "path"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/jmoiron/sqlx"
+)
+
+// TODO: support versions which can't be migrated down from
+
+type migration struct {
+ Name string
+ Up func(context.Context, *sqlx.Tx) error
+ Down func(context.Context, *sqlx.Tx) error
+}
+
+var migrations = map[uint64]migration{}
+
+func migrate(up, down func(context.Context, *sqlx.Tx) error) {
+ _, fn, _, ok := runtime.Caller(1)
+ if !ok {
+ panic("add migration: failed to get filename")
+ }
+ fn = path.Base(strings.ReplaceAll(fn, `\`, `/`))
+
+ if n, _, ok := strings.Cut(fn, "_"); !ok {
+ panic("add migration: failed to parse filename")
+ } else if v, err := strconv.ParseUint(n, 10, 64); err != nil {
+ panic("add migration: failed to parse filename: " + err.Error())
+ } else if v == 0 {
+ panic("add migration: version must not be 0")
+ } else {
+ migrations[v] = migration{strings.TrimSuffix(n, ".go"), up, down}
+ }
+}
+
+// Version gets the current and required database versions. It should be checked
+// before using the database.
+func (db *DB) Version() (current, required uint64, err error) {
+ if err = db.x.Get(&current, `PRAGMA user_version`); err != nil {
+ err = fmt.Errorf("get version: %w", err)
+ return
+ }
+ for v := range migrations {
+ if v > required {
+ required = v
+ }
+ }
+ return
+}
+
+// MigrateUp migrates the database to the provided version.
+func (db *DB) MigrateUp(ctx context.Context, to uint64) error {
+ tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{})
+ if err != nil {
+ return fmt.Errorf("begin transaction: %w", err)
+ }
+ defer tx.Rollback()
+
+ var cv uint64
+ if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil {
+ return fmt.Errorf("get version: %w", err)
+ }
+ if to < cv {
+ return fmt.Errorf("target version %d is less than current version %d", to, cv)
+ }
+
+ var ms []uint64
+ foundC, foundT := cv == 0, to == 0
+ for v := range migrations {
+ if v == cv {
+ foundC = true
+ }
+ if v == to {
+ foundT = true
+ }
+ if v > cv && v <= to {
+ ms = append(ms, v)
+ }
+ }
+ if !foundC {
+ return fmt.Errorf("unsupported db version %d", cv)
+ }
+ if !foundT {
+ return fmt.Errorf("unknown db version %d", cv)
+ }
+
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i] < ms[j]
+ })
+
+ for _, v := range ms {
+ if err := migrations[v].Up(ctx, tx); err != nil {
+ return fmt.Errorf("migrate %d: %w", v, err)
+ }
+ }
+
+ if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil {
+ return fmt.Errorf("update version: %w", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit transaction: %w", err)
+ }
+ return nil
+}
+
+// MigrateDown migrates the database down to the provided version. This will
+// probably eat your data.
+func (db *DB) MigrateDown(ctx context.Context, to uint64) error {
+ tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{})
+ if err != nil {
+ return fmt.Errorf("begin transaction: %w", err)
+ }
+ defer tx.Rollback()
+
+ var cv uint64
+ if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil {
+ return fmt.Errorf("get version: %w", err)
+ }
+ if cv < to {
+ return fmt.Errorf("current version %d is less than target version %d", cv, to)
+ }
+
+ var ms []uint64
+ foundC, foundT := cv == 0, to == 0
+ for v := range migrations {
+ if v == cv {
+ foundC = true
+ }
+ if v == to {
+ foundT = true
+ }
+ if v <= cv && v > to {
+ ms = append(ms, v)
+ }
+ }
+ if !foundC {
+ return fmt.Errorf("unsupported db version %d", cv)
+ }
+ if !foundT {
+ return fmt.Errorf("unknown db version %d", cv)
+ }
+
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i] > ms[j]
+ })
+
+ for _, v := range ms {
+ if err := migrations[v].Down(ctx, tx); err != nil {
+ return fmt.Errorf("migrate %d: %w", v, err)
+ }
+ }
+
+ if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil {
+ return fmt.Errorf("update version: %w", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit transaction: %w", err)
+ }
+ return nil
+}
diff --git a/db/atlasdb/migrations_test.go b/db/atlasdb/migrations_test.go
new file mode 100644
index 0000000..2149717
--- /dev/null
+++ b/db/atlasdb/migrations_test.go
@@ -0,0 +1,49 @@
+package atlasdb
+
+import (
+ "context"
+ "path/filepath"
+ "sort"
+ "testing"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+func TestMigrations(t *testing.T) {
+ db, err := Open(filepath.Join(t.TempDir(), "atlas.db"))
+ if err != nil {
+ panic(err)
+ }
+ defer db.Close()
+
+ cur, _, err := db.Version()
+ if err != nil {
+ panic(err)
+ }
+ if cur != 0 {
+ t.Fatalf("current version not 0")
+ }
+
+ var ms []uint64
+ for m := range migrations {
+ ms = append(ms, m)
+ }
+ sort.Slice(ms, func(i, j int) bool {
+ return ms[i] < ms[j]
+ })
+
+ for _, to := range ms {
+ if err := db.MigrateUp(context.Background(), to); err != nil {
+ t.Fatalf("migrate up to %d: %v", to, err)
+ }
+ if err := db.MigrateDown(context.Background(), 0); err != nil {
+ t.Fatalf("migrate down from %d to 0: %v", to, err)
+ }
+ if err := db.MigrateUp(context.Background(), to); err != nil {
+ t.Fatalf("migrate up to %d again: %v", to, err)
+ }
+ if err := db.MigrateDown(context.Background(), 0); err != nil {
+ t.Fatalf("migrate down from %d to 0 again: %v", to, err)
+ }
+ }
+}