diff options
author | ᴜɴᴋɴᴡᴏɴ <u@gogs.io> | 2020-04-11 01:25:19 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-11 01:25:19 +0800 |
commit | 62dda96159055ff9d485078f257445b964eb5220 (patch) | |
tree | a8bd161dc92c368bee0e6b58797e23278565cd8b /internal/db | |
parent | 5753d4cb87388c247e91eaf3ce641d309a45e760 (diff) |
access_token: migrate to GORM and add tests (#6086)
* access_token: migrate to GORM
* Add tests
* Fix tests
* Fix test clock
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/access_tokens.go | 90 | ||||
-rw-r--r-- | internal/db/access_tokens_test.go | 201 | ||||
-rw-r--r-- | internal/db/db.go | 14 | ||||
-rw-r--r-- | internal/db/errors/token.go | 16 | ||||
-rw-r--r-- | internal/db/lfs.go | 2 | ||||
-rw-r--r-- | internal/db/main_test.go | 11 | ||||
-rw-r--r-- | internal/db/mocks.go | 19 | ||||
-rw-r--r-- | internal/db/models.go | 24 | ||||
-rw-r--r-- | internal/db/token.go | 81 |
9 files changed, 343 insertions, 115 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go index c3be93ac..f4f4ee80 100644 --- a/internal/db/access_tokens.go +++ b/internal/db/access_tokens.go @@ -6,27 +6,111 @@ package db import ( "fmt" + "time" "github.com/jinzhu/gorm" + gouuid "github.com/satori/go.uuid" "gogs.io/gogs/internal/errutil" + "gogs.io/gogs/internal/tool" ) // AccessTokensStore is the persistent interface for access tokens. // // NOTE: All methods are sorted in alphabetical order. type AccessTokensStore interface { + // Create creates a new access token and persist to database. + // It returns ErrAccessTokenAlreadyExist when an access token + // with same name already exists for the user. + Create(userID int64, name string) (*AccessToken, error) + // DeleteByID deletes the access token by given ID. + // 🚨 SECURITY: The "userID" is required to prevent attacker + // deletes arbitrary access token that belongs to another user. + DeleteByID(userID, id int64) error // GetBySHA returns the access token with given SHA1. // It returns ErrAccessTokenNotExist when not found. GetBySHA(sha string) (*AccessToken, error) + // List returns all access tokens belongs to given user. + List(userID int64) ([]*AccessToken, error) // Save persists all values of given access token. + // The Updated field is set to current time automatically. Save(t *AccessToken) error } var AccessTokens AccessTokensStore +// AccessToken is a personal access token. +type AccessToken struct { + ID int64 + UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid;INDEX"` + Name string + Sha1 string `xorm:"UNIQUE VARCHAR(40)" gorm:"TYPE:VARCHAR(40);UNIQUE"` + + Created time.Time `xorm:"-" gorm:"-" json:"-"` + CreatedUnix int64 + Updated time.Time `xorm:"-" gorm:"-" json:"-"` + UpdatedUnix int64 + HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"` + HasUsed bool `xorm:"-" gorm:"-" json:"-"` +} + +// NOTE: This is a GORM create hook. +func (t *AccessToken) BeforeCreate() { + t.CreatedUnix = t.Created.Unix() +} + +// NOTE: This is a GORM update hook. +func (t *AccessToken) BeforeUpdate() { + t.UpdatedUnix = t.Updated.Unix() +} + +// NOTE: This is a GORM query hook. +func (t *AccessToken) AfterFind() { + t.Created = time.Unix(t.CreatedUnix, 0).Local() + t.Updated = time.Unix(t.UpdatedUnix, 0).Local() + t.HasUsed = t.Updated.After(t.Created) + t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now()) +} + +var _ AccessTokensStore = (*accessTokens)(nil) + type accessTokens struct { *gorm.DB + clock func() time.Time +} + +type ErrAccessTokenAlreadyExist struct { + args errutil.Args +} + +func IsErrAccessTokenAlreadyExist(err error) bool { + _, ok := err.(ErrAccessTokenAlreadyExist) + return ok +} + +func (err ErrAccessTokenAlreadyExist) Error() string { + return fmt.Sprintf("access token already exists: %v", err.args) +} + +func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) { + err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error + if err == nil { + return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}} + } else if !gorm.IsRecordNotFoundError(err) { + return nil, err + } + + token := &AccessToken{ + UserID: userID, + Name: name, + Sha1: tool.SHA1(gouuid.NewV4().String()), + Created: db.clock(), + } + return token, db.DB.Create(token).Error +} + +func (db *accessTokens) DeleteByID(userID, id int64) error { + return db.Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error } var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil) @@ -60,6 +144,12 @@ func (db *accessTokens) GetBySHA(sha string) (*AccessToken, error) { return token, nil } +func (db *accessTokens) List(userID int64) ([]*AccessToken, error) { + var tokens []*AccessToken + return tokens, db.Where("uid = ?", userID).Find(&tokens).Error +} + func (db *accessTokens) Save(t *AccessToken) error { + t.Updated = db.clock() return db.DB.Save(t).Error } diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go new file mode 100644 index 00000000..f6d62745 --- /dev/null +++ b/internal/db/access_tokens_test.go @@ -0,0 +1,201 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "gogs.io/gogs/internal/conf" + "gogs.io/gogs/internal/errutil" +) + +func Test_accessTokens(t *testing.T) { + if testing.Short() { + t.Skip() + } + + t.Parallel() + + dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%d.db", time.Now().Unix())) + gdb, err := openDB(conf.DatabaseOpts{ + Type: "sqlite3", + Path: dbpath, + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = gdb.Close() + + if t.Failed() { + t.Logf("Database %q left intact for inspection", dbpath) + return + } + + _ = os.Remove(dbpath) + }) + + err = gdb.AutoMigrate(new(AccessToken)).Error + if err != nil { + t.Fatal(err) + } + + now := time.Now().Truncate(time.Second) + clock := func() time.Time { return now } + db := &accessTokens{DB: gdb, clock: clock} + + for _, tc := range []struct { + name string + test func(*testing.T, *accessTokens) + }{ + {"Create", test_accessTokens_Create}, + {"DeleteByID", test_accessTokens_DeleteByID}, + {"GetBySHA", test_accessTokens_GetBySHA}, + {"List", test_accessTokens_List}, + {"Save", test_accessTokens_Save}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + err := deleteTables(gdb, new(AccessToken)) + if err != nil { + t.Fatal(err) + } + }) + tc.test(t, db) + }) + } +} + +func test_accessTokens_Create(t *testing.T, db *accessTokens) { + // Create first access token with name "Test" + token, err := db.Create(1, "Test") + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, int64(1), token.UserID) + assert.Equal(t, "Test", token.Name) + assert.Equal(t, 40, len(token.Sha1), "sha1 length") + assert.Equal(t, db.clock(), token.Created) + + // Try create second access token with same name should fail + _, err = db.Create(token.UserID, token.Name) + expErr := ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": token.UserID, "name": token.Name}} + assert.Equal(t, expErr, err) +} + +func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) { + // Create an access token with name "Test" + token, err := db.Create(1, "Test") + if err != nil { + t.Fatal(err) + } + + // We should be able to get it back + _, err = db.GetBySHA(token.Sha1) + if err != nil { + t.Fatal(err) + } + + // Delete a token with mismatched user ID is noop + err = db.DeleteByID(2, token.ID) + if err != nil { + t.Fatal(err) + } + _, err = db.GetBySHA(token.Sha1) + if err != nil { + t.Fatal(err) + } + + // Now delete this token with correct user ID + err = db.DeleteByID(token.UserID, token.ID) + if err != nil { + t.Fatal(err) + } + + // We should get token not found error + _, err = db.GetBySHA(token.Sha1) + expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}} + assert.Equal(t, expErr, err) +} + +func test_accessTokens_GetBySHA(t *testing.T, db *accessTokens) { + // Create an access token with name "Test" + token, err := db.Create(1, "Test") + if err != nil { + t.Fatal(err) + } + + // We should be able to get it back + _, err = db.GetBySHA(token.Sha1) + if err != nil { + t.Fatal(err) + } + + // Try to get a non-existent token + _, err = db.GetBySHA("bad_sha") + expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}} + assert.Equal(t, expErr, err) +} + +func test_accessTokens_List(t *testing.T, db *accessTokens) { + // Create two access tokens for user 1 + _, err := db.Create(1, "user1_1") + if err != nil { + t.Fatal(err) + } + _, err = db.Create(1, "user1_2") + if err != nil { + t.Fatal(err) + } + + // Create one access token for user 2 + _, err = db.Create(2, "user2_1") + if err != nil { + t.Fatal(err) + } + + // List all access tokens for user 1 + tokens, err := db.List(1) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 2, len(tokens), "number of tokens") + + assert.Equal(t, int64(1), tokens[0].UserID) + assert.Equal(t, "user1_1", tokens[0].Name) + + assert.Equal(t, int64(1), tokens[1].UserID) + assert.Equal(t, "user1_2", tokens[1].Name) +} + +func test_accessTokens_Save(t *testing.T, db *accessTokens) { + // Create an access token with name "Test" + token, err := db.Create(1, "Test") + if err != nil { + t.Fatal(err) + } + + // Updated field is zero now + assert.True(t, token.Updated.IsZero()) + + err = db.Save(token) + if err != nil { + t.Fatal(err) + } + + // Get back from DB should have Updated set + token, err = db.GetBySHA(token.Sha1) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, db.clock(), token.Updated) +} diff --git a/internal/db/db.go b/internal/db/db.go index e3796039..27135eba 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -39,7 +39,7 @@ func parsePostgreSQLHostPort(info string) (host, port string) { return host, port } -func parseMSSQLHostPort(info string) (host, port string) { +func parseMSSQLHostPort(info string) (host, port string) { host, port = "127.0.0.1", "1433" if strings.Contains(info, ":") { host = strings.Split(info, ":")[0] @@ -122,6 +122,11 @@ func getLogWriter() (io.Writer, error) { return w, nil } +var tables = []interface{}{ + new(AccessToken), + new(LFSObject), +} + func Init() error { db, err := openDB(conf.Database) if err != nil { @@ -150,16 +155,17 @@ func Init() error { case "mssql": conf.UseMSSQL = true case "sqlite3": - conf.UseMySQL = true + conf.UseSQLite3 = true } - err = db.AutoMigrate(new(LFSObject)).Error + err = db.AutoMigrate(tables...).Error if err != nil { return errors.Wrap(err, "migrate schemes") } + clock := func() time.Time {return time.Now().UTC().Truncate(time.Microsecond)} // Initialize stores, sorted in alphabetical order. - AccessTokens = &accessTokens{DB: db} + AccessTokens = &accessTokens{DB: db, clock: clock} LoginSources = &loginSources{DB: db} LFS = &lfs{DB: db} Perms = &perms{DB: db} diff --git a/internal/db/errors/token.go b/internal/db/errors/token.go deleted file mode 100644 index d6a4577a..00000000 --- a/internal/db/errors/token.go +++ /dev/null @@ -1,16 +0,0 @@ -package errors - -import "fmt" - -type AccessTokenNameAlreadyExist struct { - Name string -} - -func IsAccessTokenNameAlreadyExist(err error) bool { - _, ok := err.(AccessTokenNameAlreadyExist) - return ok -} - -func (err AccessTokenNameAlreadyExist) Error() string { - return fmt.Sprintf("access token already exist [name: %s]", err.Name) -} diff --git a/internal/db/lfs.go b/internal/db/lfs.go index 26a24df5..128069ed 100644 --- a/internal/db/lfs.go +++ b/internal/db/lfs.go @@ -37,7 +37,7 @@ type lfs struct { // LFSObject is the relation between an LFS object and a repository. type LFSObject struct { RepoID int64 `gorm:"PRIMARY_KEY;AUTO_INCREMENT:false"` - OID lfsutil.OID `gorm:"PRIMARY_KEY;column:oid"` + OID lfsutil.OID `gorm:"PRIMARY_KEY;COLUMN:oid"` Size int64 `gorm:"NOT NULL"` Storage lfsutil.Storage `gorm:"NOT NULL"` CreatedAt time.Time `gorm:"NOT NULL"` diff --git a/internal/db/main_test.go b/internal/db/main_test.go index 24393f5b..d141d3cd 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -10,6 +10,7 @@ import ( "os" "testing" + "github.com/jinzhu/gorm" log "unknwon.dev/clog/v2" "gogs.io/gogs/internal/testutil" @@ -28,3 +29,13 @@ func TestMain(m *testing.M) { } os.Exit(m.Run()) } + +func deleteTables(db *gorm.DB, tables ...interface{}) error { + for _, t := range tables { + err := db.Delete(t).Error + if err != nil { + return err + } + } + return nil +} diff --git a/internal/db/mocks.go b/internal/db/mocks.go index 8cef2b96..12622a7c 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -15,14 +15,29 @@ import ( var _ AccessTokensStore = (*MockAccessTokensStore)(nil) type MockAccessTokensStore struct { - MockGetBySHA func(sha string) (*AccessToken, error) - MockSave func(t *AccessToken) error + MockCreate func(userID int64, name string) (*AccessToken, error) + MockDeleteByID func(userID, id int64) error + MockGetBySHA func(sha string) (*AccessToken, error) + MockList func(userID int64) ([]*AccessToken, error) + MockSave func(t *AccessToken) error +} + +func (m *MockAccessTokensStore) Create(userID int64, name string) (*AccessToken, error) { + return m.MockCreate(userID, name) +} + +func (m *MockAccessTokensStore) DeleteByID(userID, id int64) error { + return m.MockDeleteByID(userID, id) } func (m *MockAccessTokensStore) GetBySHA(sha string) (*AccessToken, error) { return m.MockGetBySHA(sha) } +func (m *MockAccessTokensStore) List(userID int64) ([]*AccessToken, error) { + return m.MockList(userID) +} + func (m *MockAccessTokensStore) Save(t *AccessToken) error { return m.MockSave(t) } diff --git a/internal/db/models.go b/internal/db/models.go index 3bb35e7f..9b5b0d9a 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -41,14 +41,14 @@ type Engine interface { } var ( - x *xorm.Engine - tables []interface{} - HasEngine bool + x *xorm.Engine + legacyTables []interface{} + HasEngine bool ) func init() { - tables = append(tables, - new(User), new(PublicKey), new(AccessToken), new(TwoFactor), new(TwoFactorRecoveryCode), + legacyTables = append(legacyTables, + new(User), new(PublicKey), new(TwoFactor), new(TwoFactorRecoveryCode), new(Repository), new(DeployKey), new(Collaboration), new(Access), new(Upload), new(Watch), new(Star), new(Follow), new(Action), new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser), @@ -120,7 +120,7 @@ func NewTestEngine() error { } x.SetMapper(core.GonicMapper{}) - return x.StoreEngine("InnoDB").Sync2(tables...) + return x.StoreEngine("InnoDB").Sync2(legacyTables...) } func SetEngine() (err error) { @@ -167,7 +167,7 @@ func NewEngine() (err error) { return fmt.Errorf("migrate: %v", err) } - if err = x.StoreEngine("InnoDB").Sync2(tables...); err != nil { + if err = x.StoreEngine("InnoDB").Sync2(legacyTables...); err != nil { return fmt.Errorf("sync structs to database tables: %v\n", err) } @@ -227,8 +227,9 @@ func DumpDatabase(dirPath string) error { } // Purposely create a local variable to not modify global variable - tables := append(tables, new(Version)) - for _, table := range tables { + allTables := append(legacyTables, new(Version)) + allTables = append(allTables, tables...) + for _, table := range allTables { tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") tableFile := path.Join(dirPath, tableName+".json") f, err := os.Create(tableFile) @@ -257,8 +258,9 @@ func ImportDatabase(dirPath string, verbose bool) (err error) { } // Purposely create a local variable to not modify global variable - tables := append(tables, new(Version)) - for _, table := range tables { + allTables := append(legacyTables, new(Version)) + allTables = append(allTables, tables...) + for _, table := range allTables { tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.") tableFile := path.Join(dirPath, tableName+".json") if !com.IsExist(tableFile) { diff --git a/internal/db/token.go b/internal/db/token.go deleted file mode 100644 index afe747e0..00000000 --- a/internal/db/token.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2014 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "time" - - gouuid "github.com/satori/go.uuid" - "xorm.io/xorm" - - "gogs.io/gogs/internal/db/errors" - "gogs.io/gogs/internal/tool" -) - -// AccessToken represents a personal access token. -type AccessToken struct { - ID int64 - UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid"` - Name string - Sha1 string `xorm:"UNIQUE VARCHAR(40)"` - - Created time.Time `xorm:"-" gorm:"-" json:"-"` - CreatedUnix int64 - Updated time.Time `xorm:"-" gorm:"-" json:"-"` // Note: Updated must below Created for AfterSet. - UpdatedUnix int64 - HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"` - HasUsed bool `xorm:"-" gorm:"-" json:"-"` -} - -func (t *AccessToken) BeforeInsert() { - t.CreatedUnix = time.Now().Unix() -} - -func (t *AccessToken) BeforeUpdate() { - t.UpdatedUnix = time.Now().Unix() -} - -func (t *AccessToken) AfterSet(colName string, _ xorm.Cell) { - switch colName { - case "created_unix": - t.Created = time.Unix(t.CreatedUnix, 0).Local() - case "updated_unix": - t.Updated = time.Unix(t.UpdatedUnix, 0).Local() - t.HasUsed = t.Updated.After(t.Created) - t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now()) - } -} - -// NewAccessToken creates new access token. -func NewAccessToken(t *AccessToken) error { - t.Sha1 = tool.SHA1(gouuid.NewV4().String()) - has, err := x.Get(&AccessToken{ - UserID: t.UserID, - Name: t.Name, - }) - if err != nil { - return err - } else if has { - return errors.AccessTokenNameAlreadyExist{Name: t.Name} - } - - _, err = x.Insert(t) - return err -} - -// ListAccessTokens returns a list of access tokens belongs to given user. -func ListAccessTokens(uid int64) ([]*AccessToken, error) { - tokens := make([]*AccessToken, 0, 5) - return tokens, x.Where("uid=?", uid).Desc("id").Find(&tokens) -} - -// DeleteAccessTokenOfUserByID deletes access token by given ID. -func DeleteAccessTokenOfUserByID(userID, id int64) error { - _, err := x.Delete(&AccessToken{ - ID: id, - UserID: userID, - }) - return err -} |