aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-06-11 11:10:25 +0800
committerGitHub <noreply@github.com>2022-06-11 11:10:25 +0800
commit75fbb8244086a2ad964d1c51e3bdbdfb95df90ac (patch)
treec97a7d8b17a4a351f304bfd4a7dd01a75711a1c7 /internal/db
parentf837ea6346ac720d586eda51ab89c5c71a1f3e65 (diff)
db: use `context` for backup and restore (#7044)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/backup.go61
-rw-r--r--internal/db/backup_test.go20
2 files changed, 53 insertions, 28 deletions
diff --git a/internal/db/backup.go b/internal/db/backup.go
index 836672b1..bd13fbb4 100644
--- a/internal/db/backup.go
+++ b/internal/db/backup.go
@@ -3,6 +3,7 @@ package db
import (
"bufio"
"bytes"
+ "context"
"fmt"
"io"
"os"
@@ -30,18 +31,24 @@ func getTableType(t interface{}) string {
}
// DumpDatabase dumps all data from database to file system in JSON Lines format.
-func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
+func DumpDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
return err
}
- err = dumpLegacyTables(dirPath, verbose)
+ err = dumpLegacyTables(ctx, dirPath, verbose)
if err != nil {
return errors.Wrap(err, "dump legacy tables")
}
for _, table := range Tables {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
tableName := getTableType(table)
if verbose {
log.Trace("Dumping table %q...", tableName)
@@ -55,7 +62,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
}
defer func() { _ = f.Close() }()
- return dumpTable(db, table, f)
+ return dumpTable(ctx, db, table, f)
}()
if err != nil {
return errors.Wrapf(err, "dump table %q", tableName)
@@ -65,11 +72,13 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return nil
}
-func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
- query := db.Model(table).Order("id ASC")
+func dumpTable(ctx context.Context, db *gorm.DB, table interface{}, w io.Writer) error {
+ query := db.WithContext(ctx).Model(table)
switch table.(type) {
case *LFSObject:
- query = db.Model(table).Order("repo_id, oid ASC")
+ query = query.Order("repo_id, oid ASC")
+ default:
+ query = query.Order("id ASC")
}
rows, err := query.Rows()
@@ -98,10 +107,16 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
return rows.Err()
}
-func dumpLegacyTables(dirPath string, verbose bool) error {
+func dumpLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
// Purposely create a local variable to not modify global variable
legacyTables := append(legacyTables, new(Version))
for _, table := range legacyTables {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
tableName := getTableType(table)
if verbose {
log.Trace("Dumping table %q...", tableName)
@@ -113,7 +128,7 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
return fmt.Errorf("create JSON file: %v", err)
}
- if err = x.Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
+ if err = x.Context(ctx).Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
return jsoniter.NewEncoder(f).Encode(bean)
}); err != nil {
_ = f.Close()
@@ -125,13 +140,19 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
}
// ImportDatabase imports data from backup archive in JSON Lines format.
-func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
- err := importLegacyTables(dirPath, verbose)
+func ImportDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
+ err := importLegacyTables(ctx, dirPath, verbose)
if err != nil {
return errors.Wrap(err, "import legacy tables")
}
for _, table := range Tables {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
err := func() error {
tableFile := filepath.Join(dirPath, tableName+".json")
@@ -150,7 +171,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
}
defer func() { _ = f.Close() }()
- return importTable(db, table, f)
+ return importTable(ctx, db, table, f)
}()
if err != nil {
return errors.Wrapf(err, "import table %q", tableName)
@@ -160,13 +181,13 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return nil
}
-func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
- err := db.Migrator().DropTable(table)
+func importTable(ctx context.Context, db *gorm.DB, table interface{}, r io.Reader) error {
+ err := db.WithContext(ctx).Migrator().DropTable(table)
if err != nil {
return errors.Wrap(err, "drop table")
}
- err = db.Migrator().AutoMigrate(table)
+ err = db.WithContext(ctx).Migrator().AutoMigrate(table)
if err != nil {
return errors.Wrap(err, "auto migrate")
}
@@ -191,7 +212,7 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
return errors.Wrap(err, "unmarshal JSON to struct")
}
- err = db.Create(elem).Error
+ err = db.WithContext(ctx).Create(elem).Error
if err != nil {
return errors.Wrap(err, "create row")
}
@@ -200,14 +221,14 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
// PostgreSQL needs manually reset table sequence for auto increment keys
if conf.UsePostgreSQL && !skipResetIDSeq[rawTableName] {
seqName := rawTableName + "_id_seq"
- if _, err = x.Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
+ if _, err = x.Context(ctx).Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
return errors.Wrapf(err, "reset table %q.%q", rawTableName, seqName)
}
}
return nil
}
-func importLegacyTables(dirPath string, verbose bool) error {
+func importLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
snakeMapper := core.SnakeMapper{}
skipInsertProcessors := map[string]bool{
@@ -218,6 +239,12 @@ func importLegacyTables(dirPath string, verbose bool) error {
// Purposely create a local variable to not modify global variable
legacyTables := append(legacyTables, new(Version))
for _, table := range legacyTables {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := filepath.Join(dirPath, tableName+".json")
if !osutil.IsFile(tableFile) {
diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go
index e30ac7a2..047a2dca 100644
--- a/internal/db/backup_test.go
+++ b/internal/db/backup_test.go
@@ -6,12 +6,14 @@ package db
import (
"bytes"
+ "context"
"os"
"path/filepath"
"testing"
"time"
"github.com/pkg/errors"
+ "github.com/stretchr/testify/require"
"gorm.io/gorm"
"gogs.io/gogs/internal/auth"
@@ -22,7 +24,7 @@ import (
"gogs.io/gogs/internal/testutil"
)
-func Test_dumpAndImport(t *testing.T) {
+func TestDumpAndImport(t *testing.T) {
if testing.Short() {
t.Skip()
}
@@ -43,8 +45,6 @@ func Test_dumpAndImport(t *testing.T) {
}
func setupDBToDump(t *testing.T, db *gorm.DB) {
- t.Helper()
-
vals := []interface{}{
&Access{
ID: 1,
@@ -126,31 +126,29 @@ func setupDBToDump(t *testing.T, db *gorm.DB) {
}
for _, val := range vals {
err := db.Create(val).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
}
}
func dumpTables(t *testing.T, db *gorm.DB) {
- t.Helper()
+ ctx := context.Background()
for _, table := range Tables {
tableName := getTableType(table)
var buf bytes.Buffer
- err := dumpTable(db, table, &buf)
+ err := dumpTable(ctx, db, table, &buf)
if err != nil {
t.Fatalf("%s: %v", tableName, err)
}
golden := filepath.Join("testdata", "backup", tableName+".golden.json")
- testutil.AssertGolden(t, golden, testutil.Update("Test_dumpAndImport"), buf.String())
+ testutil.AssertGolden(t, golden, testutil.Update("TestDumpAndImport"), buf.String())
}
}
func importTables(t *testing.T, db *gorm.DB) {
- t.Helper()
+ ctx := context.Background()
for _, table := range Tables {
tableName := getTableType(table)
@@ -163,7 +161,7 @@ func importTables(t *testing.T, db *gorm.DB) {
}
defer func() { _ = f.Close() }()
- return importTable(db, table, f)
+ return importTable(ctx, db, table, f)
}()
if err != nil {
t.Fatalf("%s: %v", tableName, err)