Files
CableGUI/internal/db/migrate.go

95 lines
2.1 KiB
Go

package db
import (
"database/sql"
"embed"
"fmt"
"sort"
"strings"
)
//go:embed migrations/*.sql
var migrationFS embed.FS
// Migrate applies any pending SQL files from migrations/*.sql in
// lexicographic order against the given *sql.DB. Applied filenames are
// tracked in schema_migrations so each runs at most once. Idempotent.
func Migrate(d *sql.DB) error {
if _, err := d.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
name TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
)
`); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
applied, err := loadApplied(d)
if err != nil {
return err
}
entries, err := migrationFS.ReadDir("migrations")
if err != nil {
return fmt.Errorf("read migrations dir: %w", err)
}
names := make([]string, 0, len(entries))
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".sql") {
continue
}
names = append(names, e.Name())
}
sort.Strings(names)
for _, name := range names {
if applied[name] {
continue
}
body, err := migrationFS.ReadFile("migrations/" + name)
if err != nil {
return fmt.Errorf("read %s: %w", name, err)
}
if err := runMigration(d, name, string(body)); err != nil {
return err
}
}
return nil
}
func loadApplied(d *sql.DB) (map[string]bool, error) {
rows, err := d.Query("SELECT name FROM schema_migrations")
if err != nil {
return nil, fmt.Errorf("load applied: %w", err)
}
defer rows.Close()
out := map[string]bool{}
for rows.Next() {
var n string
if err := rows.Scan(&n); err != nil {
return nil, err
}
out[n] = true
}
return out, rows.Err()
}
func runMigration(d *sql.DB, name, body string) error {
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("begin %s: %w", name, err)
}
if _, err := tx.Exec(body); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply %s: %w", name, err)
}
if _, err := tx.Exec("INSERT INTO schema_migrations (name) VALUES (?)", name); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record %s: %w", name, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit %s: %w", name, err)
}
return nil
}