-
Notifications
You must be signed in to change notification settings - Fork 0
/
migrations.go
186 lines (165 loc) · 5.03 KB
/
migrations.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
package migrations
import (
"database/sql"
"sort"
"time"
"fmt"
)
const migrationTableName = "__migrations"
const createMigrationTableSQL = `
CREATE TABLE __migrations (
file VARCHAR(255) NOT NULL,
timestamp DATETIME NOT NULL,
PRIMARY KEY (file));
`
// New creates the migrator context
func New(db *sql.DB) *Migrator {
if db == nil {
panic("sql.DB should'nt be nil.")
}
migrator := &Migrator{connection: db}
if !migrator.migrationsTableExists() {
migrator.createMigrationTable()
}
return migrator
}
// Migrator keeps the state of the migration. This structure is
// used to run migrations
type Migrator struct {
connection *sql.DB
}
// GetFiles defines the function interface for providing
// filepaths to the migrator. The user is expected to implement this function.
// Example of functions are `ioutil.ReadDir`, https://golang.org/pkg/io/ioutil/#ReadDir,
// and then use the file name each file, and `assets.AssetDir("migrations")`.
type GetFiles func() []string
// GetContent defines the function interface for providing content from a file.
// The expected input is a filepath and the output is the content of that file.
// The user is expected to implement this function. The string the GetContent
// takes as argument will be one of the strings from GetFiles.
type GetContent func(string) string
type migration struct {
file string
timestamp time.Time
}
// Migrate executes the migration
// - Get candidate files
// - Get already migrated files
// - Execute all the files that hasn't been migrated
// - Update migration table with result
func (migrator *Migrator) Migrate(getFiles GetFiles, getContent GetContent) {
startTime := time.Now().UTC()
logDebug("Starting migration: ", startTime)
fileNames := getFiles()
sort.Strings(fileNames)
existingMigrations := migrator.getExistingMigrations()
existingMigrationMap := make(map[string]migration)
for _, m := range existingMigrations {
existingMigrationMap[m.file] = m
}
tx, err := migrator.connection.Begin()
if err != nil {
panic("Failed to create transaction for migration: " + err.Error())
}
newMigrations := make([]migration, 0, 10)
logDebug("All migrations:", fileNames)
for _, f := range fileNames {
if _, ok := existingMigrationMap[f]; !ok {
sqlContent := getContent(f)
logDebug("Running migration: ", f)
logDebug("With content: ", sqlContent)
timestamp := time.Now().UTC()
_, err := migrator.connection.Exec(sqlContent)
if err != nil {
logError("Failed to execute migration: ", f, err)
if err := tx.Rollback(); err != nil {
err = fmt.Errorf("failed trying to roll back from %s: %w", err, err)
}
panic(err)
}
mig := migration{file: f, timestamp: timestamp}
err = migrator.addMigration(migration{file: f, timestamp: timestamp})
if err != nil {
logError("Failed to update migration table: ", err)
if err := tx.Rollback(); err != nil {
err = fmt.Errorf("failed trying to roll back from %s: %w", err, err)
}
panic(err)
}
newMigrations = append(newMigrations, mig)
}
}
err = tx.Commit()
if err != nil {
panic(err)
}
endTime := time.Now().UTC()
duration := endTime.Sub(startTime)
logDebug("Migration done: ", endTime)
logDebug("Migration duration: ", duration)
}
func (migrator *Migrator) migrationsTableExists() bool {
rows, err := migrator.connection.Query("SHOW TABLES")
if err != nil {
logError("Couldn't query for tables", err)
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
fmt.Println(fmt.Errorf("failed to close the migrationsTableExists query: %w", err))
}
}(rows)
for rows.Next() {
var tableName string
err := rows.Scan(&tableName)
if err != nil {
logError("Failed to read file item row: ", err)
}
if tableName == migrationTableName {
return true
}
}
return false
}
func (migrator *Migrator) createMigrationTable() {
_, err := migrator.connection.Exec(createMigrationTableSQL)
if err != nil {
logError("Failed to create migration table: " + err.Error())
}
}
func (migrator *Migrator) addMigration(migration migration) error {
stmt, err := migrator.connection.Prepare(fmt.Sprintf("INSERT INTO %s(file, timestamp) VALUES(?,?)", migrationTableName))
if err != nil {
return err
}
_, err = stmt.Exec(migration.file, migration.timestamp)
if err != nil {
return err
}
return nil
}
func (migrator *Migrator) getExistingMigrations() []migration {
rows, err := migrator.connection.Query(fmt.Sprintf("SELECT file, timestamp FROM %s", migrationTableName))
if err != nil {
panic("Failed to create migration select statement: " + err.Error())
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
fmt.Println(fmt.Errorf("failed to close the getExistingMigrations query: %w", err))
}
}(rows)
migrations := make([]migration, 0, 10)
for rows.Next() {
var (
file string
timestamp time.Time
)
err = rows.Scan(&file, ×tamp)
if err != nil {
panic("Failed to scan migration row: " + err.Error())
}
migrations = append(migrations, migration{file: file, timestamp: timestamp})
}
return migrations
}