diff options
author | Joe Chen <jc@unknwon.io> | 2022-06-08 19:26:20 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-08 19:26:20 +0800 |
commit | 7229dd893f15ae30d20332706e40d8a87e0f94b0 (patch) | |
tree | 1e60a6751834efc257a4ab31c69b55bae3959e02 /internal/db | |
parent | 0918d8758b7470c5e1f64a62c2e48e4168993394 (diff) |
db: use `context` and go-mockgen for `PermsStore` (#7033)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/access_tokens_test.go | 12 | ||||
-rw-r--r-- | internal/db/mock_gen.go | 22 | ||||
-rw-r--r-- | internal/db/mocks.go | 407 | ||||
-rw-r--r-- | internal/db/perms.go | 40 | ||||
-rw-r--r-- | internal/db/perms_test.go | 266 | ||||
-rw-r--r-- | internal/db/repo.go | 3 | ||||
-rw-r--r-- | internal/db/repo_branch.go | 3 | ||||
-rw-r--r-- | internal/db/ssh_key.go | 3 | ||||
-rw-r--r-- | internal/db/user.go | 7 |
9 files changed, 581 insertions, 182 deletions
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go index 38c41e2d..b135a7b7 100644 --- a/internal/db/access_tokens_test.go +++ b/internal/db/access_tokens_test.go @@ -66,9 +66,7 @@ func TestAccessTokens(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -123,8 +121,8 @@ func accessTokensDeleteByID(t *testing.T, db *accessTokens) { // We should get token not found error _, err = db.GetBySHA1(ctx, token.Sha1) - expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}} - assert.Equal(t, expErr, err) + wantErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}} + assert.Equal(t, wantErr, err) } func accessTokensGetBySHA(t *testing.T, db *accessTokens) { @@ -140,8 +138,8 @@ func accessTokensGetBySHA(t *testing.T, db *accessTokens) { // Try to get a non-existent token _, err = db.GetBySHA1(ctx, "bad_sha") - expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}} - assert.Equal(t, expErr, err) + wantErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}} + assert.Equal(t, wantErr, err) } func accessTokensList(t *testing.T, db *accessTokens) { diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go index 564dcc57..ce347a63 100644 --- a/internal/db/mock_gen.go +++ b/internal/db/mock_gen.go @@ -10,7 +10,7 @@ import ( "gogs.io/gogs/internal/lfsutil" ) -//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -o mocks.go +//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i PermsStore -o mocks.go func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { before := AccessTokens @@ -101,26 +101,6 @@ func (m *mockLoginSourceFileStore) Save() error { return m.MockSave() } -var _ PermsStore = (*MockPermsStore)(nil) - -type MockPermsStore struct { - MockAccessMode func(userID, repoID int64, opts AccessModeOptions) AccessMode - MockAuthorize func(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool - MockSetRepoPerms func(repoID int64, accessMap map[int64]AccessMode) error -} - -func (m *MockPermsStore) AccessMode(userID, repoID int64, opts AccessModeOptions) AccessMode { - return m.MockAccessMode(userID, repoID, opts) -} - -func (m *MockPermsStore) Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { - return m.MockAuthorize(userID, repoID, desired, opts) -} - -func (m *MockPermsStore) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error { - return m.MockSetRepoPerms(repoID, accessMap) -} - func SetMockPermsStore(t *testing.T, mock PermsStore) { before := Perms Perms = mock diff --git a/internal/db/mocks.go b/internal/db/mocks.go index eaecf7b6..e969d83f 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -657,3 +657,410 @@ func (c AccessTokensStoreTouchFuncCall) Args() []interface{} { func (c AccessTokensStoreTouchFuncCall) Results() []interface{} { return []interface{}{c.Result0} } + +// MockPermsStore is a mock implementation of the PermsStore interface (from +// the package gogs.io/gogs/internal/db) used for unit testing. +type MockPermsStore struct { + // AccessModeFunc is an instance of a mock function object controlling + // the behavior of the method AccessMode. + AccessModeFunc *PermsStoreAccessModeFunc + // AuthorizeFunc is an instance of a mock function object controlling + // the behavior of the method Authorize. + AuthorizeFunc *PermsStoreAuthorizeFunc + // SetRepoPermsFunc is an instance of a mock function object controlling + // the behavior of the method SetRepoPerms. + SetRepoPermsFunc *PermsStoreSetRepoPermsFunc +} + +// NewMockPermsStore creates a new mock of the PermsStore interface. All +// methods return zero values for all results, unless overwritten. +func NewMockPermsStore() *MockPermsStore { + return &MockPermsStore{ + AccessModeFunc: &PermsStoreAccessModeFunc{ + defaultHook: func(context.Context, int64, int64, AccessModeOptions) (r0 AccessMode) { + return + }, + }, + AuthorizeFunc: &PermsStoreAuthorizeFunc{ + defaultHook: func(context.Context, int64, int64, AccessMode, AccessModeOptions) (r0 bool) { + return + }, + }, + SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{ + defaultHook: func(context.Context, int64, map[int64]AccessMode) (r0 error) { + return + }, + }, + } +} + +// NewStrictMockPermsStore creates a new mock of the PermsStore interface. +// All methods panic on invocation, unless overwritten. +func NewStrictMockPermsStore() *MockPermsStore { + return &MockPermsStore{ + AccessModeFunc: &PermsStoreAccessModeFunc{ + defaultHook: func(context.Context, int64, int64, AccessModeOptions) AccessMode { + panic("unexpected invocation of MockPermsStore.AccessMode") + }, + }, + AuthorizeFunc: &PermsStoreAuthorizeFunc{ + defaultHook: func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool { + panic("unexpected invocation of MockPermsStore.Authorize") + }, + }, + SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{ + defaultHook: func(context.Context, int64, map[int64]AccessMode) error { + panic("unexpected invocation of MockPermsStore.SetRepoPerms") + }, + }, + } +} + +// NewMockPermsStoreFrom creates a new mock of the MockPermsStore interface. +// All methods delegate to the given implementation, unless overwritten. +func NewMockPermsStoreFrom(i PermsStore) *MockPermsStore { + return &MockPermsStore{ + AccessModeFunc: &PermsStoreAccessModeFunc{ + defaultHook: i.AccessMode, + }, + AuthorizeFunc: &PermsStoreAuthorizeFunc{ + defaultHook: i.Authorize, + }, + SetRepoPermsFunc: &PermsStoreSetRepoPermsFunc{ + defaultHook: i.SetRepoPerms, + }, + } +} + +// PermsStoreAccessModeFunc describes the behavior when the AccessMode +// method of the parent MockPermsStore instance is invoked. +type PermsStoreAccessModeFunc struct { + defaultHook func(context.Context, int64, int64, AccessModeOptions) AccessMode + hooks []func(context.Context, int64, int64, AccessModeOptions) AccessMode + history []PermsStoreAccessModeFuncCall + mutex sync.Mutex +} + +// AccessMode delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockPermsStore) AccessMode(v0 context.Context, v1 int64, v2 int64, v3 AccessModeOptions) AccessMode { + r0 := m.AccessModeFunc.nextHook()(v0, v1, v2, v3) + m.AccessModeFunc.appendCall(PermsStoreAccessModeFuncCall{v0, v1, v2, v3, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the AccessMode method of +// the parent MockPermsStore instance is invoked and the hook queue is +// empty. +func (f *PermsStoreAccessModeFunc) SetDefaultHook(hook func(context.Context, int64, int64, AccessModeOptions) AccessMode) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// AccessMode method of the parent MockPermsStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *PermsStoreAccessModeFunc) PushHook(hook func(context.Context, int64, int64, AccessModeOptions) AccessMode) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *PermsStoreAccessModeFunc) SetDefaultReturn(r0 AccessMode) { + f.SetDefaultHook(func(context.Context, int64, int64, AccessModeOptions) AccessMode { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *PermsStoreAccessModeFunc) PushReturn(r0 AccessMode) { + f.PushHook(func(context.Context, int64, int64, AccessModeOptions) AccessMode { + return r0 + }) +} + +func (f *PermsStoreAccessModeFunc) nextHook() func(context.Context, int64, int64, AccessModeOptions) AccessMode { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *PermsStoreAccessModeFunc) appendCall(r0 PermsStoreAccessModeFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of PermsStoreAccessModeFuncCall objects +// describing the invocations of this function. +func (f *PermsStoreAccessModeFunc) History() []PermsStoreAccessModeFuncCall { + f.mutex.Lock() + history := make([]PermsStoreAccessModeFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// PermsStoreAccessModeFuncCall is an object that describes an invocation of +// method AccessMode on an instance of MockPermsStore. +type PermsStoreAccessModeFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 int64 + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 AccessModeOptions + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 AccessMode +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c PermsStoreAccessModeFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c PermsStoreAccessModeFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + +// PermsStoreAuthorizeFunc describes the behavior when the Authorize method +// of the parent MockPermsStore instance is invoked. +type PermsStoreAuthorizeFunc struct { + defaultHook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool + hooks []func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool + history []PermsStoreAuthorizeFuncCall + mutex sync.Mutex +} + +// Authorize delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockPermsStore) Authorize(v0 context.Context, v1 int64, v2 int64, v3 AccessMode, v4 AccessModeOptions) bool { + r0 := m.AuthorizeFunc.nextHook()(v0, v1, v2, v3, v4) + m.AuthorizeFunc.appendCall(PermsStoreAuthorizeFuncCall{v0, v1, v2, v3, v4, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the Authorize method of +// the parent MockPermsStore instance is invoked and the hook queue is +// empty. +func (f *PermsStoreAuthorizeFunc) SetDefaultHook(hook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Authorize method of the parent MockPermsStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *PermsStoreAuthorizeFunc) PushHook(hook func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *PermsStoreAuthorizeFunc) SetDefaultReturn(r0 bool) { + f.SetDefaultHook(func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *PermsStoreAuthorizeFunc) PushReturn(r0 bool) { + f.PushHook(func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool { + return r0 + }) +} + +func (f *PermsStoreAuthorizeFunc) nextHook() func(context.Context, int64, int64, AccessMode, AccessModeOptions) bool { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *PermsStoreAuthorizeFunc) appendCall(r0 PermsStoreAuthorizeFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of PermsStoreAuthorizeFuncCall objects +// describing the invocations of this function. +func (f *PermsStoreAuthorizeFunc) History() []PermsStoreAuthorizeFuncCall { + f.mutex.Lock() + history := make([]PermsStoreAuthorizeFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// PermsStoreAuthorizeFuncCall is an object that describes an invocation of +// method Authorize on an instance of MockPermsStore. +type PermsStoreAuthorizeFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 int64 + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 AccessMode + // Arg4 is the value of the 5th argument passed to this method + // invocation. + Arg4 AccessModeOptions + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 bool +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c PermsStoreAuthorizeFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3, c.Arg4} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c PermsStoreAuthorizeFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + +// PermsStoreSetRepoPermsFunc describes the behavior when the SetRepoPerms +// method of the parent MockPermsStore instance is invoked. +type PermsStoreSetRepoPermsFunc struct { + defaultHook func(context.Context, int64, map[int64]AccessMode) error + hooks []func(context.Context, int64, map[int64]AccessMode) error + history []PermsStoreSetRepoPermsFuncCall + mutex sync.Mutex +} + +// SetRepoPerms delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockPermsStore) SetRepoPerms(v0 context.Context, v1 int64, v2 map[int64]AccessMode) error { + r0 := m.SetRepoPermsFunc.nextHook()(v0, v1, v2) + m.SetRepoPermsFunc.appendCall(PermsStoreSetRepoPermsFuncCall{v0, v1, v2, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the SetRepoPerms method +// of the parent MockPermsStore instance is invoked and the hook queue is +// empty. +func (f *PermsStoreSetRepoPermsFunc) SetDefaultHook(hook func(context.Context, int64, map[int64]AccessMode) error) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// SetRepoPerms method of the parent MockPermsStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *PermsStoreSetRepoPermsFunc) PushHook(hook func(context.Context, int64, map[int64]AccessMode) error) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *PermsStoreSetRepoPermsFunc) SetDefaultReturn(r0 error) { + f.SetDefaultHook(func(context.Context, int64, map[int64]AccessMode) error { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *PermsStoreSetRepoPermsFunc) PushReturn(r0 error) { + f.PushHook(func(context.Context, int64, map[int64]AccessMode) error { + return r0 + }) +} + +func (f *PermsStoreSetRepoPermsFunc) nextHook() func(context.Context, int64, map[int64]AccessMode) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *PermsStoreSetRepoPermsFunc) appendCall(r0 PermsStoreSetRepoPermsFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of PermsStoreSetRepoPermsFuncCall objects +// describing the invocations of this function. +func (f *PermsStoreSetRepoPermsFunc) History() []PermsStoreSetRepoPermsFuncCall { + f.mutex.Lock() + history := make([]PermsStoreSetRepoPermsFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// PermsStoreSetRepoPermsFuncCall is an object that describes an invocation +// of method SetRepoPerms on an instance of MockPermsStore. +type PermsStoreSetRepoPermsFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 map[int64]AccessMode + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c PermsStoreSetRepoPermsFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} diff --git a/internal/db/perms.go b/internal/db/perms.go index 4a9e6330..e3de37e0 100644 --- a/internal/db/perms.go +++ b/internal/db/perms.go @@ -5,6 +5,8 @@ package db import ( + "context" + "gorm.io/gorm" log "unknwon.dev/clog/v2" ) @@ -14,24 +16,26 @@ import ( // NOTE: All methods are sorted in alphabetical order. type PermsStore interface { // AccessMode returns the access mode of given user has to the repository. - AccessMode(userID, repoID int64, opts AccessModeOptions) AccessMode - // Authorize returns true if the user has as good as desired access mode to the repository. - Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool - // SetRepoPerms does a full update to which users have which level of access to given repository. - // Keys of the "accessMap" are user IDs. - SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error + AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) AccessMode + // Authorize returns true if the user has as good as desired access mode to the + // repository. + Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool + // SetRepoPerms does a full update to which users have which level of access to + // given repository. Keys of the "accessMap" are user IDs. + SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error } var Perms PermsStore -// Access represents the highest access level of a user has to a repository. -// The only access type that is not in this table is the real owner of a repository. -// In case of an organization repository, the members of the owners team are in this table. +// Access represents the highest access level of a user has to a repository. The +// only access type that is not in this table is the real owner of a repository. +// In case of an organization repository, the members of the owners team are in +// this table. type Access struct { ID int64 - UserID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;NOT NULL"` - RepoID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;NOT NULL"` - Mode AccessMode `gorm:"NOT NULL"` + UserID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;not null"` + RepoID int64 `xorm:"UNIQUE(s)" gorm:"uniqueIndex:access_user_repo_unique;not null"` + Mode AccessMode `gorm:"not null"` } // AccessMode is the access mode of a user has to a repository. @@ -83,7 +87,7 @@ type AccessModeOptions struct { Private bool // Whether the repository is private. } -func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode AccessMode) { +func (db *perms) AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) (mode AccessMode) { if repoID <= 0 { return AccessModeNone } @@ -103,7 +107,7 @@ func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode } access := new(Access) - err := db.Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error + err := db.WithContext(ctx).Where("user_id = ? AND repo_id = ?", userID, repoID).First(access).Error if err != nil { if err != gorm.ErrRecordNotFound { log.Error("Failed to get access [user_id: %d, repo_id: %d]: %v", userID, repoID, err) @@ -113,11 +117,11 @@ func (db *perms) AccessMode(userID, repoID int64, opts AccessModeOptions) (mode return access.Mode } -func (db *perms) Authorize(userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { - return desired <= db.AccessMode(userID, repoID, opts) +func (db *perms) Authorize(ctx context.Context, userID, repoID int64, desired AccessMode, opts AccessModeOptions) bool { + return desired <= db.AccessMode(ctx, userID, repoID, opts) } -func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) error { +func (db *perms) SetRepoPerms(ctx context.Context, repoID int64, accessMap map[int64]AccessMode) error { records := make([]*Access, 0, len(accessMap)) for userID, mode := range accessMap { records = append(records, &Access{ @@ -127,7 +131,7 @@ func (db *perms) SetRepoPerms(repoID int64, accessMap map[int64]AccessMode) erro }) } - return db.Transaction(func(tx *gorm.DB) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Where("repo_id = ?", repoID).Delete(new(Access)).Error if err != nil { return err diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go index 9b821e6b..d7ddb6d5 100644 --- a/internal/db/perms_test.go +++ b/internal/db/perms_test.go @@ -5,12 +5,14 @@ package db import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func Test_perms(t *testing.T) { +func TestPerms(t *testing.T) { if testing.Short() { t.Skip() } @@ -26,16 +28,14 @@ func Test_perms(t *testing.T) { name string test func(*testing.T, *perms) }{ - {"AccessMode", test_perms_AccessMode}, - {"Authorize", test_perms_Authorize}, - {"SetRepoPerms", test_perms_SetRepoPerms}, + {"AccessMode", permsAccessMode}, + {"Authorize", permsAuthorize}, + {"SetRepoPerms", permsSetRepoPerms}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -45,21 +45,23 @@ func Test_perms(t *testing.T) { } } -func test_perms_AccessMode(t *testing.T, db *perms) { +func permsAccessMode(t *testing.T, db *perms) { + ctx := context.Background() + // Set up permissions - err := db.SetRepoPerms(1, map[int64]AccessMode{ - 2: AccessModeWrite, - 3: AccessModeAdmin, - }) - if err != nil { - t.Fatal(err) - } - err = db.SetRepoPerms(2, map[int64]AccessMode{ - 1: AccessModeRead, - }) - if err != nil { - t.Fatal(err) - } + err := db.SetRepoPerms(ctx, 1, + map[int64]AccessMode{ + 2: AccessModeWrite, + 3: AccessModeAdmin, + }, + ) + require.NoError(t, err) + err = db.SetRepoPerms(ctx, 2, + map[int64]AccessMode{ + 1: AccessModeRead, + }, + ) + require.NoError(t, err) publicRepoID := int64(1) publicRepoOpts := AccessModeOptions{ @@ -73,99 +75,101 @@ func test_perms_AccessMode(t *testing.T, db *perms) { } tests := []struct { - name string - userID int64 - repoID int64 - opts AccessModeOptions - expAccessMode AccessMode + name string + userID int64 + repoID int64 + opts AccessModeOptions + wantAccessMode AccessMode }{ { - name: "nil repository", - expAccessMode: AccessModeNone, + name: "nil repository", + wantAccessMode: AccessModeNone, }, { - name: "anonymous user has read access to public repository", - repoID: publicRepoID, - opts: publicRepoOpts, - expAccessMode: AccessModeRead, + name: "anonymous user has read access to public repository", + repoID: publicRepoID, + opts: publicRepoOpts, + wantAccessMode: AccessModeRead, }, { - name: "anonymous user has no access to private repository", - repoID: privateRepoID, - opts: privateRepoOpts, - expAccessMode: AccessModeNone, + name: "anonymous user has no access to private repository", + repoID: privateRepoID, + opts: privateRepoOpts, + wantAccessMode: AccessModeNone, }, { - name: "user is the owner", - userID: 98, - repoID: publicRepoID, - opts: publicRepoOpts, - expAccessMode: AccessModeOwner, + name: "user is the owner", + userID: 98, + repoID: publicRepoID, + opts: publicRepoOpts, + wantAccessMode: AccessModeOwner, }, { - name: "user 1 has read access to public repo", - userID: 1, - repoID: publicRepoID, - opts: publicRepoOpts, - expAccessMode: AccessModeRead, + name: "user 1 has read access to public repo", + userID: 1, + repoID: publicRepoID, + opts: publicRepoOpts, + wantAccessMode: AccessModeRead, }, { - name: "user 2 has write access to public repo", - userID: 2, - repoID: publicRepoID, - opts: publicRepoOpts, - expAccessMode: AccessModeWrite, + name: "user 2 has write access to public repo", + userID: 2, + repoID: publicRepoID, + opts: publicRepoOpts, + wantAccessMode: AccessModeWrite, }, { - name: "user 3 has admin access to public repo", - userID: 3, - repoID: publicRepoID, - opts: publicRepoOpts, - expAccessMode: AccessModeAdmin, + name: "user 3 has admin access to public repo", + userID: 3, + repoID: publicRepoID, + opts: publicRepoOpts, + wantAccessMode: AccessModeAdmin, }, { - name: "user 1 has read access to private repo", - userID: 1, - repoID: privateRepoID, - opts: privateRepoOpts, - expAccessMode: AccessModeRead, + name: "user 1 has read access to private repo", + userID: 1, + repoID: privateRepoID, + opts: privateRepoOpts, + wantAccessMode: AccessModeRead, }, { - name: "user 2 has no access to private repo", - userID: 2, - repoID: privateRepoID, - opts: privateRepoOpts, - expAccessMode: AccessModeNone, + name: "user 2 has no access to private repo", + userID: 2, + repoID: privateRepoID, + opts: privateRepoOpts, + wantAccessMode: AccessModeNone, }, { - name: "user 3 has no access to private repo", - userID: 3, - repoID: privateRepoID, - opts: privateRepoOpts, - expAccessMode: AccessModeNone, + name: "user 3 has no access to private repo", + userID: 3, + repoID: privateRepoID, + opts: privateRepoOpts, + wantAccessMode: AccessModeNone, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mode := db.AccessMode(test.userID, test.repoID, test.opts) - assert.Equal(t, test.expAccessMode, mode) + mode := db.AccessMode(ctx, test.userID, test.repoID, test.opts) + assert.Equal(t, test.wantAccessMode, mode) }) } } -func test_perms_Authorize(t *testing.T, db *perms) { +func permsAuthorize(t *testing.T, db *perms) { + ctx := context.Background() + // Set up permissions - err := db.SetRepoPerms(1, map[int64]AccessMode{ - 1: AccessModeRead, - 2: AccessModeWrite, - 3: AccessModeAdmin, - }) - if err != nil { - t.Fatal(err) - } + err := db.SetRepoPerms(ctx, 1, + map[int64]AccessMode{ + 1: AccessModeRead, + 2: AccessModeWrite, + 3: AccessModeAdmin, + }, + ) + require.NoError(t, err) repo := &Repository{ ID: 1, @@ -173,74 +177,78 @@ func test_perms_Authorize(t *testing.T, db *perms) { } tests := []struct { - name string - userID int64 - desired AccessMode - expAuthorized bool + name string + userID int64 + desired AccessMode + wantAuthorized bool }{ { - name: "user 1 has read and wants read", - userID: 1, - desired: AccessModeRead, - expAuthorized: true, + name: "user 1 has read and wants read", + userID: 1, + desired: AccessModeRead, + wantAuthorized: true, }, { - name: "user 1 has read and wants write", - userID: 1, - desired: AccessModeWrite, - expAuthorized: false, + name: "user 1 has read and wants write", + userID: 1, + desired: AccessModeWrite, + wantAuthorized: false, }, { - name: "user 2 has write and wants read", - userID: 2, - desired: AccessModeRead, - expAuthorized: true, + name: "user 2 has write and wants read", + userID: 2, + desired: AccessModeRead, + wantAuthorized: true, }, { - name: "user 2 has write and wants write", - userID: 2, - desired: AccessModeWrite, - expAuthorized: true, + name: "user 2 has write and wants write", + userID: 2, + desired: AccessModeWrite, + wantAuthorized: true, }, { - name: "user 2 has write and wants admin", - userID: 2, - desired: AccessModeAdmin, - expAuthorized: false, + name: "user 2 has write and wants admin", + userID: 2, + desired: AccessModeAdmin, + wantAuthorized: false, }, { - name: "user 3 has admin and wants read", - userID: 3, - desired: AccessModeRead, - expAuthorized: true, + name: "user 3 has admin and wants read", + userID: 3, + desired: AccessModeRead, + wantAuthorized: true, }, { - name: "user 3 has admin and wants write", - userID: 3, - desired: AccessModeWrite, - expAuthorized: true, + name: "user 3 has admin and wants write", + userID: 3, + desired: AccessModeWrite, + wantAuthorized: true, }, { - name: "user 3 has admin and wants admin", - userID: 3, - desired: AccessModeAdmin, - expAuthorized: true, + name: "user 3 has admin and wants admin", + userID: 3, + desired: AccessModeAdmin, + wantAuthorized: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - authorized := db.Authorize(test.userID, repo.ID, test.desired, AccessModeOptions{ - OwnerID: repo.OwnerID, - Private: repo.IsPrivate, - }) - assert.Equal(t, test.expAuthorized, authorized) + authorized := db.Authorize(ctx, test.userID, repo.ID, test.desired, + AccessModeOptions{ + OwnerID: repo.OwnerID, + Private: repo.IsPrivate, + }, + ) + assert.Equal(t, test.wantAuthorized, authorized) }) } } -func test_perms_SetRepoPerms(t *testing.T, db *perms) { +func permsSetRepoPerms(t *testing.T, db *perms) { + ctx := context.Background() + for _, update := range []struct { repoID int64 accessMap map[int64]AccessMode @@ -279,7 +287,7 @@ func test_perms_SetRepoPerms(t *testing.T, db *perms) { }, }, } { - err := db.SetRepoPerms(update.repoID, update.accessMap) + err := db.SetRepoPerms(ctx, update.repoID, update.accessMap) if err != nil { t.Fatal(err) } @@ -287,21 +295,19 @@ func test_perms_SetRepoPerms(t *testing.T, db *perms) { var accesses []*Access err := db.Order("user_id, repo_id").Find(&accesses).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Ignore ID fields for _, a := range accesses { a.ID = 0 } - expAccesses := []*Access{ + wantAccesses := []*Access{ {UserID: 1, RepoID: 2, Mode: AccessModeWrite}, {UserID: 2, RepoID: 1, Mode: AccessModeWrite}, {UserID: 2, RepoID: 2, Mode: AccessModeRead}, {UserID: 3, RepoID: 1, Mode: AccessModeAdmin}, {UserID: 5, RepoID: 2, Mode: AccessModeWrite}, } - assert.Equal(t, expAccesses, accesses) + assert.Equal(t, wantAccesses, accesses) } diff --git a/internal/db/repo.go b/internal/db/repo.go index e044712b..360e967c 100644 --- a/internal/db/repo.go +++ b/internal/db/repo.go @@ -6,6 +6,7 @@ package db import ( "bytes" + "context" "fmt" "image" _ "image/jpeg" @@ -557,7 +558,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin } func (repo *Repository) HasAccess(userID int64) bool { - return Perms.Authorize(userID, repo.ID, AccessModeRead, + return Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, diff --git a/internal/db/repo_branch.go b/internal/db/repo_branch.go index 22c15b69..dc9e8795 100644 --- a/internal/db/repo_branch.go +++ b/internal/db/repo_branch.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" @@ -175,7 +176,7 @@ func UpdateOrgProtectBranch(repo *Repository, protectBranch *ProtectBranch, whit userIDs := tool.StringsToInt64s(strings.Split(whitelistUserIDs, ",")) validUserIDs = make([]int64, 0, len(userIDs)) for _, userID := range userIDs { - if !Perms.Authorize(userID, repo.ID, AccessModeWrite, + if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeWrite, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, diff --git a/internal/db/ssh_key.go b/internal/db/ssh_key.go index b53c66d6..d99a0dc6 100644 --- a/internal/db/ssh_key.go +++ b/internal/db/ssh_key.go @@ -5,6 +5,7 @@ package db import ( + "context" "encoding/base64" "encoding/binary" "errors" @@ -752,7 +753,7 @@ func DeleteDeployKey(doer *User, id int64) error { if err != nil { return fmt.Errorf("GetRepositoryByID: %v", err) } - if !Perms.Authorize(doer.ID, repo.ID, AccessModeAdmin, + if !Perms.Authorize(context.TODO(), doer.ID, repo.ID, AccessModeAdmin, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, diff --git a/internal/db/user.go b/internal/db/user.go index d5d2d60f..c6bee120 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -6,6 +6,7 @@ package db import ( "bytes" + "context" "crypto/sha256" "crypto/subtle" "encoding/hex" @@ -369,7 +370,7 @@ func (u *User) DeleteAvatar() error { // IsAdminOfRepo returns true if user has admin or higher access of repository. func (u *User) IsAdminOfRepo(repo *Repository) bool { - return Perms.Authorize(u.ID, repo.ID, AccessModeAdmin, + return Perms.Authorize(context.TODO(), u.ID, repo.ID, AccessModeAdmin, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, @@ -379,7 +380,7 @@ func (u *User) IsAdminOfRepo(repo *Repository) bool { // IsWriterOfRepo returns true if user has write access to given repository. func (u *User) IsWriterOfRepo(repo *Repository) bool { - return Perms.Authorize(u.ID, repo.ID, AccessModeWrite, + return Perms.Authorize(context.TODO(), u.ID, repo.ID, AccessModeWrite, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, @@ -941,7 +942,7 @@ func GetUserByID(id int64) (*User, error) { // GetAssigneeByID returns the user with read access of repository by given ID. func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { - if !Perms.Authorize(userID, repo.ID, AccessModeRead, + if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, |