diff options
author | ᴜɴᴋɴᴡᴏɴ <u@gogs.io> | 2020-09-06 10:11:08 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-06 10:11:08 +0800 |
commit | 519e59b5778571ace3f681b81a21b92a38ede890 (patch) | |
tree | 373a2021f3adcf8b2c44e6f28282733e00909d88 /internal | |
parent | 771d3673f5dc1ba96ed9bf2ebe22c7f3066ec31d (diff) |
db: migrate to GORM v2 (#6309)
Diffstat (limited to 'internal')
-rw-r--r-- | internal/db/access_tokens.go | 24 | ||||
-rw-r--r-- | internal/db/access_tokens_test.go | 21 | ||||
-rw-r--r-- | internal/db/backup.go | 27 | ||||
-rw-r--r-- | internal/db/backup_test.go | 2 | ||||
-rw-r--r-- | internal/db/db.go | 111 | ||||
-rw-r--r-- | internal/db/lfs.go | 4 | ||||
-rw-r--r-- | internal/db/lfs_test.go | 3 | ||||
-rw-r--r-- | internal/db/login_source_files.go | 11 | ||||
-rw-r--r-- | internal/db/login_source_files_test.go | 2 | ||||
-rw-r--r-- | internal/db/login_sources.go | 25 | ||||
-rw-r--r-- | internal/db/login_sources_test.go | 36 | ||||
-rw-r--r-- | internal/db/main_test.go | 58 | ||||
-rw-r--r-- | internal/db/models.go | 6 | ||||
-rw-r--r-- | internal/db/perms.go | 9 | ||||
-rw-r--r-- | internal/db/repos.go | 19 | ||||
-rw-r--r-- | internal/db/repos_test.go | 3 | ||||
-rw-r--r-- | internal/db/two_factors.go | 22 | ||||
-rw-r--r-- | internal/db/two_factors_test.go | 3 | ||||
-rw-r--r-- | internal/db/users.go | 22 | ||||
-rw-r--r-- | internal/db/users_test.go | 5 | ||||
-rw-r--r-- | internal/dbutil/logger.go | 15 | ||||
-rw-r--r-- | internal/dbutil/writer.go | 37 | ||||
-rw-r--r-- | internal/dbutil/writer_test.go | 58 |
23 files changed, 270 insertions, 253 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go index 115f0105..ab62c52d 100644 --- a/internal/db/access_tokens.go +++ b/internal/db/access_tokens.go @@ -8,8 +8,8 @@ import ( "fmt" "time" - "github.com/jinzhu/gorm" gouuid "github.com/satori/go.uuid" + "gorm.io/gorm" "gogs.io/gogs/internal/cryptoutil" "gogs.io/gogs/internal/errutil" @@ -55,24 +55,26 @@ type AccessToken struct { } // NOTE: This is a GORM create hook. -func (t *AccessToken) BeforeCreate() { - if t.CreatedUnix > 0 { - return +func (t *AccessToken) BeforeCreate(tx *gorm.DB) error { + if t.CreatedUnix == 0 { + t.CreatedUnix = tx.NowFunc().Unix() } - t.CreatedUnix = gorm.NowFunc().Unix() + return nil } // NOTE: This is a GORM update hook. -func (t *AccessToken) BeforeUpdate() { - t.UpdatedUnix = gorm.NowFunc().Unix() +func (t *AccessToken) BeforeUpdate(tx *gorm.DB) error { + t.UpdatedUnix = tx.NowFunc().Unix() + return nil } // NOTE: This is a GORM query hook. -func (t *AccessToken) AfterFind() { +func (t *AccessToken) AfterFind(tx *gorm.DB) error { t.Created = time.Unix(t.CreatedUnix, 0).Local() t.Updated = time.Unix(t.UpdatedUnix, 0).Local() t.HasUsed = t.Updated.After(t.Created) - t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(gorm.NowFunc()) + t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(tx.NowFunc()) + return nil } var _ AccessTokensStore = (*accessTokens)(nil) @@ -98,7 +100,7 @@ func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error if err == nil { return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}} - } else if !gorm.IsRecordNotFoundError(err) { + } else if err != gorm.ErrRecordNotFound { return nil, err } @@ -137,7 +139,7 @@ func (db *accessTokens) GetBySHA(sha string) (*AccessToken, error) { token := new(AccessToken) err := db.Where("sha1 = ?", sha).First(token).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha}} } return nil, err diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go index 8a2f3805..ef629320 100644 --- a/internal/db/access_tokens_test.go +++ b/internal/db/access_tokens_test.go @@ -8,24 +8,33 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" ) func TestAccessToken_BeforeCreate(t *testing.T) { + now := time.Now() + db := &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { + return now + }, + }, + } + t.Run("CreatedUnix has been set", func(t *testing.T) { token := &AccessToken{CreatedUnix: 1} - token.BeforeCreate() + _ = token.BeforeCreate(db) assert.Equal(t, int64(1), token.CreatedUnix) assert.Equal(t, int64(0), token.UpdatedUnix) }) t.Run("CreatedUnix has not been set", func(t *testing.T) { token := &AccessToken{} - token.BeforeCreate() - assert.Equal(t, gorm.NowFunc().Unix(), token.CreatedUnix) + _ = token.BeforeCreate(db) + assert.Equal(t, db.NowFunc().Unix(), token.CreatedUnix) assert.Equal(t, int64(0), token.UpdatedUnix) }) } @@ -80,7 +89,7 @@ func test_accessTokens_Create(t *testing.T, db *accessTokens) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), token.Created.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), token.Created.UTC().Format(time.RFC3339)) // Try create second access token with same name should fail _, err = db.Create(token.UserID, token.Name) @@ -193,5 +202,5 @@ func test_accessTokens_Save(t *testing.T, db *accessTokens) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), token.Updated.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), token.Updated.UTC().Format(time.RFC3339)) } diff --git a/internal/db/backup.go b/internal/db/backup.go index a02d70d6..985b3c0f 100644 --- a/internal/db/backup.go +++ b/internal/db/backup.go @@ -8,10 +8,12 @@ import ( "path/filepath" "reflect" "strings" + "sync" - "github.com/jinzhu/gorm" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" + "gorm.io/gorm" + "gorm.io/gorm/schema" log "unknwon.dev/clog/v2" "xorm.io/core" "xorm.io/xorm" @@ -153,17 +155,21 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error { } func importTable(db *gorm.DB, table interface{}, r io.Reader) error { - err := db.DropTableIfExists(table).Error + err := db.Migrator().DropTable(table) if err != nil { return errors.Wrap(err, "drop table") } - err = db.AutoMigrate(table).Error + err = db.Migrator().AutoMigrate(table) if err != nil { return errors.Wrap(err, "auto migrate") } - rawTableName := db.NewScope(table).TableName() + s, err := schema.Parse(table, &sync.Map{}, db.NamingStrategy) + if err != nil { + return errors.Wrap(err, "parse schema") + } + rawTableName := s.Table skipResetIDSeq := map[string]bool{ "lfs_object": true, } @@ -235,21 +241,26 @@ func importLegacyTables(dirPath string, verbose bool) error { return fmt.Errorf("insert strcut: %v", err) } - meta := make(map[string]interface{}) + var meta struct { + ID int64 + CreatedUnix int64 + DeadlineUnix int64 + ClosedDateUnix int64 + } if err = jsoniter.Unmarshal(scanner.Bytes(), &meta); err != nil { log.Error("Failed to unmarshal to map: %v", err) } // Reset created_unix back to the date save in archive because Insert method updates its value if isInsertProcessor && !skipInsertProcessors[rawTableName] { - if _, err = x.Exec("UPDATE `"+rawTableName+"` SET created_unix=? WHERE id=?", meta["CreatedUnix"], meta["ID"]); err != nil { - log.Error("Failed to reset 'created_unix': %v", err) + if _, err = x.Exec("UPDATE `"+rawTableName+"` SET created_unix=? WHERE id=?", meta.CreatedUnix, meta.ID); err != nil { + log.Error("Failed to reset '%s.created_unix': %v", rawTableName, err) } } switch rawTableName { case "milestone": - if _, err = x.Exec("UPDATE `"+rawTableName+"` SET deadline_unix=?, closed_date_unix=? WHERE id=?", meta["DeadlineUnix"], meta["ClosedDateUnix"], meta["ID"]); err != nil { + if _, err = x.Exec("UPDATE `"+rawTableName+"` SET deadline_unix=?, closed_date_unix=? WHERE id=?", meta.DeadlineUnix, meta.ClosedDateUnix, meta.ID); err != nil { log.Error("Failed to reset 'milestone.deadline_unix', 'milestone.closed_date_unix': %v", err) } } diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go index 334b8d7d..fafaf28a 100644 --- a/internal/db/backup_test.go +++ b/internal/db/backup_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/pkg/errors" + "gorm.io/gorm" "gogs.io/gogs/internal/cryptoutil" "gogs.io/gogs/internal/lfsutil" diff --git a/internal/db/db.go b/internal/db/db.go index a6e53683..03da40ca 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -6,18 +6,19 @@ package db import ( "fmt" - "io" "net/url" "path/filepath" "strings" "time" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/pkg/errors" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" log "unknwon.dev/clog/v2" "gogs.io/gogs/internal/conf" @@ -96,16 +97,7 @@ func parseDSN(opts conf.DatabaseOpts) (dsn string, err error) { return dsn, nil } -func openDB(opts conf.DatabaseOpts) (*gorm.DB, error) { - dsn, err := parseDSN(opts) - if err != nil { - return nil, errors.Wrap(err, "parse DSN") - } - - return gorm.Open(opts.Type, dsn) -} - -func getLogWriter() (io.Writer, error) { +func newLogWriter() (logger.Writer, error) { sec := conf.File.Section("log.gorm") w, err := log.NewFileWriter( filepath.Join(conf.Log.RootPath, "gorm.log"), @@ -119,7 +111,30 @@ func getLogWriter() (io.Writer, error) { if err != nil { return nil, errors.Wrap(err, `create "gorm.log"`) } - return w, nil + return &dbutil.Logger{Writer: w}, nil +} + +func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) { + dsn, err := parseDSN(opts) + if err != nil { + return nil, errors.Wrap(err, "parse DSN") + } + + var dialector gorm.Dialector + switch opts.Type { + case "mysql": + dialector = mysql.Open(dsn) + case "postgres": + dialector = postgres.Open(dsn) + case "mssql": + dialector = sqlserver.Open(dsn) + case "sqlite3": + dialector = sqlite.Open(dsn) + default: + panic("unreachable") + } + + return gorm.Open(dialector, cfg) } // NOTE: Lines are sorted in alphabetical order, each letter in its own line. @@ -129,56 +144,72 @@ var tables = []interface{}{ } func Init() (*gorm.DB, error) { - db, err := openDB(conf.Database) + w, err := newLogWriter() if err != nil { - return nil, errors.Wrap(err, "open database") + return nil, errors.Wrap(err, "new log writer") + } + + level := logger.Info + if conf.IsProdMode() { + level = logger.Warn } - db.SingularTable(true) - db.DB().SetMaxOpenConns(conf.Database.MaxOpenConns) - db.DB().SetMaxIdleConns(conf.Database.MaxIdleConns) - db.DB().SetConnMaxLifetime(time.Minute) - w, err := getLogWriter() + // NOTE: AutoMigrate does not respect logger passed in gorm.Config. + logger.Default = logger.New(w, logger.Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: level, + }) + + db, err := openDB(conf.Database, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + NowFunc: func() time.Time { + return time.Now().UTC().Truncate(time.Microsecond) + }, + }) if err != nil { - return nil, errors.Wrap(err, "get log writer") + return nil, errors.Wrap(err, "open database") } - db.SetLogger(&dbutil.Writer{Writer: w}) - if !conf.IsProdMode() { - db = db.LogMode(true) + + sqlDB, err := db.DB() + if err != nil { + return nil, errors.Wrap(err, "get underlying *sql.DB") } + sqlDB.SetMaxOpenConns(conf.Database.MaxOpenConns) + sqlDB.SetMaxIdleConns(conf.Database.MaxIdleConns) + sqlDB.SetConnMaxLifetime(time.Minute) switch conf.Database.Type { + case "postgres": + conf.UsePostgreSQL = true case "mysql": conf.UseMySQL = true db = db.Set("gorm:table_options", "ENGINE=InnoDB") - case "postgres": - conf.UsePostgreSQL = true - case "mssql": - conf.UseMSSQL = true case "sqlite3": conf.UseSQLite3 = true + case "mssql": + conf.UseMSSQL = true + default: + panic("unreachable") } // NOTE: GORM has problem detecting existing columns, see https://github.com/gogs/gogs/issues/6091. // Therefore only use it to create new tables, and do customized migration with future changes. for _, table := range tables { - if db.HasTable(table) { + if db.Migrator().HasTable(table) { continue } name := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") - err = db.AutoMigrate(table).Error + err = db.Migrator().AutoMigrate(table) if err != nil { return nil, errors.Wrapf(err, "auto migrate %q", name) } log.Trace("Auto migrated %q", name) } - gorm.NowFunc = func() time.Time { - return time.Now().UTC().Truncate(time.Microsecond) - } - - sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d")) + sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc) if err != nil { return nil, errors.Wrap(err, "load login source files") } @@ -192,5 +223,5 @@ func Init() (*gorm.DB, error) { TwoFactors = &twoFactors{DB: db} Users = &users{DB: db} - return db, db.DB().Ping() + return db, nil } diff --git a/internal/db/lfs.go b/internal/db/lfs.go index 4db2fa2d..43515d8c 100644 --- a/internal/db/lfs.go +++ b/internal/db/lfs.go @@ -8,7 +8,7 @@ import ( "fmt" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/lfsutil" @@ -76,7 +76,7 @@ func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) object := new(LFSObject) err := db.Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}} } return nil, err diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go index 49d94406..4aff4f22 100644 --- a/internal/db/lfs_test.go +++ b/internal/db/lfs_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "gogs.io/gogs/internal/errutil" @@ -61,7 +60,7 @@ func test_lfs_CreateObject(t *testing.T, db *lfs) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), object.CreatedAt.Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.Format(time.RFC3339)) // Try create second LFS object with same oid should fail err = db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal) diff --git a/internal/db/login_source_files.go b/internal/db/login_source_files.go index 4c5cb377..f2a24ad9 100644 --- a/internal/db/login_source_files.go +++ b/internal/db/login_source_files.go @@ -10,8 +10,8 @@ import ( "path/filepath" "strings" "sync" + "time" - "github.com/jinzhu/gorm" "github.com/pkg/errors" "gopkg.in/ini.v1" @@ -39,6 +39,7 @@ var _ loginSourceFilesStore = (*loginSourceFiles)(nil) type loginSourceFiles struct { sync.RWMutex sources []*LoginSource + clock func() time.Time } var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil) @@ -98,7 +99,7 @@ func (s *loginSourceFiles) Update(source *LoginSource) { s.Lock() defer s.Unlock() - source.Updated = gorm.NowFunc() + source.Updated = s.clock() for _, old := range s.sources { if old.ID == source.ID { *old = *source @@ -109,12 +110,12 @@ func (s *loginSourceFiles) Update(source *LoginSource) { } // loadLoginSourceFiles loads login sources from file system. -func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) { +func loadLoginSourceFiles(authdPath string, clock func() time.Time) (loginSourceFilesStore, error) { if !osutil.IsDir(authdPath) { - return &loginSourceFiles{}, nil + return &loginSourceFiles{clock: clock}, nil } - store := &loginSourceFiles{} + store := &loginSourceFiles{clock: clock} return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error { if err != nil { return err diff --git a/internal/db/login_source_files_test.go b/internal/db/login_source_files_test.go index d87239a8..d9cdc788 100644 --- a/internal/db/login_source_files_test.go +++ b/internal/db/login_source_files_test.go @@ -6,6 +6,7 @@ package db import ( "testing" + "time" "github.com/stretchr/testify/assert" @@ -70,6 +71,7 @@ func Test_loginSourceFiles_Update(t *testing.T) { {ID: 101, IsActived: true, IsDefault: true}, {ID: 102, IsActived: false}, }, + clock: time.Now, } source102 := &LoginSource{ diff --git a/internal/db/login_sources.go b/internal/db/login_sources.go index 3c47a14c..4bb1f4c2 100644 --- a/internal/db/login_sources.go +++ b/internal/db/login_sources.go @@ -9,9 +9,9 @@ import ( "strconv" "time" - "github.com/jinzhu/gorm" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" + "gorm.io/gorm" "gogs.io/gogs/internal/auth/ldap" "gogs.io/gogs/internal/errutil" @@ -62,7 +62,7 @@ type LoginSource struct { } // NOTE: This is a GORM save hook. -func (s *LoginSource) BeforeSave() (err error) { +func (s *LoginSource) BeforeSave(tx *gorm.DB) (err error) { if s.Config == nil { return nil } @@ -71,21 +71,22 @@ func (s *LoginSource) BeforeSave() (err error) { } // NOTE: This is a GORM create hook. -func (s *LoginSource) BeforeCreate() { - if s.CreatedUnix > 0 { - return +func (s *LoginSource) BeforeCreate(tx *gorm.DB) error { + if s.CreatedUnix == 0 { + s.CreatedUnix = tx.NowFunc().Unix() + s.UpdatedUnix = s.CreatedUnix } - s.CreatedUnix = gorm.NowFunc().Unix() - s.UpdatedUnix = s.CreatedUnix + return nil } // NOTE: This is a GORM update hook. -func (s *LoginSource) BeforeUpdate() { - s.UpdatedUnix = gorm.NowFunc().Unix() +func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error { + s.UpdatedUnix = tx.NowFunc().Unix() + return nil } // NOTE: This is a GORM query hook. -func (s *LoginSource) AfterFind() error { +func (s *LoginSource) AfterFind(tx *gorm.DB) error { s.Created = time.Unix(s.CreatedUnix, 0).Local() s.Updated = time.Unix(s.UpdatedUnix, 0).Local() @@ -204,7 +205,7 @@ func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error if err == nil { return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}} - } else if !gorm.IsRecordNotFoundError(err) { + } else if err != gorm.ErrRecordNotFound { return nil, err } @@ -253,7 +254,7 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) { source := new(LoginSource) err := db.Where("id = ?", id).First(source).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return db.files.GetByID(id) } return nil, err diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go index ac657add..cb3a16e7 100644 --- a/internal/db/login_sources_test.go +++ b/internal/db/login_sources_test.go @@ -8,16 +8,25 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" ) func TestLoginSource_BeforeSave(t *testing.T) { + now := time.Now() + db := &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { + return now + }, + }, + } + t.Run("Config has not been set", func(t *testing.T) { s := &LoginSource{} - err := s.BeforeSave() + err := s.BeforeSave(db) if err != nil { t.Fatal(err) } @@ -28,7 +37,7 @@ func TestLoginSource_BeforeSave(t *testing.T) { s := &LoginSource{ Config: &PAMConfig{ServiceName: "pam_service"}, } - err := s.BeforeSave() + err := s.BeforeSave(db) if err != nil { t.Fatal(err) } @@ -37,18 +46,27 @@ func TestLoginSource_BeforeSave(t *testing.T) { } func TestLoginSource_BeforeCreate(t *testing.T) { + now := time.Now() + db := &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { + return now + }, + }, + } + t.Run("CreatedUnix has been set", func(t *testing.T) { s := &LoginSource{CreatedUnix: 1} - s.BeforeCreate() + _ = s.BeforeCreate(db) assert.Equal(t, int64(1), s.CreatedUnix) assert.Equal(t, int64(0), s.UpdatedUnix) }) t.Run("CreatedUnix has not been set", func(t *testing.T) { s := &LoginSource{} - s.BeforeCreate() - assert.Equal(t, gorm.NowFunc().Unix(), s.CreatedUnix) - assert.Equal(t, gorm.NowFunc().Unix(), s.UpdatedUnix) + _ = s.BeforeCreate(db) + assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix) + assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix) }) } @@ -108,8 +126,8 @@ func test_loginSources_Create(t *testing.T, db *loginSources) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339)) - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339)) // Try create second login source with same name should fail _, err = db.Create(CreateLoginSourceOpts{Name: source.Name}) diff --git a/internal/db/main_test.go b/internal/db/main_test.go index 6813118e..09f175c5 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -8,12 +8,15 @@ import ( "flag" "fmt" "io/ioutil" + stdlog "log" "os" "path/filepath" "testing" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" log "unknwon.dev/clog/v2" "gogs.io/gogs/internal/conf" @@ -21,10 +24,11 @@ import ( "gogs.io/gogs/internal/testutil" ) -var printSQL = flag.Bool("print-sql", false, "Print SQL executed") - func TestMain(m *testing.M) { flag.Parse() + + var w logger.Writer + level := logger.Silent if !testing.Verbose() { // Remove the primary logger and register a noop logger. log.Remove(log.DefaultConsoleName) @@ -33,10 +37,18 @@ func TestMain(m *testing.M) { fmt.Println(err) os.Exit(1) } + + w = &dbutil.Logger{Writer: ioutil.Discard} + } else { + w = stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags) + level = logger.Info } - now := time.Now().UTC().Truncate(time.Second) - gorm.NowFunc = func() time.Time { return now } + // NOTE: AutoMigrate does not respect logger passed in gorm.Config. + logger.Default = logger.New(w, logger.Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: level, + }) os.Exit(m.Run()) } @@ -48,7 +60,7 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error { } for _, t := range tables { - err := db.Delete(t).Error + err := db.Where("TRUE").Delete(t).Error if err != nil { return err } @@ -60,15 +72,29 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { t.Helper() dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) - db, err := openDB(conf.DatabaseOpts{ - Type: "sqlite3", - Path: dbpath, - }) + now := time.Now().UTC().Truncate(time.Second) + db, err := openDB( + conf.DatabaseOpts{ + Type: "sqlite3", + Path: dbpath, + }, + &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + NowFunc: func() time.Time { + return now + }, + }, + ) if err != nil { t.Fatal(err) } t.Cleanup(func() { - _ = db.Close() + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } if t.Failed() { t.Logf("Database %q left intact for inspection", dbpath) @@ -78,15 +104,7 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { _ = os.Remove(dbpath) }) - db.SingularTable(true) - if !testing.Verbose() { - db.SetLogger(&dbutil.Writer{Writer: ioutil.Discard}) - } - if *printSQL { - db.LogMode(true) - } - - err = db.AutoMigrate(tables...).Error + err = db.Migrator().AutoMigrate(tables...) if err != nil { t.Fatal(err) } diff --git a/internal/db/models.go b/internal/db/models.go index d2868bb6..52ef4495 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" log "unknwon.dev/clog/v2" "xorm.io/core" "xorm.io/xorm" @@ -68,6 +68,7 @@ func getEngine() (*xorm.Engine, error) { Param = "&" } + driver := conf.Database.Type connStr := "" switch conf.Database.Type { case "mysql": @@ -92,6 +93,7 @@ func getEngine() (*xorm.Engine, error) { connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%s%ssslmode=%s", url.QueryEscape(conf.Database.User), url.QueryEscape(conf.Database.Password), host, port, conf.Database.Name, Param, conf.Database.SSLMode) } + driver = "pgx" case "mssql": conf.UseMSSQL = true @@ -108,7 +110,7 @@ func getEngine() (*xorm.Engine, error) { default: return nil, fmt.Errorf("unknown database type: %s", conf.Database.Type) } - return xorm.NewEngine(conf.Database.Type, connStr) + return xorm.NewEngine(driver, connStr) } func NewTestEngine() error { diff --git a/internal/db/perms.go b/internal/db/perms.go index 129f2187..1295300d 100644 --- a/internal/db/perms.go +++ b/internal/db/perms.go @@ -5,8 +5,7 @@ package db import ( - "github.com/jinzhu/gorm" - "github.com/t-tiger/gorm-bulk-insert" + "gorm.io/gorm" log "unknwon.dev/clog/v2" ) @@ -53,7 +52,7 @@ func (db *perms) AccessMode(userID int64, repo *Repository) (mode AccessMode) { access := new(Access) err := db.Where("user_id = ? AND repo_id = ?", userID, repo.ID).First(access).Error if err != nil { - if !gorm.IsRecordNotFoundError(err) { + if err != gorm.ErrRecordNotFound { log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repo.ID, err) } return mode @@ -66,7 +65,7 @@ func (db *perms) Authorize(userID int64, repo *Repository, desired AccessMode) b } func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error { - records := make([]interface{}, 0, len(accessMap)) + records := make([]*Access, 0, len(accessMap)) for userID, mode := range accessMap { records = append(records, &Access{ UserID: userID, @@ -81,6 +80,6 @@ func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) erro return err } - return gormbulk.BulkInsert(tx, records, 3000) + return tx.Create(&records).Error }) } diff --git a/internal/db/repos.go b/internal/db/repos.go index 08cc6485..fc5bce03 100644 --- a/internal/db/repos.go +++ b/internal/db/repos.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/jinzhu/gorm" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" ) @@ -26,19 +26,24 @@ type ReposStore interface { var Repos ReposStore // NOTE: This is a GORM create hook. -func (r *Repository) BeforeCreate() { - r.CreatedUnix = gorm.NowFunc().Unix() +func (r *Repository) BeforeCreate(tx *gorm.DB) error { + if r.CreatedUnix == 0 { + r.CreatedUnix = tx.NowFunc().Unix() + } + return nil } // NOTE: This is a GORM update hook. -func (r *Repository) BeforeUpdate() { - r.UpdatedUnix = gorm.NowFunc().Unix() +func (r *Repository) BeforeUpdate(tx *gorm.DB) error { + r.UpdatedUnix = tx.NowFunc().Unix() + return nil } // NOTE: This is a GORM query hook. -func (r *Repository) AfterFind() { +func (r *Repository) AfterFind(tx *gorm.DB) error { r.Created = time.Unix(r.CreatedUnix, 0).Local() r.Updated = time.Unix(r.UpdatedUnix, 0).Local() + return nil } var _ ReposStore = (*repos)(nil) @@ -129,7 +134,7 @@ func (db *repos) GetByName(ownerID int64, name string) (*Repository, error) { repo := new(Repository) err := db.Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).First(repo).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrRepoNotExist{args: map[string]interface{}{"ownerID": ownerID, "name": name}} } return nil, err diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index 42eedbba..7b8f2b00 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "gogs.io/gogs/internal/errutil" @@ -80,7 +79,7 @@ func test_repos_create(t *testing.T, db *repos) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339)) } func test_repos_GetByName(t *testing.T, db *repos) { diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go index aa90c2ff..7692e5d5 100644 --- a/internal/db/two_factors.go +++ b/internal/db/two_factors.go @@ -10,9 +10,8 @@ import ( "strings" "time" - "github.com/jinzhu/gorm" "github.com/pkg/errors" - "github.com/t-tiger/gorm-bulk-insert" + "gorm.io/gorm" log "unknwon.dev/clog/v2" "gogs.io/gogs/internal/cryptoutil" @@ -39,13 +38,17 @@ type TwoFactorsStore interface { var TwoFactors TwoFactorsStore // NOTE: This is a GORM create hook. -func (t *TwoFactor) BeforeCreate() { - t.CreatedUnix = gorm.NowFunc().Unix() +func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error { + if t.CreatedUnix == 0 { + t.CreatedUnix = tx.NowFunc().Unix() + } + return nil } // NOTE: This is a GORM query hook. -func (t *TwoFactor) AfterFind() { +func (t *TwoFactor) AfterFind(tx *gorm.DB) error { t.Created = time.Unix(t.CreatedUnix, 0).Local() + return nil } var _ TwoFactorsStore = (*twoFactors)(nil) @@ -69,18 +72,13 @@ func (db *twoFactors) Create(userID int64, key, secret string) error { return errors.Wrap(err, "generate recovery codes") } - records := make([]interface{}, 0, len(recoveryCodes)) - for _, code := range recoveryCodes { - records = append(records, code) - } - return db.Transaction(func(tx *gorm.DB) error { err := tx.Create(tf).Error if err != nil { return err } - return gormbulk.BulkInsert(tx, records, 3000) + return tx.Create(&recoveryCodes).Error }) } @@ -107,7 +105,7 @@ func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) { tf := new(TwoFactor) err := db.Where("user_id = ?", userID).First(tf).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}} } return nil, err diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go index 4f0ad9da..20dfae61 100644 --- a/internal/db/two_factors_test.go +++ b/internal/db/two_factors_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "gogs.io/gogs/internal/errutil" @@ -58,7 +57,7 @@ func test_twoFactors_Create(t *testing.T, db *twoFactors) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339)) // Verify there are 10 recover codes generated var count int64 diff --git a/internal/db/users.go b/internal/db/users.go index a002c290..64535ea3 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/jinzhu/gorm" "github.com/pkg/errors" + "gorm.io/gorm" "gogs.io/gogs/internal/cryptoutil" "gogs.io/gogs/internal/errutil" @@ -48,15 +48,19 @@ type UsersStore interface { var Users UsersStore // NOTE: This is a GORM create hook. -func (u *User) BeforeCreate() { - u.CreatedUnix = gorm.NowFunc().Unix() - u.UpdatedUnix = u.CreatedUnix +func (u *User) BeforeCreate(tx *gorm.DB) error { + if u.CreatedUnix == 0 { + u.CreatedUnix = tx.NowFunc().Unix() + u.UpdatedUnix = u.CreatedUnix + } + return nil } // NOTE: This is a GORM query hook. -func (u *User) AfterFind() { +func (u *User) AfterFind(tx *gorm.DB) error { u.Created = time.Unix(u.CreatedUnix, 0).Local() u.Updated = time.Unix(u.UpdatedUnix, 0).Local() + return nil } var _ UsersStore = (*users)(nil) @@ -253,7 +257,7 @@ func (db *users) GetByEmail(email string) (*User, error) { err := db.Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true).First(user).Error if err == nil { return user, nil - } else if !gorm.IsRecordNotFoundError(err) { + } else if err != gorm.ErrRecordNotFound { return nil, err } @@ -261,7 +265,7 @@ func (db *users) GetByEmail(email string) (*User, error) { emailAddress := new(EmailAddress) err = db.Where("email = ? AND is_activated = ?", email, true).First(emailAddress).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"email": email}} } return nil, err @@ -274,7 +278,7 @@ func (db *users) GetByID(id int64) (*User, error) { user := new(User) err := db.Where("id = ?", id).First(user).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} } return nil, err @@ -286,7 +290,7 @@ func (db *users) GetByUsername(username string) (*User, error) { user := new(User) err := db.Where("lower_name = ?", strings.ToLower(username)).First(user).Error if err != nil { - if gorm.IsRecordNotFoundError(err) { + if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"name": username}} } return nil, err diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 2be196e3..7ff3b9f5 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "gogs.io/gogs/internal/errutil" @@ -129,8 +128,8 @@ func test_users_Create(t *testing.T, db *users) { if err != nil { t.Fatal(err) } - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339)) - assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339)) + assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) } func test_users_GetByEmail(t *testing.T, db *users) { diff --git a/internal/dbutil/logger.go b/internal/dbutil/logger.go new file mode 100644 index 00000000..66426ae7 --- /dev/null +++ b/internal/dbutil/logger.go @@ -0,0 +1,15 @@ +package dbutil + +import ( + "fmt" + "io" +) + +// Logger is a wrapper of io.Writer for the GORM's logger.Writer. +type Logger struct { + io.Writer +} + +func (l *Logger) Printf(format string, args ...interface{}) { + fmt.Fprintf(l.Writer, format, args...) +} diff --git a/internal/dbutil/writer.go b/internal/dbutil/writer.go deleted file mode 100644 index 3b8d7363..00000000 --- a/internal/dbutil/writer.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package dbutil - -import ( - "fmt" - "io" -) - -// Writer is a wrapper of io.Writer for the gorm.logger. -type Writer struct { - io.Writer -} - -func (w *Writer) Print(v ...interface{}) { - if len(v) == 0 { - return - } - - if len(v) == 1 { - fmt.Fprint(w.Writer, v[0]) - return - } - - switch v[0] { - case "sql": - fmt.Fprintf(w.Writer, "[sql] [%s] [%s] %s %v (%d rows affected)", v[1:]...) - case "log": - fmt.Fprintf(w.Writer, "[log] [%s] %s", v[1:]...) - case "error": - fmt.Fprintf(w.Writer, "[err] [%s] %s", v[1:]...) - default: - fmt.Fprint(w.Writer, v...) - } -} diff --git a/internal/dbutil/writer_test.go b/internal/dbutil/writer_test.go deleted file mode 100644 index a484d442..00000000 --- a/internal/dbutil/writer_test.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2020 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package dbutil - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestWriter_Print(t *testing.T) { - tests := []struct { - name string - vs []interface{} - expOutput string - }{ - { - name: "no values", - }, - { - name: "only one value", - vs: []interface{}{"test"}, - expOutput: "test", - }, - { - name: "two values", - vs: []interface{}{"test", "output"}, - expOutput: "testoutput", - }, - - { - name: "sql", - vs: []interface{}{"sql", "writer.go:65", "1ms", "SELECT * FROM users WHERE user_id = $1", []int{1}, 1}, - expOutput: "[sql] [writer.go:65] [1ms] SELECT * FROM users WHERE user_id = $1 [1] (1 rows affected)", - }, - { - name: "log", - vs: []interface{}{"log", "writer.go:65", "something"}, - expOutput: "[log] [writer.go:65] something", - }, - { - name: "error", - vs: []interface{}{"error", "writer.go:65", "something bad"}, - expOutput: "[err] [writer.go:65] something bad", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var buf bytes.Buffer - w := &Writer{Writer: &buf} - w.Print(test.vs...) - assert.Equal(t, test.expOutput, buf.String()) - }) - } -} |