1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
package atlasdb
import (
"context"
"database/sql"
"fmt"
"path"
"runtime"
"sort"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
// TODO: support versions which can't be migrated down from
type migration struct {
Name string
Up func(context.Context, *sqlx.Tx) error
Down func(context.Context, *sqlx.Tx) error
}
var migrations = map[uint64]migration{}
func migrate(up, down func(context.Context, *sqlx.Tx) error) {
_, fn, _, ok := runtime.Caller(1)
if !ok {
panic("add migration: failed to get filename")
}
fn = path.Base(strings.ReplaceAll(fn, `\`, `/`))
if n, _, ok := strings.Cut(fn, "_"); !ok {
panic("add migration: failed to parse filename")
} else if v, err := strconv.ParseUint(n, 10, 64); err != nil {
panic("add migration: failed to parse filename: " + err.Error())
} else if v == 0 {
panic("add migration: version must not be 0")
} else {
migrations[v] = migration{strings.TrimSuffix(n, ".go"), up, down}
}
}
// Version gets the current and required database versions. It should be checked
// before using the database.
func (db *DB) Version() (current, required uint64, err error) {
if err = db.x.Get(¤t, `PRAGMA user_version`); err != nil {
err = fmt.Errorf("get version: %w", err)
return
}
for v := range migrations {
if v > required {
required = v
}
}
return
}
// MigrateUp migrates the database to the provided version.
func (db *DB) MigrateUp(ctx context.Context, to uint64) error {
tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
var cv uint64
if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil {
return fmt.Errorf("get version: %w", err)
}
if to < cv {
return fmt.Errorf("target version %d is less than current version %d", to, cv)
}
var ms []uint64
foundC, foundT := cv == 0, to == 0
for v := range migrations {
if v == cv {
foundC = true
}
if v == to {
foundT = true
}
if v > cv && v <= to {
ms = append(ms, v)
}
}
if !foundC {
return fmt.Errorf("unsupported db version %d", cv)
}
if !foundT {
return fmt.Errorf("unknown db version %d", cv)
}
sort.Slice(ms, func(i, j int) bool {
return ms[i] < ms[j]
})
for _, v := range ms {
if err := migrations[v].Up(ctx, tx); err != nil {
return fmt.Errorf("migrate %d: %w", v, err)
}
}
if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil {
return fmt.Errorf("update version: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
// MigrateDown migrates the database down to the provided version. This will
// probably eat your data.
func (db *DB) MigrateDown(ctx context.Context, to uint64) error {
tx, err := db.x.BeginTxx(ctx, &sql.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
var cv uint64
if err = tx.GetContext(ctx, &cv, `PRAGMA user_version`); err != nil {
return fmt.Errorf("get version: %w", err)
}
if cv < to {
return fmt.Errorf("current version %d is less than target version %d", cv, to)
}
var ms []uint64
foundC, foundT := cv == 0, to == 0
for v := range migrations {
if v == cv {
foundC = true
}
if v == to {
foundT = true
}
if v <= cv && v > to {
ms = append(ms, v)
}
}
if !foundC {
return fmt.Errorf("unsupported db version %d", cv)
}
if !foundT {
return fmt.Errorf("unknown db version %d", cv)
}
sort.Slice(ms, func(i, j int) bool {
return ms[i] > ms[j]
})
for _, v := range ms {
if err := migrations[v].Down(ctx, tx); err != nil {
return fmt.Errorf("migrate %d: %w", v, err)
}
}
if _, err := tx.ExecContext(ctx, `PRAGMA user_version = `+strconv.FormatUint(to, 10)); err != nil {
return fmt.Errorf("update version: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
|