diff options
-rw-r--r-- | internal/context/auth.go | 2 | ||||
-rw-r--r-- | internal/context/go_get.go | 2 | ||||
-rw-r--r-- | internal/db/login_source_files_test.go | 4 | ||||
-rw-r--r-- | internal/db/login_sources_test.go | 20 | ||||
-rw-r--r-- | internal/db/mock_gen.go | 32 | ||||
-rw-r--r-- | internal/db/mocks.go | 657 | ||||
-rw-r--r-- | internal/db/user.go | 5 | ||||
-rw-r--r-- | internal/db/users.go | 94 | ||||
-rw-r--r-- | internal/db/users_test.go | 216 | ||||
-rw-r--r-- | internal/route/lfs/route.go | 6 | ||||
-rw-r--r-- | internal/route/lfs/route_test.go | 82 | ||||
-rw-r--r-- | internal/route/org/setting.go | 2 | ||||
-rw-r--r-- | internal/route/repo/http.go | 6 | ||||
-rw-r--r-- | internal/route/repo/tasks.go | 4 | ||||
-rw-r--r-- | internal/route/user/auth.go | 2 | ||||
-rw-r--r-- | internal/route/user/setting.go | 2 |
16 files changed, 879 insertions, 257 deletions
diff --git a/internal/context/auth.go b/internal/context/auth.go index 25625c3d..cba68706 100644 --- a/internal/context/auth.go +++ b/internal/context/auth.go @@ -208,7 +208,7 @@ func authenticatedUser(ctx *macaron.Context, sess session.Store) (_ *db.User, is if len(auths) == 2 && auths[0] == "Basic" { uname, passwd, _ := tool.BasicAuthDecode(auths[1]) - u, err := db.Users.Authenticate(uname, passwd, -1) + u, err := db.Users.Authenticate(ctx.Req.Context(), uname, passwd, -1) if err != nil { if !auth.IsErrBadCredentials(err) { log.Error("Failed to authenticate user: %v", err) diff --git a/internal/context/go_get.go b/internal/context/go_get.go index 2851b5db..5a59e1b5 100644 --- a/internal/context/go_get.go +++ b/internal/context/go_get.go @@ -26,7 +26,7 @@ func ServeGoGet() macaron.Handler { repoName := c.Params(":reponame") branchName := "master" - owner, err := db.Users.GetByUsername(ownerName) + owner, err := db.Users.GetByUsername(c.Req.Context(), ownerName) if err == nil { repo, err := db.Repos.GetByName(owner.ID, repoName) if err == nil && repo.DefaultBranch != "" { diff --git a/internal/db/login_source_files_test.go b/internal/db/login_source_files_test.go index 9f66f582..6254340f 100644 --- a/internal/db/login_source_files_test.go +++ b/internal/db/login_source_files_test.go @@ -23,8 +23,8 @@ func TestLoginSourceFiles_GetByID(t *testing.T) { t.Run("source does not exist", func(t *testing.T) { _, err := store.GetByID(1) - expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": int64(1)}} - assert.Equal(t, expErr, err) + wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": int64(1)}} + assert.Equal(t, wantErr, err) }) t.Run("source exists", func(t *testing.T) { diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go index 9e14caf9..ad09a8db 100644 --- a/internal/db/login_sources_test.go +++ b/internal/db/login_sources_test.go @@ -138,8 +138,8 @@ func loginSourcesCreate(t *testing.T, db *loginSources) { // Try create second login source with same name should fail _, err = db.Create(ctx, CreateLoginSourceOpts{Name: source.Name}) - expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}} - assert.Equal(t, expErr, err) + wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}} + assert.Equal(t, wantErr, err) } func loginSourcesCount(t *testing.T, db *loginSources) { @@ -184,15 +184,17 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) { require.NoError(t, err) // Create a user that uses this login source - _, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{ - LoginSource: source.ID, - }) + _, err = (&users{DB: db.DB}).Create(ctx, "alice", "", + CreateUserOpts{ + LoginSource: source.ID, + }, + ) require.NoError(t, err) // Delete the login source will result in error err = db.DeleteByID(ctx, source.ID) - expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}} - assert.Equal(t, expErr, err) + wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}} + assert.Equal(t, wantErr, err) }) mock := NewMockLoginSourceFilesStore() @@ -229,8 +231,8 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) { // We should get token not found error _, err = db.GetByID(ctx, source.ID) - expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}} - assert.Equal(t, expErr, err) + wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}} + assert.Equal(t, wantErr, err) } func loginSourcesGetByID(t *testing.T, db *loginSources) { diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go index a73f10d0..7235c6b5 100644 --- a/internal/db/mock_gen.go +++ b/internal/db/mock_gen.go @@ -8,7 +8,7 @@ import ( "testing" ) -//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -o mocks.go +//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i UsersStore -o mocks.go func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { before := AccessTokens @@ -88,36 +88,6 @@ func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) { }) } -var _ UsersStore = (*MockUsersStore)(nil) - -type MockUsersStore struct { - MockAuthenticate func(username, password string, loginSourceID int64) (*User, error) - MockCreate func(username, email string, opts CreateUserOpts) (*User, error) - MockGetByEmail func(email string) (*User, error) - MockGetByID func(id int64) (*User, error) - MockGetByUsername func(username string) (*User, error) -} - -func (m *MockUsersStore) Authenticate(username, password string, loginSourceID int64) (*User, error) { - return m.MockAuthenticate(username, password, loginSourceID) -} - -func (m *MockUsersStore) Create(username, email string, opts CreateUserOpts) (*User, error) { - return m.MockCreate(username, email, opts) -} - -func (m *MockUsersStore) GetByEmail(email string) (*User, error) { - return m.MockGetByEmail(email) -} - -func (m *MockUsersStore) GetByID(id int64) (*User, error) { - return m.MockGetByID(id) -} - -func (m *MockUsersStore) GetByUsername(username string) (*User, error) { - return m.MockGetByUsername(username) -} - func SetMockUsersStore(t *testing.T, mock UsersStore) { before := Users Users = mock diff --git a/internal/db/mocks.go b/internal/db/mocks.go index 87f3c2c3..e6a39963 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -2371,6 +2371,663 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} { return []interface{}{c.Result0} } +// MockUsersStore is a mock implementation of the UsersStore interface (from +// the package gogs.io/gogs/internal/db) used for unit testing. +type MockUsersStore struct { + // AuthenticateFunc is an instance of a mock function object controlling + // the behavior of the method Authenticate. + AuthenticateFunc *UsersStoreAuthenticateFunc + // CreateFunc is an instance of a mock function object controlling the + // behavior of the method Create. + CreateFunc *UsersStoreCreateFunc + // GetByEmailFunc is an instance of a mock function object controlling + // the behavior of the method GetByEmail. + GetByEmailFunc *UsersStoreGetByEmailFunc + // GetByIDFunc is an instance of a mock function object controlling the + // behavior of the method GetByID. + GetByIDFunc *UsersStoreGetByIDFunc + // GetByUsernameFunc is an instance of a mock function object + // controlling the behavior of the method GetByUsername. + GetByUsernameFunc *UsersStoreGetByUsernameFunc +} + +// NewMockUsersStore creates a new mock of the UsersStore interface. All +// methods return zero values for all results, unless overwritten. +func NewMockUsersStore() *MockUsersStore { + return &MockUsersStore{ + AuthenticateFunc: &UsersStoreAuthenticateFunc{ + defaultHook: func(context.Context, string, string, int64) (r0 *User, r1 error) { + return + }, + }, + CreateFunc: &UsersStoreCreateFunc{ + defaultHook: func(context.Context, string, string, CreateUserOpts) (r0 *User, r1 error) { + return + }, + }, + GetByEmailFunc: &UsersStoreGetByEmailFunc{ + defaultHook: func(context.Context, string) (r0 *User, r1 error) { + return + }, + }, + GetByIDFunc: &UsersStoreGetByIDFunc{ + defaultHook: func(context.Context, int64) (r0 *User, r1 error) { + return + }, + }, + GetByUsernameFunc: &UsersStoreGetByUsernameFunc{ + defaultHook: func(context.Context, string) (r0 *User, r1 error) { + return + }, + }, + } +} + +// NewStrictMockUsersStore creates a new mock of the UsersStore interface. +// All methods panic on invocation, unless overwritten. +func NewStrictMockUsersStore() *MockUsersStore { + return &MockUsersStore{ + AuthenticateFunc: &UsersStoreAuthenticateFunc{ + defaultHook: func(context.Context, string, string, int64) (*User, error) { + panic("unexpected invocation of MockUsersStore.Authenticate") + }, + }, + CreateFunc: &UsersStoreCreateFunc{ + defaultHook: func(context.Context, string, string, CreateUserOpts) (*User, error) { + panic("unexpected invocation of MockUsersStore.Create") + }, + }, + GetByEmailFunc: &UsersStoreGetByEmailFunc{ + defaultHook: func(context.Context, string) (*User, error) { + panic("unexpected invocation of MockUsersStore.GetByEmail") + }, + }, + GetByIDFunc: &UsersStoreGetByIDFunc{ + defaultHook: func(context.Context, int64) (*User, error) { + panic("unexpected invocation of MockUsersStore.GetByID") + }, + }, + GetByUsernameFunc: &UsersStoreGetByUsernameFunc{ + defaultHook: func(context.Context, string) (*User, error) { + panic("unexpected invocation of MockUsersStore.GetByUsername") + }, + }, + } +} + +// NewMockUsersStoreFrom creates a new mock of the MockUsersStore interface. +// All methods delegate to the given implementation, unless overwritten. +func NewMockUsersStoreFrom(i UsersStore) *MockUsersStore { + return &MockUsersStore{ + AuthenticateFunc: &UsersStoreAuthenticateFunc{ + defaultHook: i.Authenticate, + }, + CreateFunc: &UsersStoreCreateFunc{ + defaultHook: i.Create, + }, + GetByEmailFunc: &UsersStoreGetByEmailFunc{ + defaultHook: i.GetByEmail, + }, + GetByIDFunc: &UsersStoreGetByIDFunc{ + defaultHook: i.GetByID, + }, + GetByUsernameFunc: &UsersStoreGetByUsernameFunc{ + defaultHook: i.GetByUsername, + }, + } +} + +// UsersStoreAuthenticateFunc describes the behavior when the Authenticate +// method of the parent MockUsersStore instance is invoked. +type UsersStoreAuthenticateFunc struct { + defaultHook func(context.Context, string, string, int64) (*User, error) + hooks []func(context.Context, string, string, int64) (*User, error) + history []UsersStoreAuthenticateFuncCall + mutex sync.Mutex +} + +// Authenticate delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockUsersStore) Authenticate(v0 context.Context, v1 string, v2 string, v3 int64) (*User, error) { + r0, r1 := m.AuthenticateFunc.nextHook()(v0, v1, v2, v3) + m.AuthenticateFunc.appendCall(UsersStoreAuthenticateFuncCall{v0, v1, v2, v3, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the Authenticate method +// of the parent MockUsersStore instance is invoked and the hook queue is +// empty. +func (f *UsersStoreAuthenticateFunc) SetDefaultHook(hook func(context.Context, string, string, int64) (*User, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Authenticate method of the parent MockUsersStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *UsersStoreAuthenticateFunc) PushHook(hook func(context.Context, string, string, int64) (*User, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *UsersStoreAuthenticateFunc) SetDefaultReturn(r0 *User, r1 error) { + f.SetDefaultHook(func(context.Context, string, string, int64) (*User, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *UsersStoreAuthenticateFunc) PushReturn(r0 *User, r1 error) { + f.PushHook(func(context.Context, string, string, int64) (*User, error) { + return r0, r1 + }) +} + +func (f *UsersStoreAuthenticateFunc) nextHook() func(context.Context, string, string, int64) (*User, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *UsersStoreAuthenticateFunc) appendCall(r0 UsersStoreAuthenticateFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of UsersStoreAuthenticateFuncCall objects +// describing the invocations of this function. +func (f *UsersStoreAuthenticateFunc) History() []UsersStoreAuthenticateFuncCall { + f.mutex.Lock() + history := make([]UsersStoreAuthenticateFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// UsersStoreAuthenticateFuncCall is an object that describes an invocation +// of method Authenticate on an instance of MockUsersStore. +type UsersStoreAuthenticateFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 string + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 string + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 int64 + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *User + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c UsersStoreAuthenticateFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c UsersStoreAuthenticateFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// UsersStoreCreateFunc describes the behavior when the Create method of the +// parent MockUsersStore instance is invoked. +type UsersStoreCreateFunc struct { + defaultHook func(context.Context, string, string, CreateUserOpts) (*User, error) + hooks []func(context.Context, string, string, CreateUserOpts) (*User, error) + history []UsersStoreCreateFuncCall + mutex sync.Mutex +} + +// Create delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockUsersStore) Create(v0 context.Context, v1 string, v2 string, v3 CreateUserOpts) (*User, error) { + r0, r1 := m.CreateFunc.nextHook()(v0, v1, v2, v3) + m.CreateFunc.appendCall(UsersStoreCreateFuncCall{v0, v1, v2, v3, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the Create method of the +// parent MockUsersStore instance is invoked and the hook queue is empty. +func (f *UsersStoreCreateFunc) SetDefaultHook(hook func(context.Context, string, string, CreateUserOpts) (*User, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Create method of the parent MockUsersStore instance invokes the hook at +// the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *UsersStoreCreateFunc) PushHook(hook func(context.Context, string, string, CreateUserOpts) (*User, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *UsersStoreCreateFunc) SetDefaultReturn(r0 *User, r1 error) { + f.SetDefaultHook(func(context.Context, string, string, CreateUserOpts) (*User, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *UsersStoreCreateFunc) PushReturn(r0 *User, r1 error) { + f.PushHook(func(context.Context, string, string, CreateUserOpts) (*User, error) { + return r0, r1 + }) +} + +func (f *UsersStoreCreateFunc) nextHook() func(context.Context, string, string, CreateUserOpts) (*User, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *UsersStoreCreateFunc) appendCall(r0 UsersStoreCreateFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of UsersStoreCreateFuncCall objects describing +// the invocations of this function. +func (f *UsersStoreCreateFunc) History() []UsersStoreCreateFuncCall { + f.mutex.Lock() + history := make([]UsersStoreCreateFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// UsersStoreCreateFuncCall is an object that describes an invocation of +// method Create on an instance of MockUsersStore. +type UsersStoreCreateFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 string + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 string + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 CreateUserOpts + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *User + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c UsersStoreCreateFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c UsersStoreCreateFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// UsersStoreGetByEmailFunc describes the behavior when the GetByEmail +// method of the parent MockUsersStore instance is invoked. +type UsersStoreGetByEmailFunc struct { + defaultHook func(context.Context, string) (*User, error) + hooks []func(context.Context, string) (*User, error) + history []UsersStoreGetByEmailFuncCall + mutex sync.Mutex +} + +// GetByEmail delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockUsersStore) GetByEmail(v0 context.Context, v1 string) (*User, error) { + r0, r1 := m.GetByEmailFunc.nextHook()(v0, v1) + m.GetByEmailFunc.appendCall(UsersStoreGetByEmailFuncCall{v0, v1, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetByEmail method of +// the parent MockUsersStore instance is invoked and the hook queue is +// empty. +func (f *UsersStoreGetByEmailFunc) SetDefaultHook(hook func(context.Context, string) (*User, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetByEmail method of the parent MockUsersStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *UsersStoreGetByEmailFunc) PushHook(hook func(context.Context, string) (*User, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *UsersStoreGetByEmailFunc) SetDefaultReturn(r0 *User, r1 error) { + f.SetDefaultHook(func(context.Context, string) (*User, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *UsersStoreGetByEmailFunc) PushReturn(r0 *User, r1 error) { + f.PushHook(func(context.Context, string) (*User, error) { + return r0, r1 + }) +} + +func (f *UsersStoreGetByEmailFunc) nextHook() func(context.Context, string) (*User, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *UsersStoreGetByEmailFunc) appendCall(r0 UsersStoreGetByEmailFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of UsersStoreGetByEmailFuncCall objects +// describing the invocations of this function. +func (f *UsersStoreGetByEmailFunc) History() []UsersStoreGetByEmailFuncCall { + f.mutex.Lock() + history := make([]UsersStoreGetByEmailFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// UsersStoreGetByEmailFuncCall is an object that describes an invocation of +// method GetByEmail on an instance of MockUsersStore. +type UsersStoreGetByEmailFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 string + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *User + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c UsersStoreGetByEmailFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c UsersStoreGetByEmailFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// UsersStoreGetByIDFunc describes the behavior when the GetByID method of +// the parent MockUsersStore instance is invoked. +type UsersStoreGetByIDFunc struct { + defaultHook func(context.Context, int64) (*User, error) + hooks []func(context.Context, int64) (*User, error) + history []UsersStoreGetByIDFuncCall + mutex sync.Mutex +} + +// GetByID delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockUsersStore) GetByID(v0 context.Context, v1 int64) (*User, error) { + r0, r1 := m.GetByIDFunc.nextHook()(v0, v1) + m.GetByIDFunc.appendCall(UsersStoreGetByIDFuncCall{v0, v1, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetByID method of +// the parent MockUsersStore instance is invoked and the hook queue is +// empty. +func (f *UsersStoreGetByIDFunc) SetDefaultHook(hook func(context.Context, int64) (*User, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetByID method of the parent MockUsersStore instance invokes the hook at +// the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *UsersStoreGetByIDFunc) PushHook(hook func(context.Context, int64) (*User, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *UsersStoreGetByIDFunc) SetDefaultReturn(r0 *User, r1 error) { + f.SetDefaultHook(func(context.Context, int64) (*User, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *UsersStoreGetByIDFunc) PushReturn(r0 *User, r1 error) { + f.PushHook(func(context.Context, int64) (*User, error) { + return r0, r1 + }) +} + +func (f *UsersStoreGetByIDFunc) nextHook() func(context.Context, int64) (*User, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *UsersStoreGetByIDFunc) appendCall(r0 UsersStoreGetByIDFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of UsersStoreGetByIDFuncCall objects +// describing the invocations of this function. +func (f *UsersStoreGetByIDFunc) History() []UsersStoreGetByIDFuncCall { + f.mutex.Lock() + history := make([]UsersStoreGetByIDFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// UsersStoreGetByIDFuncCall is an object that describes an invocation of +// method GetByID on an instance of MockUsersStore. +type UsersStoreGetByIDFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *User + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c UsersStoreGetByIDFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c UsersStoreGetByIDFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// UsersStoreGetByUsernameFunc describes the behavior when the GetByUsername +// method of the parent MockUsersStore instance is invoked. +type UsersStoreGetByUsernameFunc struct { + defaultHook func(context.Context, string) (*User, error) + hooks []func(context.Context, string) (*User, error) + history []UsersStoreGetByUsernameFuncCall + mutex sync.Mutex +} + +// GetByUsername delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockUsersStore) GetByUsername(v0 context.Context, v1 string) (*User, error) { + r0, r1 := m.GetByUsernameFunc.nextHook()(v0, v1) + m.GetByUsernameFunc.appendCall(UsersStoreGetByUsernameFuncCall{v0, v1, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetByUsername method +// of the parent MockUsersStore instance is invoked and the hook queue is +// empty. +func (f *UsersStoreGetByUsernameFunc) SetDefaultHook(hook func(context.Context, string) (*User, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetByUsername method of the parent MockUsersStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *UsersStoreGetByUsernameFunc) PushHook(hook func(context.Context, string) (*User, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *UsersStoreGetByUsernameFunc) SetDefaultReturn(r0 *User, r1 error) { + f.SetDefaultHook(func(context.Context, string) (*User, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *UsersStoreGetByUsernameFunc) PushReturn(r0 *User, r1 error) { + f.PushHook(func(context.Context, string) (*User, error) { + return r0, r1 + }) +} + +func (f *UsersStoreGetByUsernameFunc) nextHook() func(context.Context, string) (*User, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *UsersStoreGetByUsernameFunc) appendCall(r0 UsersStoreGetByUsernameFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of UsersStoreGetByUsernameFuncCall objects +// describing the invocations of this function. +func (f *UsersStoreGetByUsernameFunc) History() []UsersStoreGetByUsernameFuncCall { + f.mutex.Lock() + history := make([]UsersStoreGetByUsernameFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// UsersStoreGetByUsernameFuncCall is an object that describes an invocation +// of method GetByUsername on an instance of MockUsersStore. +type UsersStoreGetByUsernameFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 string + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *User + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c UsersStoreGetByUsernameFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c UsersStoreGetByUsernameFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + // MockLoginSourceFileStore is a mock implementation of the // loginSourceFileStore interface (from the package // gogs.io/gogs/internal/db) used for unit testing. diff --git a/internal/db/user.go b/internal/db/user.go index c6bee120..ebfbf082 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -942,7 +942,8 @@ func GetUserByID(id int64) (*User, error) { // GetAssigneeByID returns the user with read access of repository by given ID. func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { - if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead, + ctx := context.TODO() + if !Perms.Authorize(ctx, userID, repo.ID, AccessModeRead, AccessModeOptions{ OwnerID: repo.OwnerID, Private: repo.IsPrivate, @@ -950,7 +951,7 @@ func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { ) { return nil, ErrUserNotExist{args: map[string]interface{}{"userID": userID}} } - return Users.GetByID(userID) + return Users.GetByID(ctx, userID) } // GetUserByName returns a user by given name. diff --git a/internal/db/users.go b/internal/db/users.go index 5c8f2d38..cac22c44 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -23,8 +23,8 @@ import ( // // NOTE: All methods are sorted in alphabetical order. type UsersStore interface { - // Authenticate validates username and password via given login source ID. - // It returns ErrUserNotExist when the user was not found. + // Authenticate validates username and password via given login source ID. It + // returns ErrUserNotExist when the user was not found. // // When the "loginSourceID" is negative, it aborts the process and returns // ErrUserNotExist if the user was not found in the database. @@ -34,23 +34,25 @@ type UsersStore interface { // // When the "loginSourceID" is positive, it tries to authenticate via given // login source and creates a new user when not yet exists in the database. - Authenticate(username, password string, loginSourceID int64) (*User, error) - // Create creates a new user and persists to database. - // It returns ErrUserAlreadyExist when a user with same name already exists, - // or ErrEmailAlreadyUsed if the email has been used by another user. - Create(username, email string, opts CreateUserOpts) (*User, error) - // GetByEmail returns the user (not organization) with given email. - // It ignores records with unverified emails and returns ErrUserNotExist when not found. - GetByEmail(email string) (*User, error) - // GetByID returns the user with given ID. It returns ErrUserNotExist when not found. - GetByID(id int64) (*User, error) - // GetByUsername returns the user with given username. It returns ErrUserNotExist when not found. - GetByUsername(username string) (*User, error) + Authenticate(ctx context.Context, username, password string, loginSourceID int64) (*User, error) + // Create creates a new user and persists to database. It returns + // ErrUserAlreadyExist when a user with same name already exists, or + // ErrEmailAlreadyUsed if the email has been used by another user. + Create(ctx context.Context, username, email string, opts CreateUserOpts) (*User, error) + // GetByEmail returns the user (not organization) with given email. It ignores + // records with unverified emails and returns ErrUserNotExist when not found. + GetByEmail(ctx context.Context, email string) (*User, error) + // GetByID returns the user with given ID. It returns ErrUserNotExist when not + // found. + GetByID(ctx context.Context, id int64) (*User, error) + // GetByUsername returns the user with given username. It returns + // ErrUserNotExist when not found. + GetByUsername(ctx context.Context, username string) (*User, error) } var Users UsersStore -// NOTE: This is a GORM create hook. +// BeforeCreate implements the GORM create hook. func (u *User) BeforeCreate(tx *gorm.DB) error { if u.CreatedUnix == 0 { u.CreatedUnix = tx.NowFunc().Unix() @@ -59,7 +61,7 @@ func (u *User) BeforeCreate(tx *gorm.DB) error { return nil } -// NOTE: This is a GORM query hook. +// AfterFind implements the GORM query hook. func (u *User) AfterFind(_ *gorm.DB) error { u.Created = time.Unix(u.CreatedUnix, 0).Local() u.Updated = time.Unix(u.UpdatedUnix, 0).Local() @@ -80,16 +82,14 @@ func (err ErrLoginSourceMismatch) Error() string { return fmt.Sprintf("login source mismatch: %v", err.args) } -func (db *users) Authenticate(login, password string, loginSourceID int64) (*User, error) { - ctx := context.TODO() - +func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) { login = strings.ToLower(login) - var query *gorm.DB + query := db.WithContext(ctx) if strings.Contains(login, "@") { - query = db.Where("email = ?", login) + query = query.Where("email = ?", login) } else { - query = db.Where("lower_name = ?", login) + query = query.Where("lower_name = ?", login) } user := new(User) @@ -153,15 +153,17 @@ func (db *users) Authenticate(login, password string, loginSourceID int64) (*Use return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name) } - return Users.Create(extAccount.Name, extAccount.Email, CreateUserOpts{ - FullName: extAccount.FullName, - LoginSource: authSourceID, - LoginName: extAccount.Login, - Location: extAccount.Location, - Website: extAccount.Website, - Activated: true, - Admin: extAccount.Admin, - }) + return db.Create(ctx, extAccount.Name, extAccount.Email, + CreateUserOpts{ + FullName: extAccount.FullName, + LoginSource: authSourceID, + LoginName: extAccount.Login, + Location: extAccount.Location, + Website: extAccount.Website, + Activated: true, + Admin: extAccount.Admin, + }, + ) } type CreateUserOpts struct { @@ -209,20 +211,20 @@ func (err ErrEmailAlreadyUsed) Error() string { return fmt.Sprintf("email has been used: %v", err.args) } -func (db *users) Create(username, email string, opts CreateUserOpts) (*User, error) { +func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOpts) (*User, error) { err := isUsernameAllowed(username) if err != nil { return nil, err } - _, err = db.GetByUsername(username) + _, err = db.GetByUsername(ctx, username) if err == nil { return nil, ErrUserAlreadyExist{args: errutil.Args{"name": username}} } else if !IsErrUserNotExist(err) { return nil, err } - _, err = db.GetByEmail(email) + _, err = db.GetByEmail(ctx, email) if err == nil { return nil, ErrEmailAlreadyUsed{args: errutil.Args{"email": email}} } else if !IsErrUserNotExist(err) { @@ -256,7 +258,7 @@ func (db *users) Create(username, email string, opts CreateUserOpts) (*User, err } user.EncodePassword() - return user, db.DB.Create(user).Error + return user, db.WithContext(ctx).Create(user).Error } var _ errutil.NotFound = (*ErrUserNotExist)(nil) @@ -278,7 +280,7 @@ func (ErrUserNotExist) NotFound() bool { return true } -func (db *users) GetByEmail(email string) (*User, error) { +func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { email = strings.ToLower(email) if email == "" { @@ -287,7 +289,10 @@ func (db *users) GetByEmail(email string) (*User, error) { // First try to find the user by primary email user := new(User) - err := db.Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true).First(user).Error + err := db.WithContext(ctx). + Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true). + First(user). + Error if err == nil { return user, nil } else if err != gorm.ErrRecordNotFound { @@ -296,7 +301,10 @@ func (db *users) GetByEmail(email string) (*User, error) { // Otherwise, check activated email addresses emailAddress := new(EmailAddress) - err = db.Where("email = ? AND is_activated = ?", email, true).First(emailAddress).Error + err = db.WithContext(ctx). + Where("email = ? AND is_activated = ?", email, true). + First(emailAddress). + Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"email": email}} @@ -304,12 +312,12 @@ func (db *users) GetByEmail(email string) (*User, error) { return nil, err } - return db.GetByID(emailAddress.UID) + return db.GetByID(ctx, emailAddress.UID) } -func (db *users) GetByID(id int64) (*User, error) { +func (db *users) GetByID(ctx context.Context, id int64) (*User, error) { user := new(User) - err := db.Where("id = ?", id).First(user).Error + err := db.WithContext(ctx).Where("id = ?", id).First(user).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"userID": id}} @@ -319,9 +327,9 @@ func (db *users) GetByID(id int64) (*User, error) { return user, nil } -func (db *users) GetByUsername(username string) (*User, error) { +func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) { user := new(User) - err := db.Where("lower_name = ?", strings.ToLower(username)).First(user).Error + err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotExist{args: errutil.Args{"name": username}} diff --git a/internal/db/users_test.go b/internal/db/users_test.go index dac2c208..d691110a 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -5,16 +5,18 @@ package db import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gogs.io/gogs/internal/auth" "gogs.io/gogs/internal/errutil" ) -func Test_users(t *testing.T) { +func TestUsers(t *testing.T) { if testing.Short() { t.Skip() } @@ -30,18 +32,16 @@ func Test_users(t *testing.T) { name string test func(*testing.T, *users) }{ - {"Authenticate", test_users_Authenticate}, - {"Create", test_users_Create}, - {"GetByEmail", test_users_GetByEmail}, - {"GetByID", test_users_GetByID}, - {"GetByUsername", test_users_GetByUsername}, + {"Authenticate", usersAuthenticate}, + {"Create", usersCreate}, + {"GetByEmail", usersGetByEmail}, + {"GetByID", usersGetByID}, + {"GetByUsername", usersGetByUsername}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -53,187 +53,165 @@ func Test_users(t *testing.T) { // TODO: Only local account is tested, tests for external account will be added // along with addressing https://github.com/gogs/gogs/issues/6115. -func test_users_Authenticate(t *testing.T, db *users) { +func usersAuthenticate(t *testing.T, db *users) { + ctx := context.Background() + password := "pa$$word" - alice, err := db.Create("alice", "alice@example.com", CreateUserOpts{ - Password: password, - }) - if err != nil { - t.Fatal(err) - } + alice, err := db.Create(ctx, "alice", "alice@example.com", + CreateUserOpts{ + Password: password, + }, + ) + require.NoError(t, err) t.Run("user not found", func(t *testing.T) { - _, err := db.Authenticate("bob", password, -1) - expErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": "bob"}} - assert.Equal(t, expErr, err) + _, err := db.Authenticate(ctx, "bob", password, -1) + wantErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": "bob"}} + assert.Equal(t, wantErr, err) }) t.Run("invalid password", func(t *testing.T) { - _, err := db.Authenticate(alice.Name, "bad_password", -1) - expErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": alice.Name, "userID": alice.ID}} - assert.Equal(t, expErr, err) + _, err := db.Authenticate(ctx, alice.Name, "bad_password", -1) + wantErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": alice.Name, "userID": alice.ID}} + assert.Equal(t, wantErr, err) }) t.Run("via email and password", func(t *testing.T) { - user, err := db.Authenticate(alice.Email, password, -1) - if err != nil { - t.Fatal(err) - } + user, err := db.Authenticate(ctx, alice.Email, password, -1) + require.NoError(t, err) assert.Equal(t, alice.Name, user.Name) }) t.Run("via username and password", func(t *testing.T) { - user, err := db.Authenticate(alice.Name, password, -1) - if err != nil { - t.Fatal(err) - } + user, err := db.Authenticate(ctx, alice.Name, password, -1) + require.NoError(t, err) assert.Equal(t, alice.Name, user.Name) }) } -func test_users_Create(t *testing.T, db *users) { - alice, err := db.Create("alice", "alice@example.com", CreateUserOpts{ - Activated: true, - }) - if err != nil { - t.Fatal(err) - } +func usersCreate(t *testing.T, db *users) { + ctx := context.Background() + + alice, err := db.Create(ctx, "alice", "alice@example.com", + CreateUserOpts{ + Activated: true, + }, + ) + require.NoError(t, err) t.Run("name not allowed", func(t *testing.T) { - _, err := db.Create("-", "", CreateUserOpts{}) - expErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "name": "-"}} - assert.Equal(t, expErr, err) + _, err := db.Create(ctx, "-", "", CreateUserOpts{}) + wantErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "name": "-"}} + assert.Equal(t, wantErr, err) }) t.Run("name already exists", func(t *testing.T) { - _, err := db.Create(alice.Name, "", CreateUserOpts{}) - expErr := ErrUserAlreadyExist{args: errutil.Args{"name": alice.Name}} - assert.Equal(t, expErr, err) + _, err := db.Create(ctx, alice.Name, "", CreateUserOpts{}) + wantErr := ErrUserAlreadyExist{args: errutil.Args{"name": alice.Name}} + assert.Equal(t, wantErr, err) }) t.Run("email already exists", func(t *testing.T) { - _, err := db.Create("bob", alice.Email, CreateUserOpts{}) - expErr := ErrEmailAlreadyUsed{args: errutil.Args{"email": alice.Email}} - assert.Equal(t, expErr, err) + _, err := db.Create(ctx, "bob", alice.Email, CreateUserOpts{}) + wantErr := ErrEmailAlreadyUsed{args: errutil.Args{"email": alice.Email}} + assert.Equal(t, wantErr, err) }) - user, err := db.GetByUsername(alice.Name) - if err != nil { - t.Fatal(err) - } + user, err := db.GetByUsername(ctx, alice.Name) + require.NoError(t, err) assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339)) assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) } -func test_users_GetByEmail(t *testing.T, db *users) { +func usersGetByEmail(t *testing.T, db *users) { + ctx := context.Background() + t.Run("empty email", func(t *testing.T) { - _, err := db.GetByEmail("") - expErr := ErrUserNotExist{args: errutil.Args{"email": ""}} - assert.Equal(t, expErr, err) + _, err := db.GetByEmail(ctx, "") + wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}} + assert.Equal(t, wantErr, err) }) t.Run("ignore organization", func(t *testing.T) { // TODO: Use Orgs.Create to replace SQL hack when the method is available. - org, err := db.Create("gogs", "gogs@exmaple.com", CreateUserOpts{}) - if err != nil { - t.Fatal(err) - } + org, err := db.Create(ctx, "gogs", "gogs@exmaple.com", CreateUserOpts{}) + require.NoError(t, err) err = db.Model(&User{}).Where("id", org.ID).UpdateColumn("type", UserOrganization).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - _, err = db.GetByEmail(org.Email) - expErr := ErrUserNotExist{args: errutil.Args{"email": org.Email}} - assert.Equal(t, expErr, err) + _, err = db.GetByEmail(ctx, org.Email) + wantErr := ErrUserNotExist{args: errutil.Args{"email": org.Email}} + assert.Equal(t, wantErr, err) }) t.Run("by primary email", func(t *testing.T) { - alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{}) - if err != nil { - t.Fatal(err) - } + alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{}) + require.NoError(t, err) - _, err = db.GetByEmail(alice.Email) - expErr := ErrUserNotExist{args: errutil.Args{"email": alice.Email}} - assert.Equal(t, expErr, err) + _, err = db.GetByEmail(ctx, alice.Email) + wantErr := ErrUserNotExist{args: errutil.Args{"email": alice.Email}} + assert.Equal(t, wantErr, err) // Mark user as activated // TODO: Use UserEmails.Verify to replace SQL hack when the method is available. err = db.Model(&User{}).Where("id", alice.ID).UpdateColumn("is_active", true).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - user, err := db.GetByEmail(alice.Email) - if err != nil { - t.Fatal(err) - } + user, err := db.GetByEmail(ctx, alice.Email) + require.NoError(t, err) assert.Equal(t, alice.Name, user.Name) }) t.Run("by secondary email", func(t *testing.T) { - bob, err := db.Create("bob", "bob@example.com", CreateUserOpts{}) - if err != nil { - t.Fatal(err) - } + bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOpts{}) + require.NoError(t, err) // TODO: Use UserEmails.Create to replace SQL hack when the method is available. email2 := "bob2@exmaple.com" err = db.Exec(`INSERT INTO email_address (uid, email) VALUES (?, ?)`, bob.ID, email2).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - _, err = db.GetByEmail(email2) - expErr := ErrUserNotExist{args: errutil.Args{"email": email2}} - assert.Equal(t, expErr, err) + _, err = db.GetByEmail(ctx, email2) + wantErr := ErrUserNotExist{args: errutil.Args{"email": email2}} + assert.Equal(t, wantErr, err) // TODO: Use UserEmails.Verify to replace SQL hack when the method is available. err = db.Exec(`UPDATE email_address SET is_activated = ? WHERE email = ?`, true, email2).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - user, err := db.GetByEmail(email2) - if err != nil { - t.Fatal(err) - } + user, err := db.GetByEmail(ctx, email2) + require.NoError(t, err) assert.Equal(t, bob.Name, user.Name) }) } -func test_users_GetByID(t *testing.T, db *users) { - alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{}) - if err != nil { - t.Fatal(err) - } +func usersGetByID(t *testing.T, db *users) { + ctx := context.Background() - user, err := db.GetByID(alice.ID) - if err != nil { - t.Fatal(err) - } + alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{}) + require.NoError(t, err) + + user, err := db.GetByID(ctx, alice.ID) + require.NoError(t, err) assert.Equal(t, alice.Name, user.Name) - _, err = db.GetByID(404) - expErr := ErrUserNotExist{args: errutil.Args{"userID": int64(404)}} - assert.Equal(t, expErr, err) + _, err = db.GetByID(ctx, 404) + wantErr := ErrUserNotExist{args: errutil.Args{"userID": int64(404)}} + assert.Equal(t, wantErr, err) } -func test_users_GetByUsername(t *testing.T, db *users) { - alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{}) - if err != nil { - t.Fatal(err) - } +func usersGetByUsername(t *testing.T, db *users) { + ctx := context.Background() - user, err := db.GetByUsername(alice.Name) - if err != nil { - t.Fatal(err) - } + alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{}) + require.NoError(t, err) + + user, err := db.GetByUsername(ctx, alice.Name) + require.NoError(t, err) assert.Equal(t, alice.Name, user.Name) - _, err = db.GetByUsername("bad_username") - expErr := ErrUserNotExist{args: errutil.Args{"name": "bad_username"}} - assert.Equal(t, expErr, err) + _, err = db.GetByUsername(ctx, "bad_username") + wantErr := ErrUserNotExist{args: errutil.Args{"name": "bad_username"}} + assert.Equal(t, wantErr, err) } diff --git a/internal/route/lfs/route.go b/internal/route/lfs/route.go index f5195837..94c42fea 100644 --- a/internal/route/lfs/route.go +++ b/internal/route/lfs/route.go @@ -58,7 +58,7 @@ func authenticate() macaron.Handler { return } - user, err := db.Users.Authenticate(username, password, -1) + user, err := db.Users.Authenticate(c.Req.Context(), username, password, -1) if err != nil && !auth.IsErrBadCredentials(err) { internalServerError(c.Resp) log.Error("Failed to authenticate user [name: %s]: %v", username, err) @@ -86,7 +86,7 @@ func authenticate() macaron.Handler { log.Error("Failed to touch access token: %v", err) } - user, err = db.Users.GetByID(token.UserID) + user, err = db.Users.GetByID(c.Req.Context(), token.UserID) if err != nil { // Once we found the token, we're supposed to find its related user, // thus any error is unexpected. @@ -108,7 +108,7 @@ func authorize(mode db.AccessMode) macaron.Handler { username := c.Params(":username") reponame := strings.TrimSuffix(c.Params(":reponame"), ".git") - owner, err := db.Users.GetByUsername(username) + owner, err := db.Users.GetByUsername(c.Req.Context(), username) if err != nil { if db.IsErrUserNotExist(err) { c.Status(http.StatusNotFound) diff --git a/internal/route/lfs/route_test.go b/internal/route/lfs/route_test.go index 57d0728f..6202695b 100644 --- a/internal/route/lfs/route_test.go +++ b/internal/route/lfs/route_test.go @@ -30,7 +30,7 @@ func Test_authenticate(t *testing.T) { tests := []struct { name string header http.Header - mockUsersStore *db.MockUsersStore + mockUsersStore func() db.UsersStore mockTwoFactorsStore *db.MockTwoFactorsStore mockAccessTokensStore func() db.AccessTokensStore expStatusCode int @@ -51,10 +51,10 @@ func Test_authenticate(t *testing.T) { header: http.Header{ "Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="}, }, - mockUsersStore: &db.MockUsersStore{ - MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) { - return &db.User{}, nil - }, + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.AuthenticateFunc.SetDefaultReturn(&db.User{}, nil) + return mock }, mockTwoFactorsStore: &db.MockTwoFactorsStore{ MockIsUserEnabled: func(userID int64) bool { @@ -70,10 +70,10 @@ func Test_authenticate(t *testing.T) { header: http.Header{ "Authorization": []string{"Basic dXNlcm5hbWU="}, }, - mockUsersStore: &db.MockUsersStore{ - MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) { - return nil, auth.ErrBadCredentials{} - }, + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.AuthenticateFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{}) + return mock }, mockAccessTokensStore: func() db.AccessTokensStore { mock := db.NewMockAccessTokensStore() @@ -93,10 +93,10 @@ func Test_authenticate(t *testing.T) { header: http.Header{ "Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="}, }, - mockUsersStore: &db.MockUsersStore{ - MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) { - return &db.User{ID: 1, Name: "unknwon"}, nil - }, + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.AuthenticateFunc.SetDefaultReturn(&db.User{ID: 1, Name: "unknwon"}, nil) + return mock }, mockTwoFactorsStore: &db.MockTwoFactorsStore{ MockIsUserEnabled: func(userID int64) bool { @@ -112,13 +112,11 @@ func Test_authenticate(t *testing.T) { header: http.Header{ "Authorization": []string{"Basic dXNlcm5hbWU="}, }, - mockUsersStore: &db.MockUsersStore{ - MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) { - return nil, auth.ErrBadCredentials{} - }, - MockGetByID: func(id int64) (*db.User, error) { - return &db.User{ID: 1, Name: "unknwon"}, nil - }, + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.AuthenticateFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{}) + mock.GetByIDFunc.SetDefaultReturn(&db.User{ID: 1, Name: "unknwon"}, nil) + return mock }, mockAccessTokensStore: func() db.AccessTokensStore { mock := db.NewMockAccessTokensStore() @@ -132,9 +130,10 @@ func Test_authenticate(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockUsersStore(t, test.mockUsersStore) + if test.mockUsersStore != nil { + db.SetMockUsersStore(t, test.mockUsersStore()) + } db.SetMockTwoFactorsStore(t, test.mockTwoFactorsStore) - if test.mockAccessTokensStore != nil { db.SetMockAccessTokensStore(t, test.mockAccessTokensStore()) } @@ -165,7 +164,7 @@ func Test_authorize(t *testing.T) { tests := []struct { name string authroize macaron.Handler - mockUsersStore *db.MockUsersStore + mockUsersStore func() db.UsersStore mockReposStore *db.MockReposStore mockPermsStore func() db.PermsStore expStatusCode int @@ -174,20 +173,22 @@ func Test_authorize(t *testing.T) { { name: "user does not exist", authroize: authorize(db.AccessModeNone), - mockUsersStore: &db.MockUsersStore{ - MockGetByUsername: func(username string) (*db.User, error) { - return nil, db.ErrUserNotExist{} - }, + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.GetByUsernameFunc.SetDefaultReturn(nil, db.ErrUserNotExist{}) + return mock }, expStatusCode: http.StatusNotFound, }, { name: "repository does not exist", authroize: authorize(db.AccessModeNone), - mockUsersStore: &db.MockUsersStore{ - MockGetByUsername: func(username string) (*db.User, error) { + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) { return &db.User{Name: username}, nil - }, + }) + return mock }, mockReposStore: &db.MockReposStore{ MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { @@ -199,10 +200,12 @@ func Test_authorize(t *testing.T) { { name: "actor is not authorized", authroize: authorize(db.AccessModeWrite), - mockUsersStore: &db.MockUsersStore{ - MockGetByUsername: func(username string) (*db.User, error) { + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) { return &db.User{Name: username}, nil - }, + }) + return mock }, mockReposStore: &db.MockReposStore{ MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { @@ -222,10 +225,12 @@ func Test_authorize(t *testing.T) { { name: "actor is authorized", authroize: authorize(db.AccessModeRead), - mockUsersStore: &db.MockUsersStore{ - MockGetByUsername: func(username string) (*db.User, error) { + mockUsersStore: func() db.UsersStore { + mock := db.NewMockUsersStore() + mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) { return &db.User{Name: username}, nil - }, + }) + return mock }, mockReposStore: &db.MockReposStore{ MockGetByName: func(ownerID int64, name string) (*db.Repository, error) { @@ -245,9 +250,10 @@ func Test_authorize(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - db.SetMockUsersStore(t, test.mockUsersStore) + if test.mockUsersStore != nil { + db.SetMockUsersStore(t, test.mockUsersStore()) + } db.SetMockReposStore(t, test.mockReposStore) - if test.mockPermsStore != nil { db.SetMockPermsStore(t, test.mockPermsStore()) } diff --git a/internal/route/org/setting.go b/internal/route/org/setting.go index 94e9b7a6..e15b9faf 100644 --- a/internal/route/org/setting.go +++ b/internal/route/org/setting.go @@ -109,7 +109,7 @@ func SettingsDelete(c *context.Context) { org := c.Org.Organization if c.Req.Method == "POST" { - if _, err := db.Users.Authenticate(c.User.Name, c.Query("password"), c.User.LoginSource); err != nil { + if _, err := db.Users.Authenticate(c.Req.Context(), c.User.Name, c.Query("password"), c.User.LoginSource); err != nil { if auth.IsErrBadCredentials(err) { c.RenderWithErr(c.Tr("form.enterred_invalid_password"), SETTINGS_DELETE, nil) } else { diff --git a/internal/route/repo/http.go b/internal/route/repo/http.go index 7e970194..888ad4d8 100644 --- a/internal/route/repo/http.go +++ b/internal/route/repo/http.go @@ -65,7 +65,7 @@ func HTTPContexter() macaron.Handler { strings.HasSuffix(c.Req.URL.Path, "git-upload-pack") || c.Req.Method == "GET" - owner, err := db.Users.GetByUsername(ownerName) + owner, err := db.Users.GetByUsername(c.Req.Context(), ownerName) if err != nil { if db.IsErrUserNotExist(err) { c.Status(http.StatusNotFound) @@ -123,7 +123,7 @@ func HTTPContexter() macaron.Handler { return } - authUser, err := db.Users.Authenticate(authUsername, authPassword, -1) + authUser, err := db.Users.Authenticate(c.Req.Context(), authUsername, authPassword, -1) if err != nil && !auth.IsErrBadCredentials(err) { c.Status(http.StatusInternalServerError) log.Error("Failed to authenticate user [name: %s]: %v", authUsername, err) @@ -146,7 +146,7 @@ func HTTPContexter() macaron.Handler { log.Error("Failed to touch access token: %v", err) } - authUser, err = db.Users.GetByID(token.UserID) + authUser, err = db.Users.GetByID(c.Req.Context(), token.UserID) if err != nil { // Once we found token, we're supposed to find its related user, // thus any error is unexpected. diff --git a/internal/route/repo/tasks.go b/internal/route/repo/tasks.go index 81e85e2a..c5b555b9 100644 --- a/internal/route/repo/tasks.go +++ b/internal/route/repo/tasks.go @@ -26,7 +26,7 @@ func TriggerTask(c *macaron.Context) { username := c.Params(":username") reponame := c.Params(":reponame") - owner, err := db.Users.GetByUsername(username) + owner, err := db.Users.GetByUsername(c.Req.Context(), username) if err != nil { if db.IsErrUserNotExist(err) { c.Error(http.StatusBadRequest, "Owner does not exist") @@ -55,7 +55,7 @@ func TriggerTask(c *macaron.Context) { return } - pusher, err := db.Users.GetByID(pusherID) + pusher, err := db.Users.GetByID(c.Req.Context(), pusherID) if err != nil { if db.IsErrUserNotExist(err) { c.Error(http.StatusBadRequest, "Pusher does not exist") diff --git a/internal/route/user/auth.go b/internal/route/user/auth.go index f8bbe7ab..b3b785c2 100644 --- a/internal/route/user/auth.go +++ b/internal/route/user/auth.go @@ -161,7 +161,7 @@ func LoginPost(c *context.Context, f form.SignIn) { return } - u, err := db.Users.Authenticate(f.UserName, f.Password, f.LoginSource) + u, err := db.Users.Authenticate(c.Req.Context(), f.UserName, f.Password, f.LoginSource) if err != nil { switch errors.Cause(err).(type) { case auth.ErrBadCredentials: diff --git a/internal/route/user/setting.go b/internal/route/user/setting.go index cf2226c5..c28a1747 100644 --- a/internal/route/user/setting.go +++ b/internal/route/user/setting.go @@ -640,7 +640,7 @@ func SettingsDelete(c *context.Context) { c.PageIs("SettingsDelete") if c.Req.Method == "POST" { - if _, err := db.Users.Authenticate(c.User.Name, c.Query("password"), c.User.LoginSource); err != nil { + if _, err := db.Users.Authenticate(c.Req.Context(), c.User.Name, c.Query("password"), c.User.LoginSource); err != nil { if auth.IsErrBadCredentials(err) { c.RenderWithErr(c.Tr("form.enterred_invalid_password"), SETTINGS_DELETE, nil) } else { |