95 lines
2.1 KiB
Go
95 lines
2.1 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationFS embed.FS
|
|
|
|
// Migrate applies any pending SQL files from migrations/*.sql in
|
|
// lexicographic order against the given *sql.DB. Applied filenames are
|
|
// tracked in schema_migrations so each runs at most once. Idempotent.
|
|
func Migrate(d *sql.DB) error {
|
|
if _, err := d.Exec(`
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
name TEXT PRIMARY KEY,
|
|
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
|
)
|
|
`); err != nil {
|
|
return fmt.Errorf("create schema_migrations: %w", err)
|
|
}
|
|
|
|
applied, err := loadApplied(d)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
entries, err := migrationFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations dir: %w", err)
|
|
}
|
|
names := make([]string, 0, len(entries))
|
|
for _, e := range entries {
|
|
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
|
|
continue
|
|
}
|
|
names = append(names, e.Name())
|
|
}
|
|
sort.Strings(names)
|
|
|
|
for _, name := range names {
|
|
if applied[name] {
|
|
continue
|
|
}
|
|
body, err := migrationFS.ReadFile("migrations/" + name)
|
|
if err != nil {
|
|
return fmt.Errorf("read %s: %w", name, err)
|
|
}
|
|
if err := runMigration(d, name, string(body)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadApplied(d *sql.DB) (map[string]bool, error) {
|
|
rows, err := d.Query("SELECT name FROM schema_migrations")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load applied: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
out := map[string]bool{}
|
|
for rows.Next() {
|
|
var n string
|
|
if err := rows.Scan(&n); err != nil {
|
|
return nil, err
|
|
}
|
|
out[n] = true
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func runMigration(d *sql.DB, name, body string) error {
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("begin %s: %w", name, err)
|
|
}
|
|
if _, err := tx.Exec(body); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("apply %s: %w", name, err)
|
|
}
|
|
if _, err := tx.Exec("INSERT INTO schema_migrations (name) VALUES (?)", name); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("record %s: %w", name, err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit %s: %w", name, err)
|
|
}
|
|
return nil
|
|
}
|