aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorᴜɴᴋɴᴡᴏɴ <u@gogs.io>2020-04-11 01:25:19 +0800
committerGitHub <noreply@github.com>2020-04-11 01:25:19 +0800
commit62dda96159055ff9d485078f257445b964eb5220 (patch)
treea8bd161dc92c368bee0e6b58797e23278565cd8b /internal/db
parent5753d4cb87388c247e91eaf3ce641d309a45e760 (diff)
access_token: migrate to GORM and add tests (#6086)
* access_token: migrate to GORM * Add tests * Fix tests * Fix test clock
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/access_tokens.go90
-rw-r--r--internal/db/access_tokens_test.go201
-rw-r--r--internal/db/db.go14
-rw-r--r--internal/db/errors/token.go16
-rw-r--r--internal/db/lfs.go2
-rw-r--r--internal/db/main_test.go11
-rw-r--r--internal/db/mocks.go19
-rw-r--r--internal/db/models.go24
-rw-r--r--internal/db/token.go81
9 files changed, 343 insertions, 115 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go
index c3be93ac..f4f4ee80 100644
--- a/internal/db/access_tokens.go
+++ b/internal/db/access_tokens.go
@@ -6,27 +6,111 @@ package db
import (
"fmt"
+ "time"
"github.com/jinzhu/gorm"
+ gouuid "github.com/satori/go.uuid"
"gogs.io/gogs/internal/errutil"
+ "gogs.io/gogs/internal/tool"
)
// AccessTokensStore is the persistent interface for access tokens.
//
// NOTE: All methods are sorted in alphabetical order.
type AccessTokensStore interface {
+ // Create creates a new access token and persist to database.
+ // It returns ErrAccessTokenAlreadyExist when an access token
+ // with same name already exists for the user.
+ Create(userID int64, name string) (*AccessToken, error)
+ // DeleteByID deletes the access token by given ID.
+ // 🚨 SECURITY: The "userID" is required to prevent attacker
+ // deletes arbitrary access token that belongs to another user.
+ DeleteByID(userID, id int64) error
// GetBySHA returns the access token with given SHA1.
// It returns ErrAccessTokenNotExist when not found.
GetBySHA(sha string) (*AccessToken, error)
+ // List returns all access tokens belongs to given user.
+ List(userID int64) ([]*AccessToken, error)
// Save persists all values of given access token.
+ // The Updated field is set to current time automatically.
Save(t *AccessToken) error
}
var AccessTokens AccessTokensStore
+// AccessToken is a personal access token.
+type AccessToken struct {
+ ID int64
+ UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid;INDEX"`
+ Name string
+ Sha1 string `xorm:"UNIQUE VARCHAR(40)" gorm:"TYPE:VARCHAR(40);UNIQUE"`
+
+ Created time.Time `xorm:"-" gorm:"-" json:"-"`
+ CreatedUnix int64
+ Updated time.Time `xorm:"-" gorm:"-" json:"-"`
+ UpdatedUnix int64
+ HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"`
+ HasUsed bool `xorm:"-" gorm:"-" json:"-"`
+}
+
+// NOTE: This is a GORM create hook.
+func (t *AccessToken) BeforeCreate() {
+ t.CreatedUnix = t.Created.Unix()
+}
+
+// NOTE: This is a GORM update hook.
+func (t *AccessToken) BeforeUpdate() {
+ t.UpdatedUnix = t.Updated.Unix()
+}
+
+// NOTE: This is a GORM query hook.
+func (t *AccessToken) AfterFind() {
+ t.Created = time.Unix(t.CreatedUnix, 0).Local()
+ t.Updated = time.Unix(t.UpdatedUnix, 0).Local()
+ t.HasUsed = t.Updated.After(t.Created)
+ t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now())
+}
+
+var _ AccessTokensStore = (*accessTokens)(nil)
+
type accessTokens struct {
*gorm.DB
+ clock func() time.Time
+}
+
+type ErrAccessTokenAlreadyExist struct {
+ args errutil.Args
+}
+
+func IsErrAccessTokenAlreadyExist(err error) bool {
+ _, ok := err.(ErrAccessTokenAlreadyExist)
+ return ok
+}
+
+func (err ErrAccessTokenAlreadyExist) Error() string {
+ return fmt.Sprintf("access token already exists: %v", err.args)
+}
+
+func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) {
+ err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
+ if err == nil {
+ return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
+ } else if !gorm.IsRecordNotFoundError(err) {
+ return nil, err
+ }
+
+ token := &AccessToken{
+ UserID: userID,
+ Name: name,
+ Sha1: tool.SHA1(gouuid.NewV4().String()),
+ Created: db.clock(),
+ }
+ return token, db.DB.Create(token).Error
+}
+
+func (db *accessTokens) DeleteByID(userID, id int64) error {
+ return db.Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
}
var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil)
@@ -60,6 +144,12 @@ func (db *accessTokens) GetBySHA(sha string) (*AccessToken, error) {
return token, nil
}
+func (db *accessTokens) List(userID int64) ([]*AccessToken, error) {
+ var tokens []*AccessToken
+ return tokens, db.Where("uid = ?", userID).Find(&tokens).Error
+}
+
func (db *accessTokens) Save(t *AccessToken) error {
+ t.Updated = db.clock()
return db.DB.Save(t).Error
}
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go
new file mode 100644
index 00000000..f6d62745
--- /dev/null
+++ b/internal/db/access_tokens_test.go
@@ -0,0 +1,201 @@
+// Copyright 2020 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+
+ "gogs.io/gogs/internal/conf"
+ "gogs.io/gogs/internal/errutil"
+)
+
+func Test_accessTokens(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ t.Parallel()
+
+ dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%d.db", time.Now().Unix()))
+ gdb, err := openDB(conf.DatabaseOpts{
+ Type: "sqlite3",
+ Path: dbpath,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ _ = gdb.Close()
+
+ if t.Failed() {
+ t.Logf("Database %q left intact for inspection", dbpath)
+ return
+ }
+
+ _ = os.Remove(dbpath)
+ })
+
+ err = gdb.AutoMigrate(new(AccessToken)).Error
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ now := time.Now().Truncate(time.Second)
+ clock := func() time.Time { return now }
+ db := &accessTokens{DB: gdb, clock: clock}
+
+ for _, tc := range []struct {
+ name string
+ test func(*testing.T, *accessTokens)
+ }{
+ {"Create", test_accessTokens_Create},
+ {"DeleteByID", test_accessTokens_DeleteByID},
+ {"GetBySHA", test_accessTokens_GetBySHA},
+ {"List", test_accessTokens_List},
+ {"Save", test_accessTokens_Save},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Cleanup(func() {
+ err := deleteTables(gdb, new(AccessToken))
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ tc.test(t, db)
+ })
+ }
+}
+
+func test_accessTokens_Create(t *testing.T, db *accessTokens) {
+ // Create first access token with name "Test"
+ token, err := db.Create(1, "Test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, int64(1), token.UserID)
+ assert.Equal(t, "Test", token.Name)
+ assert.Equal(t, 40, len(token.Sha1), "sha1 length")
+ assert.Equal(t, db.clock(), token.Created)
+
+ // Try create second access token with same name should fail
+ _, err = db.Create(token.UserID, token.Name)
+ expErr := ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": token.UserID, "name": token.Name}}
+ assert.Equal(t, expErr, err)
+}
+
+func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) {
+ // Create an access token with name "Test"
+ token, err := db.Create(1, "Test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We should be able to get it back
+ _, err = db.GetBySHA(token.Sha1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Delete a token with mismatched user ID is noop
+ err = db.DeleteByID(2, token.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.GetBySHA(token.Sha1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Now delete this token with correct user ID
+ err = db.DeleteByID(token.UserID, token.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We should get token not found error
+ _, err = db.GetBySHA(token.Sha1)
+ expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}}
+ assert.Equal(t, expErr, err)
+}
+
+func test_accessTokens_GetBySHA(t *testing.T, db *accessTokens) {
+ // Create an access token with name "Test"
+ token, err := db.Create(1, "Test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We should be able to get it back
+ _, err = db.GetBySHA(token.Sha1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Try to get a non-existent token
+ _, err = db.GetBySHA("bad_sha")
+ expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}}
+ assert.Equal(t, expErr, err)
+}
+
+func test_accessTokens_List(t *testing.T, db *accessTokens) {
+ // Create two access tokens for user 1
+ _, err := db.Create(1, "user1_1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Create(1, "user1_2")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create one access token for user 2
+ _, err = db.Create(2, "user2_1")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // List all access tokens for user 1
+ tokens, err := db.List(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, 2, len(tokens), "number of tokens")
+
+ assert.Equal(t, int64(1), tokens[0].UserID)
+ assert.Equal(t, "user1_1", tokens[0].Name)
+
+ assert.Equal(t, int64(1), tokens[1].UserID)
+ assert.Equal(t, "user1_2", tokens[1].Name)
+}
+
+func test_accessTokens_Save(t *testing.T, db *accessTokens) {
+ // Create an access token with name "Test"
+ token, err := db.Create(1, "Test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Updated field is zero now
+ assert.True(t, token.Updated.IsZero())
+
+ err = db.Save(token)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Get back from DB should have Updated set
+ token, err = db.GetBySHA(token.Sha1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, db.clock(), token.Updated)
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index e3796039..27135eba 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -39,7 +39,7 @@ func parsePostgreSQLHostPort(info string) (host, port string) {
return host, port
}
-func parseMSSQLHostPort(info string) (host, port string) {
+func parseMSSQLHostPort(info string) (host, port string) {
host, port = "127.0.0.1", "1433"
if strings.Contains(info, ":") {
host = strings.Split(info, ":")[0]
@@ -122,6 +122,11 @@ func getLogWriter() (io.Writer, error) {
return w, nil
}
+var tables = []interface{}{
+ new(AccessToken),
+ new(LFSObject),
+}
+
func Init() error {
db, err := openDB(conf.Database)
if err != nil {
@@ -150,16 +155,17 @@ func Init() error {
case "mssql":
conf.UseMSSQL = true
case "sqlite3":
- conf.UseMySQL = true
+ conf.UseSQLite3 = true
}
- err = db.AutoMigrate(new(LFSObject)).Error
+ err = db.AutoMigrate(tables...).Error
if err != nil {
return errors.Wrap(err, "migrate schemes")
}
+ clock := func() time.Time {return time.Now().UTC().Truncate(time.Microsecond)}
// Initialize stores, sorted in alphabetical order.
- AccessTokens = &accessTokens{DB: db}
+ AccessTokens = &accessTokens{DB: db, clock: clock}
LoginSources = &loginSources{DB: db}
LFS = &lfs{DB: db}
Perms = &perms{DB: db}
diff --git a/internal/db/errors/token.go b/internal/db/errors/token.go
deleted file mode 100644
index d6a4577a..00000000
--- a/internal/db/errors/token.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package errors
-
-import "fmt"
-
-type AccessTokenNameAlreadyExist struct {
- Name string
-}
-
-func IsAccessTokenNameAlreadyExist(err error) bool {
- _, ok := err.(AccessTokenNameAlreadyExist)
- return ok
-}
-
-func (err AccessTokenNameAlreadyExist) Error() string {
- return fmt.Sprintf("access token already exist [name: %s]", err.Name)
-}
diff --git a/internal/db/lfs.go b/internal/db/lfs.go
index 26a24df5..128069ed 100644
--- a/internal/db/lfs.go
+++ b/internal/db/lfs.go
@@ -37,7 +37,7 @@ type lfs struct {
// LFSObject is the relation between an LFS object and a repository.
type LFSObject struct {
RepoID int64 `gorm:"PRIMARY_KEY;AUTO_INCREMENT:false"`
- OID lfsutil.OID `gorm:"PRIMARY_KEY;column:oid"`
+ OID lfsutil.OID `gorm:"PRIMARY_KEY;COLUMN:oid"`
Size int64 `gorm:"NOT NULL"`
Storage lfsutil.Storage `gorm:"NOT NULL"`
CreatedAt time.Time `gorm:"NOT NULL"`
diff --git a/internal/db/main_test.go b/internal/db/main_test.go
index 24393f5b..d141d3cd 100644
--- a/internal/db/main_test.go
+++ b/internal/db/main_test.go
@@ -10,6 +10,7 @@ import (
"os"
"testing"
+ "github.com/jinzhu/gorm"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/testutil"
@@ -28,3 +29,13 @@ func TestMain(m *testing.M) {
}
os.Exit(m.Run())
}
+
+func deleteTables(db *gorm.DB, tables ...interface{}) error {
+ for _, t := range tables {
+ err := db.Delete(t).Error
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index 8cef2b96..12622a7c 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -15,14 +15,29 @@ import (
var _ AccessTokensStore = (*MockAccessTokensStore)(nil)
type MockAccessTokensStore struct {
- MockGetBySHA func(sha string) (*AccessToken, error)
- MockSave func(t *AccessToken) error
+ MockCreate func(userID int64, name string) (*AccessToken, error)
+ MockDeleteByID func(userID, id int64) error
+ MockGetBySHA func(sha string) (*AccessToken, error)
+ MockList func(userID int64) ([]*AccessToken, error)
+ MockSave func(t *AccessToken) error
+}
+
+func (m *MockAccessTokensStore) Create(userID int64, name string) (*AccessToken, error) {
+ return m.MockCreate(userID, name)
+}
+
+func (m *MockAccessTokensStore) DeleteByID(userID, id int64) error {
+ return m.MockDeleteByID(userID, id)
}
func (m *MockAccessTokensStore) GetBySHA(sha string) (*AccessToken, error) {
return m.MockGetBySHA(sha)
}
+func (m *MockAccessTokensStore) List(userID int64) ([]*AccessToken, error) {
+ return m.MockList(userID)
+}
+
func (m *MockAccessTokensStore) Save(t *AccessToken) error {
return m.MockSave(t)
}
diff --git a/internal/db/models.go b/internal/db/models.go
index 3bb35e7f..9b5b0d9a 100644
--- a/internal/db/models.go
+++ b/internal/db/models.go
@@ -41,14 +41,14 @@ type Engine interface {
}
var (
- x *xorm.Engine
- tables []interface{}
- HasEngine bool
+ x *xorm.Engine
+ legacyTables []interface{}
+ HasEngine bool
)
func init() {
- tables = append(tables,
- new(User), new(PublicKey), new(AccessToken), new(TwoFactor), new(TwoFactorRecoveryCode),
+ legacyTables = append(legacyTables,
+ new(User), new(PublicKey), new(TwoFactor), new(TwoFactorRecoveryCode),
new(Repository), new(DeployKey), new(Collaboration), new(Access), new(Upload),
new(Watch), new(Star), new(Follow), new(Action),
new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser),
@@ -120,7 +120,7 @@ func NewTestEngine() error {
}
x.SetMapper(core.GonicMapper{})
- return x.StoreEngine("InnoDB").Sync2(tables...)
+ return x.StoreEngine("InnoDB").Sync2(legacyTables...)
}
func SetEngine() (err error) {
@@ -167,7 +167,7 @@ func NewEngine() (err error) {
return fmt.Errorf("migrate: %v", err)
}
- if err = x.StoreEngine("InnoDB").Sync2(tables...); err != nil {
+ if err = x.StoreEngine("InnoDB").Sync2(legacyTables...); err != nil {
return fmt.Errorf("sync structs to database tables: %v\n", err)
}
@@ -227,8 +227,9 @@ func DumpDatabase(dirPath string) error {
}
// Purposely create a local variable to not modify global variable
- tables := append(tables, new(Version))
- for _, table := range tables {
+ allTables := append(legacyTables, new(Version))
+ allTables = append(allTables, tables...)
+ for _, table := range allTables {
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := path.Join(dirPath, tableName+".json")
f, err := os.Create(tableFile)
@@ -257,8 +258,9 @@ func ImportDatabase(dirPath string, verbose bool) (err error) {
}
// Purposely create a local variable to not modify global variable
- tables := append(tables, new(Version))
- for _, table := range tables {
+ allTables := append(legacyTables, new(Version))
+ allTables = append(allTables, tables...)
+ for _, table := range allTables {
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := path.Join(dirPath, tableName+".json")
if !com.IsExist(tableFile) {
diff --git a/internal/db/token.go b/internal/db/token.go
deleted file mode 100644
index afe747e0..00000000
--- a/internal/db/token.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2014 The Gogs Authors. All rights reserved.
-// Use of this source code is governed by a MIT-style
-// license that can be found in the LICENSE file.
-
-package db
-
-import (
- "time"
-
- gouuid "github.com/satori/go.uuid"
- "xorm.io/xorm"
-
- "gogs.io/gogs/internal/db/errors"
- "gogs.io/gogs/internal/tool"
-)
-
-// AccessToken represents a personal access token.
-type AccessToken struct {
- ID int64
- UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid"`
- Name string
- Sha1 string `xorm:"UNIQUE VARCHAR(40)"`
-
- Created time.Time `xorm:"-" gorm:"-" json:"-"`
- CreatedUnix int64
- Updated time.Time `xorm:"-" gorm:"-" json:"-"` // Note: Updated must below Created for AfterSet.
- UpdatedUnix int64
- HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"`
- HasUsed bool `xorm:"-" gorm:"-" json:"-"`
-}
-
-func (t *AccessToken) BeforeInsert() {
- t.CreatedUnix = time.Now().Unix()
-}
-
-func (t *AccessToken) BeforeUpdate() {
- t.UpdatedUnix = time.Now().Unix()
-}
-
-func (t *AccessToken) AfterSet(colName string, _ xorm.Cell) {
- switch colName {
- case "created_unix":
- t.Created = time.Unix(t.CreatedUnix, 0).Local()
- case "updated_unix":
- t.Updated = time.Unix(t.UpdatedUnix, 0).Local()
- t.HasUsed = t.Updated.After(t.Created)
- t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now())
- }
-}
-
-// NewAccessToken creates new access token.
-func NewAccessToken(t *AccessToken) error {
- t.Sha1 = tool.SHA1(gouuid.NewV4().String())
- has, err := x.Get(&AccessToken{
- UserID: t.UserID,
- Name: t.Name,
- })
- if err != nil {
- return err
- } else if has {
- return errors.AccessTokenNameAlreadyExist{Name: t.Name}
- }
-
- _, err = x.Insert(t)
- return err
-}
-
-// ListAccessTokens returns a list of access tokens belongs to given user.
-func ListAccessTokens(uid int64) ([]*AccessToken, error) {
- tokens := make([]*AccessToken, 0, 5)
- return tokens, x.Where("uid=?", uid).Desc("id").Find(&tokens)
-}
-
-// DeleteAccessTokenOfUserByID deletes access token by given ID.
-func DeleteAccessTokenOfUserByID(userID, id int64) error {
- _, err := x.Delete(&AccessToken{
- ID: id,
- UserID: userID,
- })
- return err
-}