diff options
Diffstat (limited to 'internal/db/access_tokens.go')
-rw-r--r-- | internal/db/access_tokens.go | 64 |
1 files changed, 33 insertions, 31 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go index 8915f480..bbb7b851 100644 --- a/internal/db/access_tokens.go +++ b/internal/db/access_tokens.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "time" @@ -19,22 +20,23 @@ import ( // // NOTE: All methods are sorted in alphabetical order. type AccessTokensStore interface { - // Create creates a new access token and persist to database. - // It returns ErrAccessTokenAlreadyExist when an access token - // with same name already exists for the user. - Create(userID int64, name string) (*AccessToken, error) + // Create creates a new access token and persist to database. It returns + // ErrAccessTokenAlreadyExist when an access token with same name already exists + // for the user. + Create(ctx context.Context, userID int64, name string) (*AccessToken, error) // DeleteByID deletes the access token by given ID. - // 🚨 SECURITY: The "userID" is required to prevent attacker - // deletes arbitrary access token that belongs to another user. - DeleteByID(userID, id int64) error - // GetBySHA1 returns the access token with given SHA1. - // It returns ErrAccessTokenNotExist when not found. - GetBySHA1(sha1 string) (*AccessToken, error) + // + // 🚨 SECURITY: The "userID" is required to prevent attacker deletes arbitrary + // access token that belongs to another user. + DeleteByID(ctx context.Context, userID, id int64) error + // GetBySHA1 returns the access token with given SHA1. It returns + // ErrAccessTokenNotExist when not found. + GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) // List returns all access tokens belongs to given user. - List(userID int64) ([]*AccessToken, error) - // Save persists all values of given access token. - // The Updated field is set to current time automatically. - Save(t *AccessToken) error + List(ctx context.Context, userID int64) ([]*AccessToken, error) + // Save persists all values of given access token. The Updated field is set to + // current time automatically. + Save(ctx context.Context, t *AccessToken) error } var AccessTokens AccessTokensStore @@ -42,17 +44,17 @@ var AccessTokens AccessTokensStore // AccessToken is a personal access token. type AccessToken struct { ID int64 - UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid;INDEX"` + UserID int64 `gorm:"column:uid;index"` Name string - Sha1 string `xorm:"UNIQUE VARCHAR(40)" gorm:"TYPE:VARCHAR(40);UNIQUE"` + Sha1 string `gorm:"type:VARCHAR(40);unique"` SHA256 string `gorm:"type:VARCHAR(64);unique;not null"` - Created time.Time `xorm:"-" gorm:"-" json:"-"` + Created time.Time `gorm:"-" json:"-"` CreatedUnix int64 - Updated time.Time `xorm:"-" gorm:"-" json:"-"` + Updated time.Time `gorm:"-" json:"-"` UpdatedUnix int64 - HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"` - HasUsed bool `xorm:"-" gorm:"-" json:"-"` + HasRecentActivity bool `gorm:"-" json:"-"` + HasUsed bool `gorm:"-" json:"-"` } // BeforeCreate implements the GORM create hook. @@ -97,8 +99,8 @@ func (err ErrAccessTokenAlreadyExist) Error() string { return fmt.Sprintf("access token already exists: %v", err.args) } -func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) { - err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error +func (db *accessTokens) Create(ctx context.Context, userID int64, name string) (*AccessToken, error) { + err := db.WithContext(ctx).Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error if err == nil { return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}} } else if err != gorm.ErrRecordNotFound { @@ -114,7 +116,7 @@ func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) Sha1: sha256[:40], // To pass the column unique constraint, keep the length of SHA1. SHA256: sha256, } - if err = db.DB.Create(accessToken).Error; err != nil { + if err = db.DB.WithContext(ctx).Create(accessToken).Error; err != nil { return nil, err } @@ -123,8 +125,8 @@ func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) return accessToken, nil } -func (db *accessTokens) DeleteByID(userID, id int64) error { - return db.Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error +func (db *accessTokens) DeleteByID(ctx context.Context, userID, id int64) error { + return db.WithContext(ctx).Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error } var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil) @@ -146,10 +148,10 @@ func (ErrAccessTokenNotExist) NotFound() bool { return true } -func (db *accessTokens) GetBySHA1(sha1 string) (*AccessToken, error) { +func (db *accessTokens) GetBySHA1(ctx context.Context, sha1 string) (*AccessToken, error) { sha256 := cryptoutil.SHA256(sha1) token := new(AccessToken) - err := db.Where("sha256 = ?", sha256).First(token).Error + err := db.WithContext(ctx).Where("sha256 = ?", sha256).First(token).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrAccessTokenNotExist{args: errutil.Args{"sha": sha1}} @@ -159,11 +161,11 @@ func (db *accessTokens) GetBySHA1(sha1 string) (*AccessToken, error) { return token, nil } -func (db *accessTokens) List(userID int64) ([]*AccessToken, error) { +func (db *accessTokens) List(ctx context.Context, userID int64) ([]*AccessToken, error) { var tokens []*AccessToken - return tokens, db.Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error + return tokens, db.WithContext(ctx).Where("uid = ?", userID).Order("id ASC").Find(&tokens).Error } -func (db *accessTokens) Save(t *AccessToken) error { - return db.DB.Save(t).Error +func (db *accessTokens) Save(ctx context.Context, t *AccessToken) error { + return db.DB.WithContext(ctx).Save(t).Error } |