aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-06-08 19:26:20 +0800
committerGitHub <noreply@github.com>2022-06-08 19:26:20 +0800
commit7229dd893f15ae30d20332706e40d8a87e0f94b0 (patch)
tree1e60a6751834efc257a4ab31c69b55bae3959e02 /internal/db
parent0918d8758b7470c5e1f64a62c2e48e4168993394 (diff)
db: use `context` and go-mockgen for `PermsStore` (#7033)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/access_tokens_test.go12
-rw-r--r--internal/db/mock_gen.go22
-rw-r--r--internal/db/mocks.go407
-rw-r--r--internal/db/perms.go40
-rw-r--r--internal/db/perms_test.go266
-rw-r--r--internal/db/repo.go3
-rw-r--r--internal/db/repo_branch.go3
-rw-r--r--internal/db/ssh_key.go3
-rw-r--r--internal/db/user.go7
9 files changed, 581 insertions, 182 deletions
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go
index 38c41e2d..b135a7b7 100644
--- a/internal/db/access_tokens_test.go
+++ b/internal/db/access_tokens_test.go
@@ -66,9 +66,7 @@ func TestAccessTokens(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB, tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
})
tc.test(t, db)
})
@@ -123,8 +121,8 @@ func accessTokensDeleteByID(t *testing.T, db *accessTokens) {
// We should get token not found error
_, err = db.GetBySHA1(ctx, token.Sha1)
- expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}}
+ assert.Equal(t, wantErr, err)
}
func accessTokensGetBySHA(t *testing.T, db *accessTokens) {
@@ -140,8 +138,8 @@ func accessTokensGetBySHA(t *testing.T, db *accessTokens) {
// Try to get a non-existent token
_, err = db.GetBySHA1(ctx, "bad_sha")
- expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}}
+ assert.Equal(t, wantErr, err)
}
func accessTokensList(t *testing.T, db *accessTokens) {
diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go
index 564dcc57..ce347a63 100644
--- a/internal/db/mock_gen.go
+++ b/internal/db/mock_gen.go
@@ -10,7 +10,7 @@ import (
"gogs.io/gogs/internal/lfsutil"
)
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i PermsStore -o mocks.go
func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
before := AccessTokens
@@ -101,26 +101,6 @@ func (m *mockLoginSourceFileStore) Save() error {
return m.MockSave()
}
-var _ PermsStore = (*MockPermsStore)(nil)
-
-type MockPermsStore struct {
- MockAccessMode func(userID, repoID int64, opts AccessModeOptions) AccessMode
- MockAuthorize func(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool
- MockSetRepoPerms func(repoID int64, accessMap map[int64]AccessMode) error
-}
-
-func (m *MockPermsStore) AccessMode(userID, repoID int64, opts AccessModeOptions) AccessMode {
- return m.MockAccessMode(userID, repoID, opts)
-}
-
-func (m *MockPermsStore) Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
- return m.MockAuthorize(userID, repoID, desired, opts)
-}
-
-func (m *MockPermsStore) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error {
- return m.MockSetRepoPerms(repoID, accessMap)
-}
-
func SetMockPermsStore(t *testing.T, mock PermsStore) {
before := Perms
Perms = mock
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index eaecf7b6..e969d83f 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -657,3 +657,410 @@ func (c AccessTokensStoreTouchFuncCall) Args() []interface{} {
func (c AccessTokensStoreTouchFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
+
+// MockPermsStore is a mock implementation of the PermsStore interface (from
+// the package gogs.io/gogs/internal/db) used for unit testing.
+type MockPermsStore struct {
+ // AccessModeFunc is an instance of a mock function object controlling
+ // the behavior of the method AccessMode.
+ AccessModeFunc *PermsStoreAccessModeFunc
+ // AuthorizeFunc is an instance of a mock function object controlling
+ // the behavior of the method Authorize.
+ AuthorizeFunc *PermsStoreAuthorizeFunc
+ // SetRepoPermsFunc is an instance of a mock function object controlling
+ // the behavior of the method SetRepoPerms.
+ SetRepoPermsFunc *PermsStoreSetRepoPermsFunc
+}
+
+// NewMockPermsStore creates a new mock of the PermsStore interface. All
+// methods return zero values for all results, unless overwritten.
+func NewMockPermsStore() *MockPermsStore {
+ return &MockPermsStore{
+ AccessModeFunc: &PermsStoreAccessModeFunc{
+ defaultHook: func(context.Context, int64, int64, AccessModeOptions) (r0 AccessMode) {
+ return
+ },
+ },
+ AuthorizeFunc: &PermsStoreAuthorizeFunc{
+ defaultHook: func(context.Context, int64, int64, AccessMode, AccessModeOptions) (r0 bool) {
+ return
+ },
+ },
+ SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{
+ defaultHook: func(context.Context, int64, map[int64]AccessMode) (r0 error) {
+ return
+ },
+ },
+ }
+}
+
+// NewStrictMockPermsStore creates a new mock of the PermsStore interface.
+// All methods panic on invocation, unless overwritten.
+func NewStrictMockPermsStore() *MockPermsStore {
+ return &MockPermsStore{
+ AccessModeFunc: &PermsStoreAccessModeFunc{
+ defaultHook: func(context.Context, int64, int64, AccessModeOptions) AccessMode {
+ panic("unexpected invocation of MockPermsStore.AccessMode")
+ },
+ },
+ AuthorizeFunc: &PermsStoreAuthorizeFunc{
+ defaultHook: func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool {
+ panic("unexpected invocation of MockPermsStore.Authorize")
+ },
+ },
+ SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{
+ defaultHook: func(context.Context, int64, map[int64]AccessMode) error {
+ panic("unexpected invocation of MockPermsStore.SetRepoPerms")
+ },
+ },
+ }
+}
+
+// NewMockPermsStoreFrom creates a new mock of the MockPermsStore interface.
+// All methods delegate to the given implementation, unless overwritten.
+func NewMockPermsStoreFrom(i PermsStore) *MockPermsStore {
+ return &MockPermsStore{
+ AccessModeFunc: &PermsStoreAccessModeFunc{
+ defaultHook: i.AccessMode,
+ },
+ AuthorizeFunc: &PermsStoreAuthorizeFunc{
+ defaultHook: i.Authorize,
+ },
+ SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{
+ defaultHook: i.SetRepoPerms,
+ },
+ }
+}
+
+// PermsStoreAccessModeFunc describes the behavior when the AccessMode
+// method of the parent MockPermsStore instance is invoked.
+type PermsStoreAccessModeFunc struct {
+ defaultHook func(context.Context, int64, int64, AccessModeOptions) AccessMode
+ hooks []func(context.Context, int64, int64, AccessModeOptions) AccessMode
+ history []PermsStoreAccessModeFuncCall
+ mutex sync.Mutex
+}
+
+// AccessMode delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockPermsStore) AccessMode(v0 context.Context, v1 int64, v2 int64, v3 AccessModeOptions) AccessMode {
+ r0 := m.AccessModeFunc.nextHook()(v0, v1, v2, v3)
+ m.AccessModeFunc.appendCall(PermsStoreAccessModeFuncCall{v0, v1, v2, v3, r0})
+ return r0
+}
+
+// SetDefaultHook sets function that is called when the AccessMode method of
+// the parent MockPermsStore instance is invoked and the hook queue is
+// empty.
+func (f *PermsStoreAccessModeFunc) SetDefaultHook(hook func(context.Context, int64, int64, AccessModeOptions) AccessMode) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// AccessMode method of the parent MockPermsStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *PermsStoreAccessModeFunc) PushHook(hook func(context.Context, int64, int64, AccessModeOptions) AccessMode) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *PermsStoreAccessModeFunc) SetDefaultReturn(r0 AccessMode) {
+ f.SetDefaultHook(func(context.Context, int64, int64, AccessModeOptions) AccessMode {
+ return r0
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *PermsStoreAccessModeFunc) PushReturn(r0 AccessMode) {
+ f.PushHook(func(context.Context, int64, int64, AccessModeOptions) AccessMode {
+ return r0
+ })
+}
+
+func (f *PermsStoreAccessModeFunc) nextHook() func(context.Context, int64, int64, AccessModeOptions) AccessMode {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *PermsStoreAccessModeFunc) appendCall(r0 PermsStoreAccessModeFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of PermsStoreAccessModeFuncCall objects
+// describing the invocations of this function.
+func (f *PermsStoreAccessModeFunc) History() []PermsStoreAccessModeFuncCall {
+ f.mutex.Lock()
+ history := make([]PermsStoreAccessModeFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// PermsStoreAccessModeFuncCall is an object that describes an invocation of
+// method AccessMode on an instance of MockPermsStore.
+type PermsStoreAccessModeFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 int64
+ // Arg3 is the value of the 4th argument passed to this method
+ // invocation.
+ Arg3 AccessModeOptions
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 AccessMode
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c PermsStoreAccessModeFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c PermsStoreAccessModeFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0}
+}
+
+// PermsStoreAuthorizeFunc describes the behavior when the Authorize method
+// of the parent MockPermsStore instance is invoked.
+type PermsStoreAuthorizeFunc struct {
+ defaultHook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool
+ hooks []func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool
+ history []PermsStoreAuthorizeFuncCall
+ mutex sync.Mutex
+}
+
+// Authorize delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockPermsStore) Authorize(v0 context.Context, v1 int64, v2 int64, v3 AccessMode, v4 AccessModeOptions) bool {
+ r0 := m.AuthorizeFunc.nextHook()(v0, v1, v2, v3, v4)
+ m.AuthorizeFunc.appendCall(PermsStoreAuthorizeFuncCall{v0, v1, v2, v3, v4, r0})
+ return r0
+}
+
+// SetDefaultHook sets function that is called when the Authorize method of
+// the parent MockPermsStore instance is invoked and the hook queue is
+// empty.
+func (f *PermsStoreAuthorizeFunc) SetDefaultHook(hook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Authorize method of the parent MockPermsStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *PermsStoreAuthorizeFunc) PushHook(hook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *PermsStoreAuthorizeFunc) SetDefaultReturn(r0 bool) {
+ f.SetDefaultHook(func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool {
+ return r0
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *PermsStoreAuthorizeFunc) PushReturn(r0 bool) {
+ f.PushHook(func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool {
+ return r0
+ })
+}
+
+func (f *PermsStoreAuthorizeFunc) nextHook() func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *PermsStoreAuthorizeFunc) appendCall(r0 PermsStoreAuthorizeFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of PermsStoreAuthorizeFuncCall objects
+// describing the invocations of this function.
+func (f *PermsStoreAuthorizeFunc) History() []PermsStoreAuthorizeFuncCall {
+ f.mutex.Lock()
+ history := make([]PermsStoreAuthorizeFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// PermsStoreAuthorizeFuncCall is an object that describes an invocation of
+// method Authorize on an instance of MockPermsStore.
+type PermsStoreAuthorizeFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 int64
+ // Arg3 is the value of the 4th argument passed to this method
+ // invocation.
+ Arg3 AccessMode
+ // Arg4 is the value of the 5th argument passed to this method
+ // invocation.
+ Arg4 AccessModeOptions
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c PermsStoreAuthorizeFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3, c.Arg4}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c PermsStoreAuthorizeFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0}
+}
+
+// PermsStoreSetRepoPermsFunc describes the behavior when the SetRepoPerms
+// method of the parent MockPermsStore instance is invoked.
+type PermsStoreSetRepoPermsFunc struct {
+ defaultHook func(context.Context, int64, map[int64]AccessMode) error
+ hooks []func(context.Context, int64, map[int64]AccessMode) error
+ history []PermsStoreSetRepoPermsFuncCall
+ mutex sync.Mutex
+}
+
+// SetRepoPerms delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockPermsStore) SetRepoPerms(v0 context.Context, v1 int64, v2 map[int64]AccessMode) error {
+ r0 := m.SetRepoPermsFunc.nextHook()(v0, v1, v2)
+ m.SetRepoPermsFunc.appendCall(PermsStoreSetRepoPermsFuncCall{v0, v1, v2, r0})
+ return r0
+}
+
+// SetDefaultHook sets function that is called when the SetRepoPerms method
+// of the parent MockPermsStore instance is invoked and the hook queue is
+// empty.
+func (f *PermsStoreSetRepoPermsFunc) SetDefaultHook(hook func(context.Context, int64, map[int64]AccessMode) error) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// SetRepoPerms method of the parent MockPermsStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *PermsStoreSetRepoPermsFunc) PushHook(hook func(context.Context, int64, map[int64]AccessMode) error) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *PermsStoreSetRepoPermsFunc) SetDefaultReturn(r0 error) {
+ f.SetDefaultHook(func(context.Context, int64, map[int64]AccessMode) error {
+ return r0
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *PermsStoreSetRepoPermsFunc) PushReturn(r0 error) {
+ f.PushHook(func(context.Context, int64, map[int64]AccessMode) error {
+ return r0
+ })
+}
+
+func (f *PermsStoreSetRepoPermsFunc) nextHook() func(context.Context, int64, map[int64]AccessMode) error {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *PermsStoreSetRepoPermsFunc) appendCall(r0 PermsStoreSetRepoPermsFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of PermsStoreSetRepoPermsFuncCall objects
+// describing the invocations of this function.
+func (f *PermsStoreSetRepoPermsFunc) History() []PermsStoreSetRepoPermsFuncCall {
+ f.mutex.Lock()
+ history := make([]PermsStoreSetRepoPermsFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// PermsStoreSetRepoPermsFuncCall is an object that describes an invocation
+// of method SetRepoPerms on an instance of MockPermsStore.
+type PermsStoreSetRepoPermsFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 map[int64]AccessMode
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c PermsStoreSetRepoPermsFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0}
+}
diff --git a/internal/db/perms.go b/internal/db/perms.go
index 4a9e6330..e3de37e0 100644
--- a/internal/db/perms.go
+++ b/internal/db/perms.go
@@ -5,6 +5,8 @@
package db
import (
+ "context"
+
"gorm.io/gorm"
log "unknwon.dev/clog/v2"
)
@@ -14,24 +16,26 @@ import (
// NOTE: All methods are sorted in alphabetical order.
type PermsStore interface {
// AccessMode returns the access mode of given user has to the repository.
- AccessMode(userID, repoID int64, opts AccessModeOptions) AccessMode
- // Authorize returns true if the user has as good as desired access mode to the repository.
- Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool
- // SetRepoPerms does a full update to which users have which level of access to given repository.
- // Keys of the "accessMap" are user IDs.
- SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error
+ AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) AccessMode
+ // Authorize returns true if the user has as good as desired access mode to the
+ // repository.
+ Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool
+ // SetRepoPerms does a full update to which users have which level of access to
+ // given repository. Keys of the "accessMap" are user IDs.
+ SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error
}
var Perms PermsStore
-// Access represents the highest access level of a user has to a repository.
-// The only access type that is not in this table is the real owner of a repository.
-// In case of an organization repository, the members of the owners team are in this table.
+// Access represents the highest access level of a user has to a repository. The
+// only access type that is not in this table is the real owner of a repository.
+// In case of an organization repository, the members of the owners team are in
+// this table.
type Access struct {
ID int64
- UserID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;NOT NULL"`
- RepoID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;NOT NULL"`
- Mode AccessMode `gorm:"NOT NULL"`
+ UserID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;not null"`
+ RepoID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;not null"`
+ Mode AccessMode `gorm:"not null"`
}
// AccessMode is the access mode of a user has to a repository.
@@ -83,7 +87,7 @@ type AccessModeOptions struct {
Private bool // Whether the repository is private.
}
-func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
+func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) {
if repoID <= 0 {
return AccessModeNone
}
@@ -103,7 +107,7 @@ func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode
}
access := new(Access)
- err := db.Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
+ err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error
if err != nil {
if err != gorm.ErrRecordNotFound {
log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err)
@@ -113,11 +117,11 @@ func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode
return access.Mode
}
-func (db *perms) Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
- return desired <= db.AccessMode(userID, repoID, opts)
+func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool {
+ return desired <= db.AccessMode(ctx, userID, repoID, opts)
}
-func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error {
+func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error {
records := make([]*Access, 0, len(accessMap))
for userID, mode := range accessMap {
records = append(records, &Access{
@@ -127,7 +131,7 @@ func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) erro
})
}
- return db.Transaction(func(tx *gorm.DB) error {
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error
if err != nil {
return err
diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go
index 9b821e6b..d7ddb6d5 100644
--- a/internal/db/perms_test.go
+++ b/internal/db/perms_test.go
@@ -5,12 +5,14 @@
package db
import (
+ "context"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-func Test_perms(t *testing.T) {
+func TestPerms(t *testing.T) {
if testing.Short() {
t.Skip()
}
@@ -26,16 +28,14 @@ func Test_perms(t *testing.T) {
name string
test func(*testing.T, *perms)
}{
- {"AccessMode", test_perms_AccessMode},
- {"Authorize", test_perms_Authorize},
- {"SetRepoPerms", test_perms_SetRepoPerms},
+ {"AccessMode", permsAccessMode},
+ {"Authorize", permsAuthorize},
+ {"SetRepoPerms", permsSetRepoPerms},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB, tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
})
tc.test(t, db)
})
@@ -45,21 +45,23 @@ func Test_perms(t *testing.T) {
}
}
-func test_perms_AccessMode(t *testing.T, db *perms) {
+func permsAccessMode(t *testing.T, db *perms) {
+ ctx := context.Background()
+
// Set up permissions
- err := db.SetRepoPerms(1, map[int64]AccessMode{
- 2: AccessModeWrite,
- 3: AccessModeAdmin,
- })
- if err != nil {
- t.Fatal(err)
- }
- err = db.SetRepoPerms(2, map[int64]AccessMode{
- 1: AccessModeRead,
- })
- if err != nil {
- t.Fatal(err)
- }
+ err := db.SetRepoPerms(ctx, 1,
+ map[int64]AccessMode{
+ 2: AccessModeWrite,
+ 3: AccessModeAdmin,
+ },
+ )
+ require.NoError(t, err)
+ err = db.SetRepoPerms(ctx, 2,
+ map[int64]AccessMode{
+ 1: AccessModeRead,
+ },
+ )
+ require.NoError(t, err)
publicRepoID := int64(1)
publicRepoOpts := AccessModeOptions{
@@ -73,99 +75,101 @@ func test_perms_AccessMode(t *testing.T, db *perms) {
}
tests := []struct {
- name string
- userID int64
- repoID int64
- opts AccessModeOptions
- expAccessMode AccessMode
+ name string
+ userID int64
+ repoID int64
+ opts AccessModeOptions
+ wantAccessMode AccessMode
}{
{
- name: "nil repository",
- expAccessMode: AccessModeNone,
+ name: "nil repository",
+ wantAccessMode: AccessModeNone,
},
{
- name: "anonymous user has read access to public repository",
- repoID: publicRepoID,
- opts: publicRepoOpts,
- expAccessMode: AccessModeRead,
+ name: "anonymous user has read access to public repository",
+ repoID: publicRepoID,
+ opts: publicRepoOpts,
+ wantAccessMode: AccessModeRead,
},
{
- name: "anonymous user has no access to private repository",
- repoID: privateRepoID,
- opts: privateRepoOpts,
- expAccessMode: AccessModeNone,
+ name: "anonymous user has no access to private repository",
+ repoID: privateRepoID,
+ opts: privateRepoOpts,
+ wantAccessMode: AccessModeNone,
},
{
- name: "user is the owner",
- userID: 98,
- repoID: publicRepoID,
- opts: publicRepoOpts,
- expAccessMode: AccessModeOwner,
+ name: "user is the owner",
+ userID: 98,
+ repoID: publicRepoID,
+ opts: publicRepoOpts,
+ wantAccessMode: AccessModeOwner,
},
{
- name: "user 1 has read access to public repo",
- userID: 1,
- repoID: publicRepoID,
- opts: publicRepoOpts,
- expAccessMode: AccessModeRead,
+ name: "user 1 has read access to public repo",
+ userID: 1,
+ repoID: publicRepoID,
+ opts: publicRepoOpts,
+ wantAccessMode: AccessModeRead,
},
{
- name: "user 2 has write access to public repo",
- userID: 2,
- repoID: publicRepoID,
- opts: publicRepoOpts,
- expAccessMode: AccessModeWrite,
+ name: "user 2 has write access to public repo",
+ userID: 2,
+ repoID: publicRepoID,
+ opts: publicRepoOpts,
+ wantAccessMode: AccessModeWrite,
},
{
- name: "user 3 has admin access to public repo",
- userID: 3,
- repoID: publicRepoID,
- opts: publicRepoOpts,
- expAccessMode: AccessModeAdmin,
+ name: "user 3 has admin access to public repo",
+ userID: 3,
+ repoID: publicRepoID,
+ opts: publicRepoOpts,
+ wantAccessMode: AccessModeAdmin,
},
{
- name: "user 1 has read access to private repo",
- userID: 1,
- repoID: privateRepoID,
- opts: privateRepoOpts,
- expAccessMode: AccessModeRead,
+ name: "user 1 has read access to private repo",
+ userID: 1,
+ repoID: privateRepoID,
+ opts: privateRepoOpts,
+ wantAccessMode: AccessModeRead,
},
{
- name: "user 2 has no access to private repo",
- userID: 2,
- repoID: privateRepoID,
- opts: privateRepoOpts,
- expAccessMode: AccessModeNone,
+ name: "user 2 has no access to private repo",
+ userID: 2,
+ repoID: privateRepoID,
+ opts: privateRepoOpts,
+ wantAccessMode: AccessModeNone,
},
{
- name: "user 3 has no access to private repo",
- userID: 3,
- repoID: privateRepoID,
- opts: privateRepoOpts,
- expAccessMode: AccessModeNone,
+ name: "user 3 has no access to private repo",
+ userID: 3,
+ repoID: privateRepoID,
+ opts: privateRepoOpts,
+ wantAccessMode: AccessModeNone,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mode := db.AccessMode(test.userID, test.repoID, test.opts)
- assert.Equal(t, test.expAccessMode, mode)
+ mode := db.AccessMode(ctx, test.userID, test.repoID, test.opts)
+ assert.Equal(t, test.wantAccessMode, mode)
})
}
}
-func test_perms_Authorize(t *testing.T, db *perms) {
+func permsAuthorize(t *testing.T, db *perms) {
+ ctx := context.Background()
+
// Set up permissions
- err := db.SetRepoPerms(1, map[int64]AccessMode{
- 1: AccessModeRead,
- 2: AccessModeWrite,
- 3: AccessModeAdmin,
- })
- if err != nil {
- t.Fatal(err)
- }
+ err := db.SetRepoPerms(ctx, 1,
+ map[int64]AccessMode{
+ 1: AccessModeRead,
+ 2: AccessModeWrite,
+ 3: AccessModeAdmin,
+ },
+ )
+ require.NoError(t, err)
repo := &Repository{
ID: 1,
@@ -173,74 +177,78 @@ func test_perms_Authorize(t *testing.T, db *perms) {
}
tests := []struct {
- name string
- userID int64
- desired AccessMode
- expAuthorized bool
+ name string
+ userID int64
+ desired AccessMode
+ wantAuthorized bool
}{
{
- name: "user 1 has read and wants read",
- userID: 1,
- desired: AccessModeRead,
- expAuthorized: true,
+ name: "user 1 has read and wants read",
+ userID: 1,
+ desired: AccessModeRead,
+ wantAuthorized: true,
},
{
- name: "user 1 has read and wants write",
- userID: 1,
- desired: AccessModeWrite,
- expAuthorized: false,
+ name: "user 1 has read and wants write",
+ userID: 1,
+ desired: AccessModeWrite,
+ wantAuthorized: false,
},
{
- name: "user 2 has write and wants read",
- userID: 2,
- desired: AccessModeRead,
- expAuthorized: true,
+ name: "user 2 has write and wants read",
+ userID: 2,
+ desired: AccessModeRead,
+ wantAuthorized: true,
},
{
- name: "user 2 has write and wants write",
- userID: 2,
- desired: AccessModeWrite,
- expAuthorized: true,
+ name: "user 2 has write and wants write",
+ userID: 2,
+ desired: AccessModeWrite,
+ wantAuthorized: true,
},
{
- name: "user 2 has write and wants admin",
- userID: 2,
- desired: AccessModeAdmin,
- expAuthorized: false,
+ name: "user 2 has write and wants admin",
+ userID: 2,
+ desired: AccessModeAdmin,
+ wantAuthorized: false,
},
{
- name: "user 3 has admin and wants read",
- userID: 3,
- desired: AccessModeRead,
- expAuthorized: true,
+ name: "user 3 has admin and wants read",
+ userID: 3,
+ desired: AccessModeRead,
+ wantAuthorized: true,
},
{
- name: "user 3 has admin and wants write",
- userID: 3,
- desired: AccessModeWrite,
- expAuthorized: true,
+ name: "user 3 has admin and wants write",
+ userID: 3,
+ desired: AccessModeWrite,
+ wantAuthorized: true,
},
{
- name: "user 3 has admin and wants admin",
- userID: 3,
- desired: AccessModeAdmin,
- expAuthorized: true,
+ name: "user 3 has admin and wants admin",
+ userID: 3,
+ desired: AccessModeAdmin,
+ wantAuthorized: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- authorized := db.Authorize(test.userID, repo.ID, test.desired, AccessModeOptions{
- OwnerID: repo.OwnerID,
- Private: repo.IsPrivate,
- })
- assert.Equal(t, test.expAuthorized, authorized)
+ authorized := db.Authorize(ctx, test.userID, repo.ID, test.desired,
+ AccessModeOptions{
+ OwnerID: repo.OwnerID,
+ Private: repo.IsPrivate,
+ },
+ )
+ assert.Equal(t, test.wantAuthorized, authorized)
})
}
}
-func test_perms_SetRepoPerms(t *testing.T, db *perms) {
+func permsSetRepoPerms(t *testing.T, db *perms) {
+ ctx := context.Background()
+
for _, update := range []struct {
repoID int64
accessMap map[int64]AccessMode
@@ -279,7 +287,7 @@ func test_perms_SetRepoPerms(t *testing.T, db *perms) {
},
},
} {
- err := db.SetRepoPerms(update.repoID, update.accessMap)
+ err := db.SetRepoPerms(ctx, update.repoID, update.accessMap)
if err != nil {
t.Fatal(err)
}
@@ -287,21 +295,19 @@ func test_perms_SetRepoPerms(t *testing.T, db *perms) {
var accesses []*Access
err := db.Order("user_id, repo_id").Find(&accesses).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
// Ignore ID fields
for _, a := range accesses {
a.ID = 0
}
- expAccesses := []*Access{
+ wantAccesses := []*Access{
{UserID: 1, RepoID: 2, Mode: AccessModeWrite},
{UserID: 2, RepoID: 1, Mode: AccessModeWrite},
{UserID: 2, RepoID: 2, Mode: AccessModeRead},
{UserID: 3, RepoID: 1, Mode: AccessModeAdmin},
{UserID: 5, RepoID: 2, Mode: AccessModeWrite},
}
- assert.Equal(t, expAccesses, accesses)
+ assert.Equal(t, wantAccesses, accesses)
}
diff --git a/internal/db/repo.go b/internal/db/repo.go
index e044712b..360e967c 100644
--- a/internal/db/repo.go
+++ b/internal/db/repo.go
@@ -6,6 +6,7 @@ package db
import (
"bytes"
+ "context"
"fmt"
"image"
_ "image/jpeg"
@@ -557,7 +558,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin
}
func (repo *Repository) HasAccess(userID int64) bool {
- return Perms.Authorize(userID, repo.ID, AccessModeRead,
+ return Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
diff --git a/internal/db/repo_branch.go b/internal/db/repo_branch.go
index 22c15b69..dc9e8795 100644
--- a/internal/db/repo_branch.go
+++ b/internal/db/repo_branch.go
@@ -5,6 +5,7 @@
package db
import (
+ "context"
"fmt"
"strings"
@@ -175,7 +176,7 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit
userIDs := tool.StringsToInt64s(strings.Split(whitelistUserIDs, ","))
validUserIDs = make([]int64, 0, len(userIDs))
for _, userID := range userIDs {
- if !Perms.Authorize(userID, repo.ID, AccessModeWrite,
+ if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeWrite,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
diff --git a/internal/db/ssh_key.go b/internal/db/ssh_key.go
index b53c66d6..d99a0dc6 100644
--- a/internal/db/ssh_key.go
+++ b/internal/db/ssh_key.go
@@ -5,6 +5,7 @@
package db
import (
+ "context"
"encoding/base64"
"encoding/binary"
"errors"
@@ -752,7 +753,7 @@ func DeleteDeployKey(doer *User, id int64) error {
if err != nil {
return fmt.Errorf("GetRepositoryByID: %v", err)
}
- if !Perms.Authorize(doer.ID, repo.ID, AccessModeAdmin,
+ if !Perms.Authorize(context.TODO(), doer.ID, repo.ID, AccessModeAdmin,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
diff --git a/internal/db/user.go b/internal/db/user.go
index d5d2d60f..c6bee120 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -6,6 +6,7 @@ package db
import (
"bytes"
+ "context"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
@@ -369,7 +370,7 @@ func (u *User) DeleteAvatar() error {
// IsAdminOfRepo returns true if user has admin or higher access of repository.
func (u *User) IsAdminOfRepo(repo *Repository) bool {
- return Perms.Authorize(u.ID, repo.ID, AccessModeAdmin,
+ return Perms.Authorize(context.TODO(), u.ID, repo.ID, AccessModeAdmin,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
@@ -379,7 +380,7 @@ func (u *User) IsAdminOfRepo(repo *Repository) bool {
// IsWriterOfRepo returns true if user has write access to given repository.
func (u *User) IsWriterOfRepo(repo *Repository) bool {
- return Perms.Authorize(u.ID, repo.ID, AccessModeWrite,
+ return Perms.Authorize(context.TODO(), u.ID, repo.ID, AccessModeWrite,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
@@ -941,7 +942,7 @@ func GetUserByID(id int64) (*User, error) {
// GetAssigneeByID returns the user with read access of repository by given ID.
func GetAssigneeByID(repo *Repository, userID int64) (*User, error) {
- if !Perms.Authorize(userID, repo.ID, AccessModeRead,
+ if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,