diff options
author | Joe Chen <jc@unknwon.io> | 2022-06-11 11:10:25 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-11 11:10:25 +0800 |
commit | 75fbb8244086a2ad964d1c51e3bdbdfb95df90ac (patch) | |
tree | c97a7d8b17a4a351f304bfd4a7dd01a75711a1c7 /internal/db | |
parent | f837ea6346ac720d586eda51ab89c5c71a1f3e65 (diff) |
db: use `context` for backup and restore (#7044)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/backup.go | 61 | ||||
-rw-r--r-- | internal/db/backup_test.go | 20 |
2 files changed, 53 insertions, 28 deletions
diff --git a/internal/db/backup.go b/internal/db/backup.go index 836672b1..bd13fbb4 100644 --- a/internal/db/backup.go +++ b/internal/db/backup.go @@ -3,6 +3,7 @@ package db import ( "bufio" "bytes" + "context" "fmt" "io" "os" @@ -30,18 +31,24 @@ func getTableType(t interface{}) string { } // DumpDatabase dumps all data from database to file system in JSON Lines format. -func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error { +func DumpDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error { err := os.MkdirAll(dirPath, os.ModePerm) if err != nil { return err } - err = dumpLegacyTables(dirPath, verbose) + err = dumpLegacyTables(ctx, dirPath, verbose) if err != nil { return errors.Wrap(err, "dump legacy tables") } for _, table := range Tables { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + tableName := getTableType(table) if verbose { log.Trace("Dumping table %q...", tableName) @@ -55,7 +62,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error { } defer func() { _ = f.Close() }() - return dumpTable(db, table, f) + return dumpTable(ctx, db, table, f) }() if err != nil { return errors.Wrapf(err, "dump table %q", tableName) @@ -65,11 +72,13 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error { return nil } -func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error { - query := db.Model(table).Order("id ASC") +func dumpTable(ctx context.Context, db *gorm.DB, table interface{}, w io.Writer) error { + query := db.WithContext(ctx).Model(table) switch table.(type) { case *LFSObject: - query = db.Model(table).Order("repo_id, oid ASC") + query = query.Order("repo_id, oid ASC") + default: + query = query.Order("id ASC") } rows, err := query.Rows() @@ -98,10 +107,16 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error { return rows.Err() } -func dumpLegacyTables(dirPath string, verbose bool) error { +func dumpLegacyTables(ctx context.Context, dirPath string, verbose bool) error { // Purposely create a local variable to not modify global variable legacyTables := append(legacyTables, new(Version)) for _, table := range legacyTables { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + tableName := getTableType(table) if verbose { log.Trace("Dumping table %q...", tableName) @@ -113,7 +128,7 @@ func dumpLegacyTables(dirPath string, verbose bool) error { return fmt.Errorf("create JSON file: %v", err) } - if err = x.Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) { + if err = x.Context(ctx).Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) { return jsoniter.NewEncoder(f).Encode(bean) }); err != nil { _ = f.Close() @@ -125,13 +140,19 @@ func dumpLegacyTables(dirPath string, verbose bool) error { } // ImportDatabase imports data from backup archive in JSON Lines format. -func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error { - err := importLegacyTables(dirPath, verbose) +func ImportDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error { + err := importLegacyTables(ctx, dirPath, verbose) if err != nil { return errors.Wrap(err, "import legacy tables") } for _, table := range Tables { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") err := func() error { tableFile := filepath.Join(dirPath, tableName+".json") @@ -150,7 +171,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error { } defer func() { _ = f.Close() }() - return importTable(db, table, f) + return importTable(ctx, db, table, f) }() if err != nil { return errors.Wrapf(err, "import table %q", tableName) @@ -160,13 +181,13 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error { return nil } -func importTable(db *gorm.DB, table interface{}, r io.Reader) error { - err := db.Migrator().DropTable(table) +func importTable(ctx context.Context, db *gorm.DB, table interface{}, r io.Reader) error { + err := db.WithContext(ctx).Migrator().DropTable(table) if err != nil { return errors.Wrap(err, "drop table") } - err = db.Migrator().AutoMigrate(table) + err = db.WithContext(ctx).Migrator().AutoMigrate(table) if err != nil { return errors.Wrap(err, "auto migrate") } @@ -191,7 +212,7 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error { return errors.Wrap(err, "unmarshal JSON to struct") } - err = db.Create(elem).Error + err = db.WithContext(ctx).Create(elem).Error if err != nil { return errors.Wrap(err, "create row") } @@ -200,14 +221,14 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error { // PostgreSQL needs manually reset table sequence for auto increment keys if conf.UsePostgreSQL && !skipResetIDSeq[rawTableName] { seqName := rawTableName + "_id_seq" - if _, err = x.Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil { + if _, err = x.Context(ctx).Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil { return errors.Wrapf(err, "reset table %q.%q", rawTableName, seqName) } } return nil } -func importLegacyTables(dirPath string, verbose bool) error { +func importLegacyTables(ctx context.Context, dirPath string, verbose bool) error { snakeMapper := core.SnakeMapper{} skipInsertProcessors := map[string]bool{ @@ -218,6 +239,12 @@ func importLegacyTables(dirPath string, verbose bool) error { // Purposely create a local variable to not modify global variable legacyTables := append(legacyTables, new(Version)) for _, table := range legacyTables { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") tableFile := filepath.Join(dirPath, tableName+".json") if !osutil.IsFile(tableFile) { diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go index e30ac7a2..047a2dca 100644 --- a/internal/db/backup_test.go +++ b/internal/db/backup_test.go @@ -6,12 +6,14 @@ package db import ( "bytes" + "context" "os" "path/filepath" "testing" "time" "github.com/pkg/errors" + "github.com/stretchr/testify/require" "gorm.io/gorm" "gogs.io/gogs/internal/auth" @@ -22,7 +24,7 @@ import ( "gogs.io/gogs/internal/testutil" ) -func Test_dumpAndImport(t *testing.T) { +func TestDumpAndImport(t *testing.T) { if testing.Short() { t.Skip() } @@ -43,8 +45,6 @@ func Test_dumpAndImport(t *testing.T) { } func setupDBToDump(t *testing.T, db *gorm.DB) { - t.Helper() - vals := []interface{}{ &Access{ ID: 1, @@ -126,31 +126,29 @@ func setupDBToDump(t *testing.T, db *gorm.DB) { } for _, val := range vals { err := db.Create(val).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) } } func dumpTables(t *testing.T, db *gorm.DB) { - t.Helper() + ctx := context.Background() for _, table := range Tables { tableName := getTableType(table) var buf bytes.Buffer - err := dumpTable(db, table, &buf) + err := dumpTable(ctx, db, table, &buf) if err != nil { t.Fatalf("%s: %v", tableName, err) } golden := filepath.Join("testdata", "backup", tableName+".golden.json") - testutil.AssertGolden(t, golden, testutil.Update("Test_dumpAndImport"), buf.String()) + testutil.AssertGolden(t, golden, testutil.Update("TestDumpAndImport"), buf.String()) } } func importTables(t *testing.T, db *gorm.DB) { - t.Helper() + ctx := context.Background() for _, table := range Tables { tableName := getTableType(table) @@ -163,7 +161,7 @@ func importTables(t *testing.T, db *gorm.DB) { } defer func() { _ = f.Close() }() - return importTable(db, table, f) + return importTable(ctx, db, table, f) }() if err != nil { t.Fatalf("%s: %v", tableName, err) |