aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorᴜɴᴋɴᴡᴏɴ <u@gogs.io>2020-09-06 10:11:08 +0800
committerGitHub <noreply@github.com>2020-09-06 10:11:08 +0800
commit519e59b5778571ace3f681b81a21b92a38ede890 (patch)
tree373a2021f3adcf8b2c44e6f28282733e00909d88 /internal
parent771d3673f5dc1ba96ed9bf2ebe22c7f3066ec31d (diff)
db: migrate to GORM v2 (#6309)
Diffstat (limited to 'internal')
-rw-r--r--internal/db/access_tokens.go24
-rw-r--r--internal/db/access_tokens_test.go21
-rw-r--r--internal/db/backup.go27
-rw-r--r--internal/db/backup_test.go2
-rw-r--r--internal/db/db.go111
-rw-r--r--internal/db/lfs.go4
-rw-r--r--internal/db/lfs_test.go3
-rw-r--r--internal/db/login_source_files.go11
-rw-r--r--internal/db/login_source_files_test.go2
-rw-r--r--internal/db/login_sources.go25
-rw-r--r--internal/db/login_sources_test.go36
-rw-r--r--internal/db/main_test.go58
-rw-r--r--internal/db/models.go6
-rw-r--r--internal/db/perms.go9
-rw-r--r--internal/db/repos.go19
-rw-r--r--internal/db/repos_test.go3
-rw-r--r--internal/db/two_factors.go22
-rw-r--r--internal/db/two_factors_test.go3
-rw-r--r--internal/db/users.go22
-rw-r--r--internal/db/users_test.go5
-rw-r--r--internal/dbutil/logger.go15
-rw-r--r--internal/dbutil/writer.go37
-rw-r--r--internal/dbutil/writer_test.go58
23 files changed, 270 insertions, 253 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go
index 115f0105..ab62c52d 100644
--- a/internal/db/access_tokens.go
+++ b/internal/db/access_tokens.go
@@ -8,8 +8,8 @@ import (
"fmt"
"time"
- "github.com/jinzhu/gorm"
gouuid "github.com/satori/go.uuid"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/cryptoutil"
"gogs.io/gogs/internal/errutil"
@@ -55,24 +55,26 @@ type AccessToken struct {
}
// NOTE: This is a GORM create hook.
-func (t *AccessToken) BeforeCreate() {
- if t.CreatedUnix > 0 {
- return
+func (t *AccessToken) BeforeCreate(tx *gorm.DB) error {
+ if t.CreatedUnix == 0 {
+ t.CreatedUnix = tx.NowFunc().Unix()
}
- t.CreatedUnix = gorm.NowFunc().Unix()
+ return nil
}
// NOTE: This is a GORM update hook.
-func (t *AccessToken) BeforeUpdate() {
- t.UpdatedUnix = gorm.NowFunc().Unix()
+func (t *AccessToken) BeforeUpdate(tx *gorm.DB) error {
+ t.UpdatedUnix = tx.NowFunc().Unix()
+ return nil
}
// NOTE: This is a GORM query hook.
-func (t *AccessToken) AfterFind() {
+func (t *AccessToken) AfterFind(tx *gorm.DB) error {
t.Created = time.Unix(t.CreatedUnix, 0).Local()
t.Updated = time.Unix(t.UpdatedUnix, 0).Local()
t.HasUsed = t.Updated.After(t.Created)
- t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(gorm.NowFunc())
+ t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(tx.NowFunc())
+ return nil
}
var _ AccessTokensStore = (*accessTokens)(nil)
@@ -98,7 +100,7 @@ func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error)
err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
if err == nil {
return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
- } else if !gorm.IsRecordNotFoundError(err) {
+ } else if err != gorm.ErrRecordNotFound {
return nil, err
}
@@ -137,7 +139,7 @@ func (db *accessTokens) GetBySHA(sha string) (*AccessToken, error) {
token := new(AccessToken)
err := db.Where("sha1 = ?", sha).First(token).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha}}
}
return nil, err
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go
index 8a2f3805..ef629320 100644
--- a/internal/db/access_tokens_test.go
+++ b/internal/db/access_tokens_test.go
@@ -8,24 +8,33 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/errutil"
)
func TestAccessToken_BeforeCreate(t *testing.T) {
+ now := time.Now()
+ db := &gorm.DB{
+ Config: &gorm.Config{
+ NowFunc: func() time.Time {
+ return now
+ },
+ },
+ }
+
t.Run("CreatedUnix has been set", func(t *testing.T) {
token := &AccessToken{CreatedUnix: 1}
- token.BeforeCreate()
+ _ = token.BeforeCreate(db)
assert.Equal(t, int64(1), token.CreatedUnix)
assert.Equal(t, int64(0), token.UpdatedUnix)
})
t.Run("CreatedUnix has not been set", func(t *testing.T) {
token := &AccessToken{}
- token.BeforeCreate()
- assert.Equal(t, gorm.NowFunc().Unix(), token.CreatedUnix)
+ _ = token.BeforeCreate(db)
+ assert.Equal(t, db.NowFunc().Unix(), token.CreatedUnix)
assert.Equal(t, int64(0), token.UpdatedUnix)
})
}
@@ -80,7 +89,7 @@ func test_accessTokens_Create(t *testing.T, db *accessTokens) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), token.Created.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), token.Created.UTC().Format(time.RFC3339))
// Try create second access token with same name should fail
_, err = db.Create(token.UserID, token.Name)
@@ -193,5 +202,5 @@ func test_accessTokens_Save(t *testing.T, db *accessTokens) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), token.Updated.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), token.Updated.UTC().Format(time.RFC3339))
}
diff --git a/internal/db/backup.go b/internal/db/backup.go
index a02d70d6..985b3c0f 100644
--- a/internal/db/backup.go
+++ b/internal/db/backup.go
@@ -8,10 +8,12 @@ import (
"path/filepath"
"reflect"
"strings"
+ "sync"
- "github.com/jinzhu/gorm"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
+ "gorm.io/gorm"
+ "gorm.io/gorm/schema"
log "unknwon.dev/clog/v2"
"xorm.io/core"
"xorm.io/xorm"
@@ -153,17 +155,21 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
}
func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
- err := db.DropTableIfExists(table).Error
+ err := db.Migrator().DropTable(table)
if err != nil {
return errors.Wrap(err, "drop table")
}
- err = db.AutoMigrate(table).Error
+ err = db.Migrator().AutoMigrate(table)
if err != nil {
return errors.Wrap(err, "auto migrate")
}
- rawTableName := db.NewScope(table).TableName()
+ s, err := schema.Parse(table, &sync.Map{}, db.NamingStrategy)
+ if err != nil {
+ return errors.Wrap(err, "parse schema")
+ }
+ rawTableName := s.Table
skipResetIDSeq := map[string]bool{
"lfs_object": true,
}
@@ -235,21 +241,26 @@ func importLegacyTables(dirPath string, verbose bool) error {
return fmt.Errorf("insert strcut: %v", err)
}
- meta := make(map[string]interface{})
+ var meta struct {
+ ID int64
+ CreatedUnix int64
+ DeadlineUnix int64
+ ClosedDateUnix int64
+ }
if err = jsoniter.Unmarshal(scanner.Bytes(), &meta); err != nil {
log.Error("Failed to unmarshal to map: %v", err)
}
// Reset created_unix back to the date save in archive because Insert method updates its value
if isInsertProcessor && !skipInsertProcessors[rawTableName] {
- if _, err = x.Exec("UPDATE `"+rawTableName+"` SET created_unix=? WHERE id=?", meta["CreatedUnix"], meta["ID"]); err != nil {
- log.Error("Failed to reset 'created_unix': %v", err)
+ if _, err = x.Exec("UPDATE `"+rawTableName+"` SET created_unix=? WHERE id=?", meta.CreatedUnix, meta.ID); err != nil {
+ log.Error("Failed to reset '%s.created_unix': %v", rawTableName, err)
}
}
switch rawTableName {
case "milestone":
- if _, err = x.Exec("UPDATE `"+rawTableName+"` SET deadline_unix=?, closed_date_unix=? WHERE id=?", meta["DeadlineUnix"], meta["ClosedDateUnix"], meta["ID"]); err != nil {
+ if _, err = x.Exec("UPDATE `"+rawTableName+"` SET deadline_unix=?, closed_date_unix=? WHERE id=?", meta.DeadlineUnix, meta.ClosedDateUnix, meta.ID); err != nil {
log.Error("Failed to reset 'milestone.deadline_unix', 'milestone.closed_date_unix': %v", err)
}
}
diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go
index 334b8d7d..fafaf28a 100644
--- a/internal/db/backup_test.go
+++ b/internal/db/backup_test.go
@@ -11,8 +11,8 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/pkg/errors"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/cryptoutil"
"gogs.io/gogs/internal/lfsutil"
diff --git a/internal/db/db.go b/internal/db/db.go
index a6e53683..03da40ca 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -6,18 +6,19 @@ package db
import (
"fmt"
- "io"
"net/url"
"path/filepath"
"strings"
"time"
- "github.com/jinzhu/gorm"
- _ "github.com/jinzhu/gorm/dialects/mssql"
- _ "github.com/jinzhu/gorm/dialects/mysql"
- _ "github.com/jinzhu/gorm/dialects/postgres"
- _ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/pkg/errors"
+ "gorm.io/driver/mysql"
+ "gorm.io/driver/postgres"
+ "gorm.io/driver/sqlite"
+ "gorm.io/driver/sqlserver"
+ "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+ "gorm.io/gorm/schema"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/conf"
@@ -96,16 +97,7 @@ func parseDSN(opts conf.DatabaseOpts) (dsn string, err error) {
return dsn, nil
}
-func openDB(opts conf.DatabaseOpts) (*gorm.DB, error) {
- dsn, err := parseDSN(opts)
- if err != nil {
- return nil, errors.Wrap(err, "parse DSN")
- }
-
- return gorm.Open(opts.Type, dsn)
-}
-
-func getLogWriter() (io.Writer, error) {
+func newLogWriter() (logger.Writer, error) {
sec := conf.File.Section("log.gorm")
w, err := log.NewFileWriter(
filepath.Join(conf.Log.RootPath, "gorm.log"),
@@ -119,7 +111,30 @@ func getLogWriter() (io.Writer, error) {
if err != nil {
return nil, errors.Wrap(err, `create "gorm.log"`)
}
- return w, nil
+ return &dbutil.Logger{Writer: w}, nil
+}
+
+func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) {
+ dsn, err := parseDSN(opts)
+ if err != nil {
+ return nil, errors.Wrap(err, "parse DSN")
+ }
+
+ var dialector gorm.Dialector
+ switch opts.Type {
+ case "mysql":
+ dialector = mysql.Open(dsn)
+ case "postgres":
+ dialector = postgres.Open(dsn)
+ case "mssql":
+ dialector = sqlserver.Open(dsn)
+ case "sqlite3":
+ dialector = sqlite.Open(dsn)
+ default:
+ panic("unreachable")
+ }
+
+ return gorm.Open(dialector, cfg)
}
// NOTE: Lines are sorted in alphabetical order, each letter in its own line.
@@ -129,56 +144,72 @@ var tables = []interface{}{
}
func Init() (*gorm.DB, error) {
- db, err := openDB(conf.Database)
+ w, err := newLogWriter()
if err != nil {
- return nil, errors.Wrap(err, "open database")
+ return nil, errors.Wrap(err, "new log writer")
+ }
+
+ level := logger.Info
+ if conf.IsProdMode() {
+ level = logger.Warn
}
- db.SingularTable(true)
- db.DB().SetMaxOpenConns(conf.Database.MaxOpenConns)
- db.DB().SetMaxIdleConns(conf.Database.MaxIdleConns)
- db.DB().SetConnMaxLifetime(time.Minute)
- w, err := getLogWriter()
+ // NOTE: AutoMigrate does not respect logger passed in gorm.Config.
+ logger.Default = logger.New(w, logger.Config{
+ SlowThreshold: 100 * time.Millisecond,
+ LogLevel: level,
+ })
+
+ db, err := openDB(conf.Database, &gorm.Config{
+ NamingStrategy: schema.NamingStrategy{
+ SingularTable: true,
+ },
+ NowFunc: func() time.Time {
+ return time.Now().UTC().Truncate(time.Microsecond)
+ },
+ })
if err != nil {
- return nil, errors.Wrap(err, "get log writer")
+ return nil, errors.Wrap(err, "open database")
}
- db.SetLogger(&dbutil.Writer{Writer: w})
- if !conf.IsProdMode() {
- db = db.LogMode(true)
+
+ sqlDB, err := db.DB()
+ if err != nil {
+ return nil, errors.Wrap(err, "get underlying *sql.DB")
}
+ sqlDB.SetMaxOpenConns(conf.Database.MaxOpenConns)
+ sqlDB.SetMaxIdleConns(conf.Database.MaxIdleConns)
+ sqlDB.SetConnMaxLifetime(time.Minute)
switch conf.Database.Type {
+ case "postgres":
+ conf.UsePostgreSQL = true
case "mysql":
conf.UseMySQL = true
db = db.Set("gorm:table_options", "ENGINE=InnoDB")
- case "postgres":
- conf.UsePostgreSQL = true
- case "mssql":
- conf.UseMSSQL = true
case "sqlite3":
conf.UseSQLite3 = true
+ case "mssql":
+ conf.UseMSSQL = true
+ default:
+ panic("unreachable")
}
// NOTE: GORM has problem detecting existing columns, see https://github.com/gogs/gogs/issues/6091.
// Therefore only use it to create new tables, and do customized migration with future changes.
for _, table := range tables {
- if db.HasTable(table) {
+ if db.Migrator().HasTable(table) {
continue
}
name := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
- err = db.AutoMigrate(table).Error
+ err = db.Migrator().AutoMigrate(table)
if err != nil {
return nil, errors.Wrapf(err, "auto migrate %q", name)
}
log.Trace("Auto migrated %q", name)
}
- gorm.NowFunc = func() time.Time {
- return time.Now().UTC().Truncate(time.Microsecond)
- }
-
- sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"))
+ sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
if err != nil {
return nil, errors.Wrap(err, "load login source files")
}
@@ -192,5 +223,5 @@ func Init() (*gorm.DB, error) {
TwoFactors = &twoFactors{DB: db}
Users = &users{DB: db}
- return db, db.DB().Ping()
+ return db, nil
}
diff --git a/internal/db/lfs.go b/internal/db/lfs.go
index 4db2fa2d..43515d8c 100644
--- a/internal/db/lfs.go
+++ b/internal/db/lfs.go
@@ -8,7 +8,7 @@ import (
"fmt"
"time"
- "github.com/jinzhu/gorm"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/lfsutil"
@@ -76,7 +76,7 @@ func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error)
object := new(LFSObject)
err := db.Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}}
}
return nil, err
diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go
index 49d94406..4aff4f22 100644
--- a/internal/db/lfs_test.go
+++ b/internal/db/lfs_test.go
@@ -8,7 +8,6 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/errutil"
@@ -61,7 +60,7 @@ func test_lfs_CreateObject(t *testing.T, db *lfs) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), object.CreatedAt.Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.Format(time.RFC3339))
// Try create second LFS object with same oid should fail
err = db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal)
diff --git a/internal/db/login_source_files.go b/internal/db/login_source_files.go
index 4c5cb377..f2a24ad9 100644
--- a/internal/db/login_source_files.go
+++ b/internal/db/login_source_files.go
@@ -10,8 +10,8 @@ import (
"path/filepath"
"strings"
"sync"
+ "time"
- "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"gopkg.in/ini.v1"
@@ -39,6 +39,7 @@ var _ loginSourceFilesStore = (*loginSourceFiles)(nil)
type loginSourceFiles struct {
sync.RWMutex
sources []*LoginSource
+ clock func() time.Time
}
var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil)
@@ -98,7 +99,7 @@ func (s *loginSourceFiles) Update(source *LoginSource) {
s.Lock()
defer s.Unlock()
- source.Updated = gorm.NowFunc()
+ source.Updated = s.clock()
for _, old := range s.sources {
if old.ID == source.ID {
*old = *source
@@ -109,12 +110,12 @@ func (s *loginSourceFiles) Update(source *LoginSource) {
}
// loadLoginSourceFiles loads login sources from file system.
-func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) {
+func loadLoginSourceFiles(authdPath string, clock func() time.Time) (loginSourceFilesStore, error) {
if !osutil.IsDir(authdPath) {
- return &loginSourceFiles{}, nil
+ return &loginSourceFiles{clock: clock}, nil
}
- store := &loginSourceFiles{}
+ store := &loginSourceFiles{clock: clock}
return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
diff --git a/internal/db/login_source_files_test.go b/internal/db/login_source_files_test.go
index d87239a8..d9cdc788 100644
--- a/internal/db/login_source_files_test.go
+++ b/internal/db/login_source_files_test.go
@@ -6,6 +6,7 @@ package db
import (
"testing"
+ "time"
"github.com/stretchr/testify/assert"
@@ -70,6 +71,7 @@ func Test_loginSourceFiles_Update(t *testing.T) {
{ID: 101, IsActived: true, IsDefault: true},
{ID: 102, IsActived: false},
},
+ clock: time.Now,
}
source102 := &LoginSource{
diff --git a/internal/db/login_sources.go b/internal/db/login_sources.go
index 3c47a14c..4bb1f4c2 100644
--- a/internal/db/login_sources.go
+++ b/internal/db/login_sources.go
@@ -9,9 +9,9 @@ import (
"strconv"
"time"
- "github.com/jinzhu/gorm"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/errutil"
@@ -62,7 +62,7 @@ type LoginSource struct {
}
// NOTE: This is a GORM save hook.
-func (s *LoginSource) BeforeSave() (err error) {
+func (s *LoginSource) BeforeSave(tx *gorm.DB) (err error) {
if s.Config == nil {
return nil
}
@@ -71,21 +71,22 @@ func (s *LoginSource) BeforeSave() (err error) {
}
// NOTE: This is a GORM create hook.
-func (s *LoginSource) BeforeCreate() {
- if s.CreatedUnix > 0 {
- return
+func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
+ if s.CreatedUnix == 0 {
+ s.CreatedUnix = tx.NowFunc().Unix()
+ s.UpdatedUnix = s.CreatedUnix
}
- s.CreatedUnix = gorm.NowFunc().Unix()
- s.UpdatedUnix = s.CreatedUnix
+ return nil
}
// NOTE: This is a GORM update hook.
-func (s *LoginSource) BeforeUpdate() {
- s.UpdatedUnix = gorm.NowFunc().Unix()
+func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
+ s.UpdatedUnix = tx.NowFunc().Unix()
+ return nil
}
// NOTE: This is a GORM query hook.
-func (s *LoginSource) AfterFind() error {
+func (s *LoginSource) AfterFind(tx *gorm.DB) error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
@@ -204,7 +205,7 @@ func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error)
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
- } else if !gorm.IsRecordNotFoundError(err) {
+ } else if err != gorm.ErrRecordNotFound {
return nil, err
}
@@ -253,7 +254,7 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
source := new(LoginSource)
err := db.Where("id = ?", id).First(source).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return db.files.GetByID(id)
}
return nil, err
diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go
index ac657add..cb3a16e7 100644
--- a/internal/db/login_sources_test.go
+++ b/internal/db/login_sources_test.go
@@ -8,16 +8,25 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/errutil"
)
func TestLoginSource_BeforeSave(t *testing.T) {
+ now := time.Now()
+ db := &gorm.DB{
+ Config: &gorm.Config{
+ NowFunc: func() time.Time {
+ return now
+ },
+ },
+ }
+
t.Run("Config has not been set", func(t *testing.T) {
s := &LoginSource{}
- err := s.BeforeSave()
+ err := s.BeforeSave(db)
if err != nil {
t.Fatal(err)
}
@@ -28,7 +37,7 @@ func TestLoginSource_BeforeSave(t *testing.T) {
s := &LoginSource{
Config: &PAMConfig{ServiceName: "pam_service"},
}
- err := s.BeforeSave()
+ err := s.BeforeSave(db)
if err != nil {
t.Fatal(err)
}
@@ -37,18 +46,27 @@ func TestLoginSource_BeforeSave(t *testing.T) {
}
func TestLoginSource_BeforeCreate(t *testing.T) {
+ now := time.Now()
+ db := &gorm.DB{
+ Config: &gorm.Config{
+ NowFunc: func() time.Time {
+ return now
+ },
+ },
+ }
+
t.Run("CreatedUnix has been set", func(t *testing.T) {
s := &LoginSource{CreatedUnix: 1}
- s.BeforeCreate()
+ _ = s.BeforeCreate(db)
assert.Equal(t, int64(1), s.CreatedUnix)
assert.Equal(t, int64(0), s.UpdatedUnix)
})
t.Run("CreatedUnix has not been set", func(t *testing.T) {
s := &LoginSource{}
- s.BeforeCreate()
- assert.Equal(t, gorm.NowFunc().Unix(), s.CreatedUnix)
- assert.Equal(t, gorm.NowFunc().Unix(), s.UpdatedUnix)
+ _ = s.BeforeCreate(db)
+ assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
+ assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
})
}
@@ -108,8 +126,8 @@ func test_loginSources_Create(t *testing.T, db *loginSources) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
// Try create second login source with same name should fail
_, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
diff --git a/internal/db/main_test.go b/internal/db/main_test.go
index 6813118e..09f175c5 100644
--- a/internal/db/main_test.go
+++ b/internal/db/main_test.go
@@ -8,12 +8,15 @@ import (
"flag"
"fmt"
"io/ioutil"
+ stdlog "log"
"os"
"path/filepath"
"testing"
"time"
- "github.com/jinzhu/gorm"
+ "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+ "gorm.io/gorm/schema"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/conf"
@@ -21,10 +24,11 @@ import (
"gogs.io/gogs/internal/testutil"
)
-var printSQL = flag.Bool("print-sql", false, "Print SQL executed")
-
func TestMain(m *testing.M) {
flag.Parse()
+
+ var w logger.Writer
+ level := logger.Silent
if !testing.Verbose() {
// Remove the primary logger and register a noop logger.
log.Remove(log.DefaultConsoleName)
@@ -33,10 +37,18 @@ func TestMain(m *testing.M) {
fmt.Println(err)
os.Exit(1)
}
+
+ w = &dbutil.Logger{Writer: ioutil.Discard}
+ } else {
+ w = stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags)
+ level = logger.Info
}
- now := time.Now().UTC().Truncate(time.Second)
- gorm.NowFunc = func() time.Time { return now }
+ // NOTE: AutoMigrate does not respect logger passed in gorm.Config.
+ logger.Default = logger.New(w, logger.Config{
+ SlowThreshold: 100 * time.Millisecond,
+ LogLevel: level,
+ })
os.Exit(m.Run())
}
@@ -48,7 +60,7 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
}
for _, t := range tables {
- err := db.Delete(t).Error
+ err := db.Where("TRUE").Delete(t).Error
if err != nil {
return err
}
@@ -60,15 +72,29 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
t.Helper()
dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
- db, err := openDB(conf.DatabaseOpts{
- Type: "sqlite3",
- Path: dbpath,
- })
+ now := time.Now().UTC().Truncate(time.Second)
+ db, err := openDB(
+ conf.DatabaseOpts{
+ Type: "sqlite3",
+ Path: dbpath,
+ },
+ &gorm.Config{
+ NamingStrategy: schema.NamingStrategy{
+ SingularTable: true,
+ },
+ NowFunc: func() time.Time {
+ return now
+ },
+ },
+ )
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
- _ = db.Close()
+ sqlDB, err := db.DB()
+ if err == nil {
+ _ = sqlDB.Close()
+ }
if t.Failed() {
t.Logf("Database %q left intact for inspection", dbpath)
@@ -78,15 +104,7 @@ func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
_ = os.Remove(dbpath)
})
- db.SingularTable(true)
- if !testing.Verbose() {
- db.SetLogger(&dbutil.Writer{Writer: ioutil.Discard})
- }
- if *printSQL {
- db.LogMode(true)
- }
-
- err = db.AutoMigrate(tables...).Error
+ err = db.Migrator().AutoMigrate(tables...)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/db/models.go b/internal/db/models.go
index d2868bb6..52ef4495 100644
--- a/internal/db/models.go
+++ b/internal/db/models.go
@@ -13,7 +13,7 @@ import (
"strings"
"time"
- "github.com/jinzhu/gorm"
+ "gorm.io/gorm"
log "unknwon.dev/clog/v2"
"xorm.io/core"
"xorm.io/xorm"
@@ -68,6 +68,7 @@ func getEngine() (*xorm.Engine, error) {
Param = "&"
}
+ driver := conf.Database.Type
connStr := ""
switch conf.Database.Type {
case "mysql":
@@ -92,6 +93,7 @@ func getEngine() (*xorm.Engine, error) {
connStr = fmt.Sprintf("postgres://%s:%s@%s:%s/%s%ssslmode=%s",
url.QueryEscape(conf.Database.User), url.QueryEscape(conf.Database.Password), host, port, conf.Database.Name, Param, conf.Database.SSLMode)
}
+ driver = "pgx"
case "mssql":
conf.UseMSSQL = true
@@ -108,7 +110,7 @@ func getEngine() (*xorm.Engine, error) {
default:
return nil, fmt.Errorf("unknown database type: %s", conf.Database.Type)
}
- return xorm.NewEngine(conf.Database.Type, connStr)
+ return xorm.NewEngine(driver, connStr)
}
func NewTestEngine() error {
diff --git a/internal/db/perms.go b/internal/db/perms.go
index 129f2187..1295300d 100644
--- a/internal/db/perms.go
+++ b/internal/db/perms.go
@@ -5,8 +5,7 @@
package db
import (
- "github.com/jinzhu/gorm"
- "github.com/t-tiger/gorm-bulk-insert"
+ "gorm.io/gorm"
log "unknwon.dev/clog/v2"
)
@@ -53,7 +52,7 @@ func (db *perms) AccessMode(userID int64, repo *Repository) (mode AccessMode) {
access := new(Access)
err := db.Where("user_id = ? AND repo_id = ?", userID, repo.ID).First(access).Error
if err != nil {
- if !gorm.IsRecordNotFoundError(err) {
+ if err != gorm.ErrRecordNotFound {
log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repo.ID, err)
}
return mode
@@ -66,7 +65,7 @@ func (db *perms) Authorize(userID int64, repo *Repository, desired AccessMode) b
}
func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error {
- records := make([]interface{}, 0, len(accessMap))
+ records := make([]*Access, 0, len(accessMap))
for userID, mode := range accessMap {
records = append(records, &Access{
UserID: userID,
@@ -81,6 +80,6 @@ func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) erro
return err
}
- return gormbulk.BulkInsert(tx, records, 3000)
+ return tx.Create(&records).Error
})
}
diff --git a/internal/db/repos.go b/internal/db/repos.go
index 08cc6485..fc5bce03 100644
--- a/internal/db/repos.go
+++ b/internal/db/repos.go
@@ -9,7 +9,7 @@ import (
"strings"
"time"
- "github.com/jinzhu/gorm"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/errutil"
)
@@ -26,19 +26,24 @@ type ReposStore interface {
var Repos ReposStore
// NOTE: This is a GORM create hook.
-func (r *Repository) BeforeCreate() {
- r.CreatedUnix = gorm.NowFunc().Unix()
+func (r *Repository) BeforeCreate(tx *gorm.DB) error {
+ if r.CreatedUnix == 0 {
+ r.CreatedUnix = tx.NowFunc().Unix()
+ }
+ return nil
}
// NOTE: This is a GORM update hook.
-func (r *Repository) BeforeUpdate() {
- r.UpdatedUnix = gorm.NowFunc().Unix()
+func (r *Repository) BeforeUpdate(tx *gorm.DB) error {
+ r.UpdatedUnix = tx.NowFunc().Unix()
+ return nil
}
// NOTE: This is a GORM query hook.
-func (r *Repository) AfterFind() {
+func (r *Repository) AfterFind(tx *gorm.DB) error {
r.Created = time.Unix(r.CreatedUnix, 0).Local()
r.Updated = time.Unix(r.UpdatedUnix, 0).Local()
+ return nil
}
var _ ReposStore = (*repos)(nil)
@@ -129,7 +134,7 @@ func (db *repos) GetByName(ownerID int64, name string) (*Repository, error) {
repo := new(Repository)
err := db.Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).First(repo).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrRepoNotExist{args: map[string]interface{}{"ownerID": ownerID, "name": name}}
}
return nil, err
diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go
index 42eedbba..7b8f2b00 100644
--- a/internal/db/repos_test.go
+++ b/internal/db/repos_test.go
@@ -8,7 +8,6 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/errutil"
@@ -80,7 +79,7 @@ func test_repos_create(t *testing.T, db *repos) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339))
}
func test_repos_GetByName(t *testing.T, db *repos) {
diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go
index aa90c2ff..7692e5d5 100644
--- a/internal/db/two_factors.go
+++ b/internal/db/two_factors.go
@@ -10,9 +10,8 @@ import (
"strings"
"time"
- "github.com/jinzhu/gorm"
"github.com/pkg/errors"
- "github.com/t-tiger/gorm-bulk-insert"
+ "gorm.io/gorm"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/cryptoutil"
@@ -39,13 +38,17 @@ type TwoFactorsStore interface {
var TwoFactors TwoFactorsStore
// NOTE: This is a GORM create hook.
-func (t *TwoFactor) BeforeCreate() {
- t.CreatedUnix = gorm.NowFunc().Unix()
+func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error {
+ if t.CreatedUnix == 0 {
+ t.CreatedUnix = tx.NowFunc().Unix()
+ }
+ return nil
}
// NOTE: This is a GORM query hook.
-func (t *TwoFactor) AfterFind() {
+func (t *TwoFactor) AfterFind(tx *gorm.DB) error {
t.Created = time.Unix(t.CreatedUnix, 0).Local()
+ return nil
}
var _ TwoFactorsStore = (*twoFactors)(nil)
@@ -69,18 +72,13 @@ func (db *twoFactors) Create(userID int64, key, secret string) error {
return errors.Wrap(err, "generate recovery codes")
}
- records := make([]interface{}, 0, len(recoveryCodes))
- for _, code := range recoveryCodes {
- records = append(records, code)
- }
-
return db.Transaction(func(tx *gorm.DB) error {
err := tx.Create(tf).Error
if err != nil {
return err
}
- return gormbulk.BulkInsert(tx, records, 3000)
+ return tx.Create(&recoveryCodes).Error
})
}
@@ -107,7 +105,7 @@ func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) {
tf := new(TwoFactor)
err := db.Where("user_id = ?", userID).First(tf).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
}
return nil, err
diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go
index 4f0ad9da..20dfae61 100644
--- a/internal/db/two_factors_test.go
+++ b/internal/db/two_factors_test.go
@@ -8,7 +8,6 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/errutil"
@@ -58,7 +57,7 @@ func test_twoFactors_Create(t *testing.T, db *twoFactors) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
// Verify there are 10 recover codes generated
var count int64
diff --git a/internal/db/users.go b/internal/db/users.go
index a002c290..64535ea3 100644
--- a/internal/db/users.go
+++ b/internal/db/users.go
@@ -9,8 +9,8 @@ import (
"strings"
"time"
- "github.com/jinzhu/gorm"
"github.com/pkg/errors"
+ "gorm.io/gorm"
"gogs.io/gogs/internal/cryptoutil"
"gogs.io/gogs/internal/errutil"
@@ -48,15 +48,19 @@ type UsersStore interface {
var Users UsersStore
// NOTE: This is a GORM create hook.
-func (u *User) BeforeCreate() {
- u.CreatedUnix = gorm.NowFunc().Unix()
- u.UpdatedUnix = u.CreatedUnix
+func (u *User) BeforeCreate(tx *gorm.DB) error {
+ if u.CreatedUnix == 0 {
+ u.CreatedUnix = tx.NowFunc().Unix()
+ u.UpdatedUnix = u.CreatedUnix
+ }
+ return nil
}
// NOTE: This is a GORM query hook.
-func (u *User) AfterFind() {
+func (u *User) AfterFind(tx *gorm.DB) error {
u.Created = time.Unix(u.CreatedUnix, 0).Local()
u.Updated = time.Unix(u.UpdatedUnix, 0).Local()
+ return nil
}
var _ UsersStore = (*users)(nil)
@@ -253,7 +257,7 @@ func (db *users) GetByEmail(email string) (*User, error) {
err := db.Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true).First(user).Error
if err == nil {
return user, nil
- } else if !gorm.IsRecordNotFoundError(err) {
+ } else if err != gorm.ErrRecordNotFound {
return nil, err
}
@@ -261,7 +265,7 @@ func (db *users) GetByEmail(email string) (*User, error) {
emailAddress := new(EmailAddress)
err = db.Where("email = ? AND is_activated = ?", email, true).First(emailAddress).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
}
return nil, err
@@ -274,7 +278,7 @@ func (db *users) GetByID(id int64) (*User, error) {
user := new(User)
err := db.Where("id = ?", id).First(user).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"userID": id}}
}
return nil, err
@@ -286,7 +290,7 @@ func (db *users) GetByUsername(username string) (*User, error) {
user := new(User)
err := db.Where("lower_name = ?", strings.ToLower(username)).First(user).Error
if err != nil {
- if gorm.IsRecordNotFoundError(err) {
+ if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"name": username}}
}
return nil, err
diff --git a/internal/db/users_test.go b/internal/db/users_test.go
index 2be196e3..7ff3b9f5 100644
--- a/internal/db/users_test.go
+++ b/internal/db/users_test.go
@@ -8,7 +8,6 @@ import (
"testing"
"time"
- "github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/errutil"
@@ -129,8 +128,8 @@ func test_users_Create(t *testing.T, db *users) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339))
- assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339))
+ assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
}
func test_users_GetByEmail(t *testing.T, db *users) {
diff --git a/internal/dbutil/logger.go b/internal/dbutil/logger.go
new file mode 100644
index 00000000..66426ae7
--- /dev/null
+++ b/internal/dbutil/logger.go
@@ -0,0 +1,15 @@
+package dbutil
+
+import (
+ "fmt"
+ "io"
+)
+
+// Logger is a wrapper of io.Writer for the GORM's logger.Writer.
+type Logger struct {
+ io.Writer
+}
+
+func (l *Logger) Printf(format string, args ...interface{}) {
+ fmt.Fprintf(l.Writer, format, args...)
+}
diff --git a/internal/dbutil/writer.go b/internal/dbutil/writer.go
deleted file mode 100644
index 3b8d7363..00000000
--- a/internal/dbutil/writer.go
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2020 The Gogs Authors. All rights reserved.
-// Use of this source code is governed by a MIT-style
-// license that can be found in the LICENSE file.
-
-package dbutil
-
-import (
- "fmt"
- "io"
-)
-
-// Writer is a wrapper of io.Writer for the gorm.logger.
-type Writer struct {
- io.Writer
-}
-
-func (w *Writer) Print(v ...interface{}) {
- if len(v) == 0 {
- return
- }
-
- if len(v) == 1 {
- fmt.Fprint(w.Writer, v[0])
- return
- }
-
- switch v[0] {
- case "sql":
- fmt.Fprintf(w.Writer, "[sql] [%s] [%s] %s %v (%d rows affected)", v[1:]...)
- case "log":
- fmt.Fprintf(w.Writer, "[log] [%s] %s", v[1:]...)
- case "error":
- fmt.Fprintf(w.Writer, "[err] [%s] %s", v[1:]...)
- default:
- fmt.Fprint(w.Writer, v...)
- }
-}
diff --git a/internal/dbutil/writer_test.go b/internal/dbutil/writer_test.go
deleted file mode 100644
index a484d442..00000000
--- a/internal/dbutil/writer_test.go
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright 2020 The Gogs Authors. All rights reserved.
-// Use of this source code is governed by a MIT-style
-// license that can be found in the LICENSE file.
-
-package dbutil
-
-import (
- "bytes"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestWriter_Print(t *testing.T) {
- tests := []struct {
- name string
- vs []interface{}
- expOutput string
- }{
- {
- name: "no values",
- },
- {
- name: "only one value",
- vs: []interface{}{"test"},
- expOutput: "test",
- },
- {
- name: "two values",
- vs: []interface{}{"test", "output"},
- expOutput: "testoutput",
- },
-
- {
- name: "sql",
- vs: []interface{}{"sql", "writer.go:65", "1ms", "SELECT * FROM users WHERE user_id = $1", []int{1}, 1},
- expOutput: "[sql] [writer.go:65] [1ms] SELECT * FROM users WHERE user_id = $1 [1] (1 rows affected)",
- },
- {
- name: "log",
- vs: []interface{}{"log", "writer.go:65", "something"},
- expOutput: "[log] [writer.go:65] something",
- },
- {
- name: "error",
- vs: []interface{}{"error", "writer.go:65", "something bad"},
- expOutput: "[err] [writer.go:65] something bad",
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var buf bytes.Buffer
- w := &Writer{Writer: &buf}
- w.Print(test.vs...)
- assert.Equal(t, test.expOutput, buf.String())
- })
- }
-}