aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/context/auth.go2
-rw-r--r--internal/context/go_get.go2
-rw-r--r--internal/db/login_source_files_test.go4
-rw-r--r--internal/db/login_sources_test.go20
-rw-r--r--internal/db/mock_gen.go32
-rw-r--r--internal/db/mocks.go657
-rw-r--r--internal/db/user.go5
-rw-r--r--internal/db/users.go94
-rw-r--r--internal/db/users_test.go216
-rw-r--r--internal/route/lfs/route.go6
-rw-r--r--internal/route/lfs/route_test.go82
-rw-r--r--internal/route/org/setting.go2
-rw-r--r--internal/route/repo/http.go6
-rw-r--r--internal/route/repo/tasks.go4
-rw-r--r--internal/route/user/auth.go2
-rw-r--r--internal/route/user/setting.go2
16 files changed, 879 insertions, 257 deletions
diff --git a/internal/context/auth.go b/internal/context/auth.go
index 25625c3d..cba68706 100644
--- a/internal/context/auth.go
+++ b/internal/context/auth.go
@@ -208,7 +208,7 @@ func authenticatedUser(ctx *macaron.Context, sess session.Store) (_ *db.User, is
if len(auths) == 2 && auths[0] == "Basic" {
uname, passwd, _ := tool.BasicAuthDecode(auths[1])
- u, err := db.Users.Authenticate(uname, passwd, -1)
+ u, err := db.Users.Authenticate(ctx.Req.Context(), uname, passwd, -1)
if err != nil {
if !auth.IsErrBadCredentials(err) {
log.Error("Failed to authenticate user: %v", err)
diff --git a/internal/context/go_get.go b/internal/context/go_get.go
index 2851b5db..5a59e1b5 100644
--- a/internal/context/go_get.go
+++ b/internal/context/go_get.go
@@ -26,7 +26,7 @@ func ServeGoGet() macaron.Handler {
repoName := c.Params(":reponame")
branchName := "master"
- owner, err := db.Users.GetByUsername(ownerName)
+ owner, err := db.Users.GetByUsername(c.Req.Context(), ownerName)
if err == nil {
repo, err := db.Repos.GetByName(owner.ID, repoName)
if err == nil && repo.DefaultBranch != "" {
diff --git a/internal/db/login_source_files_test.go b/internal/db/login_source_files_test.go
index 9f66f582..6254340f 100644
--- a/internal/db/login_source_files_test.go
+++ b/internal/db/login_source_files_test.go
@@ -23,8 +23,8 @@ func TestLoginSourceFiles_GetByID(t *testing.T) {
t.Run("source does not exist", func(t *testing.T) {
_, err := store.GetByID(1)
- expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": int64(1)}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": int64(1)}}
+ assert.Equal(t, wantErr, err)
})
t.Run("source exists", func(t *testing.T) {
diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go
index 9e14caf9..ad09a8db 100644
--- a/internal/db/login_sources_test.go
+++ b/internal/db/login_sources_test.go
@@ -138,8 +138,8 @@ func loginSourcesCreate(t *testing.T, db *loginSources) {
// Try create second login source with same name should fail
_, err = db.Create(ctx, CreateLoginSourceOpts{Name: source.Name})
- expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
+ assert.Equal(t, wantErr, err)
}
func loginSourcesCount(t *testing.T, db *loginSources) {
@@ -184,15 +184,17 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
require.NoError(t, err)
// Create a user that uses this login source
- _, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{
- LoginSource: source.ID,
- })
+ _, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
+ CreateUserOpts{
+ LoginSource: source.ID,
+ },
+ )
require.NoError(t, err)
// Delete the login source will result in error
err = db.DeleteByID(ctx, source.ID)
- expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
+ assert.Equal(t, wantErr, err)
})
mock := NewMockLoginSourceFilesStore()
@@ -229,8 +231,8 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
// We should get token not found error
_, err = db.GetByID(ctx, source.ID)
- expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
- assert.Equal(t, expErr, err)
+ wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
+ assert.Equal(t, wantErr, err)
}
func loginSourcesGetByID(t *testing.T, db *loginSources) {
diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go
index a73f10d0..7235c6b5 100644
--- a/internal/db/mock_gen.go
+++ b/internal/db/mock_gen.go
@@ -8,7 +8,7 @@ import (
"testing"
)
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i UsersStore -o mocks.go
func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
before := AccessTokens
@@ -88,36 +88,6 @@ func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) {
})
}
-var _ UsersStore = (*MockUsersStore)(nil)
-
-type MockUsersStore struct {
- MockAuthenticate func(username, password string, loginSourceID int64) (*User, error)
- MockCreate func(username, email string, opts CreateUserOpts) (*User, error)
- MockGetByEmail func(email string) (*User, error)
- MockGetByID func(id int64) (*User, error)
- MockGetByUsername func(username string) (*User, error)
-}
-
-func (m *MockUsersStore) Authenticate(username, password string, loginSourceID int64) (*User, error) {
- return m.MockAuthenticate(username, password, loginSourceID)
-}
-
-func (m *MockUsersStore) Create(username, email string, opts CreateUserOpts) (*User, error) {
- return m.MockCreate(username, email, opts)
-}
-
-func (m *MockUsersStore) GetByEmail(email string) (*User, error) {
- return m.MockGetByEmail(email)
-}
-
-func (m *MockUsersStore) GetByID(id int64) (*User, error) {
- return m.MockGetByID(id)
-}
-
-func (m *MockUsersStore) GetByUsername(username string) (*User, error) {
- return m.MockGetByUsername(username)
-}
-
func SetMockUsersStore(t *testing.T, mock UsersStore) {
before := Users
Users = mock
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index 87f3c2c3..e6a39963 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -2371,6 +2371,663 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
+// MockUsersStore is a mock implementation of the UsersStore interface (from
+// the package gogs.io/gogs/internal/db) used for unit testing.
+type MockUsersStore struct {
+ // AuthenticateFunc is an instance of a mock function object controlling
+ // the behavior of the method Authenticate.
+ AuthenticateFunc *UsersStoreAuthenticateFunc
+ // CreateFunc is an instance of a mock function object controlling the
+ // behavior of the method Create.
+ CreateFunc *UsersStoreCreateFunc
+ // GetByEmailFunc is an instance of a mock function object controlling
+ // the behavior of the method GetByEmail.
+ GetByEmailFunc *UsersStoreGetByEmailFunc
+ // GetByIDFunc is an instance of a mock function object controlling the
+ // behavior of the method GetByID.
+ GetByIDFunc *UsersStoreGetByIDFunc
+ // GetByUsernameFunc is an instance of a mock function object
+ // controlling the behavior of the method GetByUsername.
+ GetByUsernameFunc *UsersStoreGetByUsernameFunc
+}
+
+// NewMockUsersStore creates a new mock of the UsersStore interface. All
+// methods return zero values for all results, unless overwritten.
+func NewMockUsersStore() *MockUsersStore {
+ return &MockUsersStore{
+ AuthenticateFunc: &UsersStoreAuthenticateFunc{
+ defaultHook: func(context.Context, string, string, int64) (r0 *User, r1 error) {
+ return
+ },
+ },
+ CreateFunc: &UsersStoreCreateFunc{
+ defaultHook: func(context.Context, string, string, CreateUserOpts) (r0 *User, r1 error) {
+ return
+ },
+ },
+ GetByEmailFunc: &UsersStoreGetByEmailFunc{
+ defaultHook: func(context.Context, string) (r0 *User, r1 error) {
+ return
+ },
+ },
+ GetByIDFunc: &UsersStoreGetByIDFunc{
+ defaultHook: func(context.Context, int64) (r0 *User, r1 error) {
+ return
+ },
+ },
+ GetByUsernameFunc: &UsersStoreGetByUsernameFunc{
+ defaultHook: func(context.Context, string) (r0 *User, r1 error) {
+ return
+ },
+ },
+ }
+}
+
+// NewStrictMockUsersStore creates a new mock of the UsersStore interface.
+// All methods panic on invocation, unless overwritten.
+func NewStrictMockUsersStore() *MockUsersStore {
+ return &MockUsersStore{
+ AuthenticateFunc: &UsersStoreAuthenticateFunc{
+ defaultHook: func(context.Context, string, string, int64) (*User, error) {
+ panic("unexpected invocation of MockUsersStore.Authenticate")
+ },
+ },
+ CreateFunc: &UsersStoreCreateFunc{
+ defaultHook: func(context.Context, string, string, CreateUserOpts) (*User, error) {
+ panic("unexpected invocation of MockUsersStore.Create")
+ },
+ },
+ GetByEmailFunc: &UsersStoreGetByEmailFunc{
+ defaultHook: func(context.Context, string) (*User, error) {
+ panic("unexpected invocation of MockUsersStore.GetByEmail")
+ },
+ },
+ GetByIDFunc: &UsersStoreGetByIDFunc{
+ defaultHook: func(context.Context, int64) (*User, error) {
+ panic("unexpected invocation of MockUsersStore.GetByID")
+ },
+ },
+ GetByUsernameFunc: &UsersStoreGetByUsernameFunc{
+ defaultHook: func(context.Context, string) (*User, error) {
+ panic("unexpected invocation of MockUsersStore.GetByUsername")
+ },
+ },
+ }
+}
+
+// NewMockUsersStoreFrom creates a new mock of the MockUsersStore interface.
+// All methods delegate to the given implementation, unless overwritten.
+func NewMockUsersStoreFrom(i UsersStore) *MockUsersStore {
+ return &MockUsersStore{
+ AuthenticateFunc: &UsersStoreAuthenticateFunc{
+ defaultHook: i.Authenticate,
+ },
+ CreateFunc: &UsersStoreCreateFunc{
+ defaultHook: i.Create,
+ },
+ GetByEmailFunc: &UsersStoreGetByEmailFunc{
+ defaultHook: i.GetByEmail,
+ },
+ GetByIDFunc: &UsersStoreGetByIDFunc{
+ defaultHook: i.GetByID,
+ },
+ GetByUsernameFunc: &UsersStoreGetByUsernameFunc{
+ defaultHook: i.GetByUsername,
+ },
+ }
+}
+
+// UsersStoreAuthenticateFunc describes the behavior when the Authenticate
+// method of the parent MockUsersStore instance is invoked.
+type UsersStoreAuthenticateFunc struct {
+ defaultHook func(context.Context, string, string, int64) (*User, error)
+ hooks []func(context.Context, string, string, int64) (*User, error)
+ history []UsersStoreAuthenticateFuncCall
+ mutex sync.Mutex
+}
+
+// Authenticate delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockUsersStore) Authenticate(v0 context.Context, v1 string, v2 string, v3 int64) (*User, error) {
+ r0, r1 := m.AuthenticateFunc.nextHook()(v0, v1, v2, v3)
+ m.AuthenticateFunc.appendCall(UsersStoreAuthenticateFuncCall{v0, v1, v2, v3, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the Authenticate method
+// of the parent MockUsersStore instance is invoked and the hook queue is
+// empty.
+func (f *UsersStoreAuthenticateFunc) SetDefaultHook(hook func(context.Context, string, string, int64) (*User, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Authenticate method of the parent MockUsersStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *UsersStoreAuthenticateFunc) PushHook(hook func(context.Context, string, string, int64) (*User, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreAuthenticateFunc) SetDefaultReturn(r0 *User, r1 error) {
+ f.SetDefaultHook(func(context.Context, string, string, int64) (*User, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreAuthenticateFunc) PushReturn(r0 *User, r1 error) {
+ f.PushHook(func(context.Context, string, string, int64) (*User, error) {
+ return r0, r1
+ })
+}
+
+func (f *UsersStoreAuthenticateFunc) nextHook() func(context.Context, string, string, int64) (*User, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *UsersStoreAuthenticateFunc) appendCall(r0 UsersStoreAuthenticateFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreAuthenticateFuncCall objects
+// describing the invocations of this function.
+func (f *UsersStoreAuthenticateFunc) History() []UsersStoreAuthenticateFuncCall {
+ f.mutex.Lock()
+ history := make([]UsersStoreAuthenticateFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// UsersStoreAuthenticateFuncCall is an object that describes an invocation
+// of method Authenticate on an instance of MockUsersStore.
+type UsersStoreAuthenticateFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 string
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 string
+ // Arg3 is the value of the 4th argument passed to this method
+ // invocation.
+ Arg3 int64
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *User
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreAuthenticateFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreAuthenticateFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
+// UsersStoreCreateFunc describes the behavior when the Create method of the
+// parent MockUsersStore instance is invoked.
+type UsersStoreCreateFunc struct {
+ defaultHook func(context.Context, string, string, CreateUserOpts) (*User, error)
+ hooks []func(context.Context, string, string, CreateUserOpts) (*User, error)
+ history []UsersStoreCreateFuncCall
+ mutex sync.Mutex
+}
+
+// Create delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockUsersStore) Create(v0 context.Context, v1 string, v2 string, v3 CreateUserOpts) (*User, error) {
+ r0, r1 := m.CreateFunc.nextHook()(v0, v1, v2, v3)
+ m.CreateFunc.appendCall(UsersStoreCreateFuncCall{v0, v1, v2, v3, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the Create method of the
+// parent MockUsersStore instance is invoked and the hook queue is empty.
+func (f *UsersStoreCreateFunc) SetDefaultHook(hook func(context.Context, string, string, CreateUserOpts) (*User, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Create method of the parent MockUsersStore instance invokes the hook at
+// the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *UsersStoreCreateFunc) PushHook(hook func(context.Context, string, string, CreateUserOpts) (*User, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreCreateFunc) SetDefaultReturn(r0 *User, r1 error) {
+ f.SetDefaultHook(func(context.Context, string, string, CreateUserOpts) (*User, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreCreateFunc) PushReturn(r0 *User, r1 error) {
+ f.PushHook(func(context.Context, string, string, CreateUserOpts) (*User, error) {
+ return r0, r1
+ })
+}
+
+func (f *UsersStoreCreateFunc) nextHook() func(context.Context, string, string, CreateUserOpts) (*User, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *UsersStoreCreateFunc) appendCall(r0 UsersStoreCreateFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreCreateFuncCall objects describing
+// the invocations of this function.
+func (f *UsersStoreCreateFunc) History() []UsersStoreCreateFuncCall {
+ f.mutex.Lock()
+ history := make([]UsersStoreCreateFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// UsersStoreCreateFuncCall is an object that describes an invocation of
+// method Create on an instance of MockUsersStore.
+type UsersStoreCreateFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 string
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 string
+ // Arg3 is the value of the 4th argument passed to this method
+ // invocation.
+ Arg3 CreateUserOpts
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *User
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreCreateFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreCreateFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
+// UsersStoreGetByEmailFunc describes the behavior when the GetByEmail
+// method of the parent MockUsersStore instance is invoked.
+type UsersStoreGetByEmailFunc struct {
+ defaultHook func(context.Context, string) (*User, error)
+ hooks []func(context.Context, string) (*User, error)
+ history []UsersStoreGetByEmailFuncCall
+ mutex sync.Mutex
+}
+
+// GetByEmail delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockUsersStore) GetByEmail(v0 context.Context, v1 string) (*User, error) {
+ r0, r1 := m.GetByEmailFunc.nextHook()(v0, v1)
+ m.GetByEmailFunc.appendCall(UsersStoreGetByEmailFuncCall{v0, v1, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByEmail method of
+// the parent MockUsersStore instance is invoked and the hook queue is
+// empty.
+func (f *UsersStoreGetByEmailFunc) SetDefaultHook(hook func(context.Context, string) (*User, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByEmail method of the parent MockUsersStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *UsersStoreGetByEmailFunc) PushHook(hook func(context.Context, string) (*User, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreGetByEmailFunc) SetDefaultReturn(r0 *User, r1 error) {
+ f.SetDefaultHook(func(context.Context, string) (*User, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreGetByEmailFunc) PushReturn(r0 *User, r1 error) {
+ f.PushHook(func(context.Context, string) (*User, error) {
+ return r0, r1
+ })
+}
+
+func (f *UsersStoreGetByEmailFunc) nextHook() func(context.Context, string) (*User, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *UsersStoreGetByEmailFunc) appendCall(r0 UsersStoreGetByEmailFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreGetByEmailFuncCall objects
+// describing the invocations of this function.
+func (f *UsersStoreGetByEmailFunc) History() []UsersStoreGetByEmailFuncCall {
+ f.mutex.Lock()
+ history := make([]UsersStoreGetByEmailFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// UsersStoreGetByEmailFuncCall is an object that describes an invocation of
+// method GetByEmail on an instance of MockUsersStore.
+type UsersStoreGetByEmailFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 string
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *User
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreGetByEmailFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreGetByEmailFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
+// UsersStoreGetByIDFunc describes the behavior when the GetByID method of
+// the parent MockUsersStore instance is invoked.
+type UsersStoreGetByIDFunc struct {
+ defaultHook func(context.Context, int64) (*User, error)
+ hooks []func(context.Context, int64) (*User, error)
+ history []UsersStoreGetByIDFuncCall
+ mutex sync.Mutex
+}
+
+// GetByID delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockUsersStore) GetByID(v0 context.Context, v1 int64) (*User, error) {
+ r0, r1 := m.GetByIDFunc.nextHook()(v0, v1)
+ m.GetByIDFunc.appendCall(UsersStoreGetByIDFuncCall{v0, v1, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByID method of
+// the parent MockUsersStore instance is invoked and the hook queue is
+// empty.
+func (f *UsersStoreGetByIDFunc) SetDefaultHook(hook func(context.Context, int64) (*User, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByID method of the parent MockUsersStore instance invokes the hook at
+// the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *UsersStoreGetByIDFunc) PushHook(hook func(context.Context, int64) (*User, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreGetByIDFunc) SetDefaultReturn(r0 *User, r1 error) {
+ f.SetDefaultHook(func(context.Context, int64) (*User, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreGetByIDFunc) PushReturn(r0 *User, r1 error) {
+ f.PushHook(func(context.Context, int64) (*User, error) {
+ return r0, r1
+ })
+}
+
+func (f *UsersStoreGetByIDFunc) nextHook() func(context.Context, int64) (*User, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *UsersStoreGetByIDFunc) appendCall(r0 UsersStoreGetByIDFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreGetByIDFuncCall objects
+// describing the invocations of this function.
+func (f *UsersStoreGetByIDFunc) History() []UsersStoreGetByIDFuncCall {
+ f.mutex.Lock()
+ history := make([]UsersStoreGetByIDFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// UsersStoreGetByIDFuncCall is an object that describes an invocation of
+// method GetByID on an instance of MockUsersStore.
+type UsersStoreGetByIDFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *User
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreGetByIDFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreGetByIDFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
+// UsersStoreGetByUsernameFunc describes the behavior when the GetByUsername
+// method of the parent MockUsersStore instance is invoked.
+type UsersStoreGetByUsernameFunc struct {
+ defaultHook func(context.Context, string) (*User, error)
+ hooks []func(context.Context, string) (*User, error)
+ history []UsersStoreGetByUsernameFuncCall
+ mutex sync.Mutex
+}
+
+// GetByUsername delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockUsersStore) GetByUsername(v0 context.Context, v1 string) (*User, error) {
+ r0, r1 := m.GetByUsernameFunc.nextHook()(v0, v1)
+ m.GetByUsernameFunc.appendCall(UsersStoreGetByUsernameFuncCall{v0, v1, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByUsername method
+// of the parent MockUsersStore instance is invoked and the hook queue is
+// empty.
+func (f *UsersStoreGetByUsernameFunc) SetDefaultHook(hook func(context.Context, string) (*User, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByUsername method of the parent MockUsersStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *UsersStoreGetByUsernameFunc) PushHook(hook func(context.Context, string) (*User, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreGetByUsernameFunc) SetDefaultReturn(r0 *User, r1 error) {
+ f.SetDefaultHook(func(context.Context, string) (*User, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreGetByUsernameFunc) PushReturn(r0 *User, r1 error) {
+ f.PushHook(func(context.Context, string) (*User, error) {
+ return r0, r1
+ })
+}
+
+func (f *UsersStoreGetByUsernameFunc) nextHook() func(context.Context, string) (*User, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *UsersStoreGetByUsernameFunc) appendCall(r0 UsersStoreGetByUsernameFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreGetByUsernameFuncCall objects
+// describing the invocations of this function.
+func (f *UsersStoreGetByUsernameFunc) History() []UsersStoreGetByUsernameFuncCall {
+ f.mutex.Lock()
+ history := make([]UsersStoreGetByUsernameFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// UsersStoreGetByUsernameFuncCall is an object that describes an invocation
+// of method GetByUsername on an instance of MockUsersStore.
+type UsersStoreGetByUsernameFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 string
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *User
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreGetByUsernameFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreGetByUsernameFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
// MockLoginSourceFileStore is a mock implementation of the
// loginSourceFileStore interface (from the package
// gogs.io/gogs/internal/db) used for unit testing.
diff --git a/internal/db/user.go b/internal/db/user.go
index c6bee120..ebfbf082 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -942,7 +942,8 @@ func GetUserByID(id int64) (*User, error) {
// GetAssigneeByID returns the user with read access of repository by given ID.
func GetAssigneeByID(repo *Repository, userID int64) (*User, error) {
- if !Perms.Authorize(context.TODO(), userID, repo.ID, AccessModeRead,
+ ctx := context.TODO()
+ if !Perms.Authorize(ctx, userID, repo.ID, AccessModeRead,
AccessModeOptions{
OwnerID: repo.OwnerID,
Private: repo.IsPrivate,
@@ -950,7 +951,7 @@ func GetAssigneeByID(repo *Repository, userID int64) (*User, error) {
) {
return nil, ErrUserNotExist{args: map[string]interface{}{"userID": userID}}
}
- return Users.GetByID(userID)
+ return Users.GetByID(ctx, userID)
}
// GetUserByName returns a user by given name.
diff --git a/internal/db/users.go b/internal/db/users.go
index 5c8f2d38..cac22c44 100644
--- a/internal/db/users.go
+++ b/internal/db/users.go
@@ -23,8 +23,8 @@ import (
//
// NOTE: All methods are sorted in alphabetical order.
type UsersStore interface {
- // Authenticate validates username and password via given login source ID.
- // It returns ErrUserNotExist when the user was not found.
+ // Authenticate validates username and password via given login source ID. It
+ // returns ErrUserNotExist when the user was not found.
//
// When the "loginSourceID" is negative, it aborts the process and returns
// ErrUserNotExist if the user was not found in the database.
@@ -34,23 +34,25 @@ type UsersStore interface {
//
// When the "loginSourceID" is positive, it tries to authenticate via given
// login source and creates a new user when not yet exists in the database.
- Authenticate(username, password string, loginSourceID int64) (*User, error)
- // Create creates a new user and persists to database.
- // It returns ErrUserAlreadyExist when a user with same name already exists,
- // or ErrEmailAlreadyUsed if the email has been used by another user.
- Create(username, email string, opts CreateUserOpts) (*User, error)
- // GetByEmail returns the user (not organization) with given email.
- // It ignores records with unverified emails and returns ErrUserNotExist when not found.
- GetByEmail(email string) (*User, error)
- // GetByID returns the user with given ID. It returns ErrUserNotExist when not found.
- GetByID(id int64) (*User, error)
- // GetByUsername returns the user with given username. It returns ErrUserNotExist when not found.
- GetByUsername(username string) (*User, error)
+ Authenticate(ctx context.Context, username, password string, loginSourceID int64) (*User, error)
+ // Create creates a new user and persists to database. It returns
+ // ErrUserAlreadyExist when a user with same name already exists, or
+ // ErrEmailAlreadyUsed if the email has been used by another user.
+ Create(ctx context.Context, username, email string, opts CreateUserOpts) (*User, error)
+ // GetByEmail returns the user (not organization) with given email. It ignores
+ // records with unverified emails and returns ErrUserNotExist when not found.
+ GetByEmail(ctx context.Context, email string) (*User, error)
+ // GetByID returns the user with given ID. It returns ErrUserNotExist when not
+ // found.
+ GetByID(ctx context.Context, id int64) (*User, error)
+ // GetByUsername returns the user with given username. It returns
+ // ErrUserNotExist when not found.
+ GetByUsername(ctx context.Context, username string) (*User, error)
}
var Users UsersStore
-// NOTE: This is a GORM create hook.
+// BeforeCreate implements the GORM create hook.
func (u *User) BeforeCreate(tx *gorm.DB) error {
if u.CreatedUnix == 0 {
u.CreatedUnix = tx.NowFunc().Unix()
@@ -59,7 +61,7 @@ func (u *User) BeforeCreate(tx *gorm.DB) error {
return nil
}
-// NOTE: This is a GORM query hook.
+// AfterFind implements the GORM query hook.
func (u *User) AfterFind(_ *gorm.DB) error {
u.Created = time.Unix(u.CreatedUnix, 0).Local()
u.Updated = time.Unix(u.UpdatedUnix, 0).Local()
@@ -80,16 +82,14 @@ func (err ErrLoginSourceMismatch) Error() string {
return fmt.Sprintf("login source mismatch: %v", err.args)
}
-func (db *users) Authenticate(login, password string, loginSourceID int64) (*User, error) {
- ctx := context.TODO()
-
+func (db *users) Authenticate(ctx context.Context, login, password string, loginSourceID int64) (*User, error) {
login = strings.ToLower(login)
- var query *gorm.DB
+ query := db.WithContext(ctx)
if strings.Contains(login, "@") {
- query = db.Where("email = ?", login)
+ query = query.Where("email = ?", login)
} else {
- query = db.Where("lower_name = ?", login)
+ query = query.Where("lower_name = ?", login)
}
user := new(User)
@@ -153,15 +153,17 @@ func (db *users) Authenticate(login, password string, loginSourceID int64) (*Use
return nil, fmt.Errorf("invalid pattern for attribute 'username' [%s]: must be valid alpha or numeric or dash(-_) or dot characters", extAccount.Name)
}
- return Users.Create(extAccount.Name, extAccount.Email, CreateUserOpts{
- FullName: extAccount.FullName,
- LoginSource: authSourceID,
- LoginName: extAccount.Login,
- Location: extAccount.Location,
- Website: extAccount.Website,
- Activated: true,
- Admin: extAccount.Admin,
- })
+ return db.Create(ctx, extAccount.Name, extAccount.Email,
+ CreateUserOpts{
+ FullName: extAccount.FullName,
+ LoginSource: authSourceID,
+ LoginName: extAccount.Login,
+ Location: extAccount.Location,
+ Website: extAccount.Website,
+ Activated: true,
+ Admin: extAccount.Admin,
+ },
+ )
}
type CreateUserOpts struct {
@@ -209,20 +211,20 @@ func (err ErrEmailAlreadyUsed) Error() string {
return fmt.Sprintf("email has been used: %v", err.args)
}
-func (db *users) Create(username, email string, opts CreateUserOpts) (*User, error) {
+func (db *users) Create(ctx context.Context, username, email string, opts CreateUserOpts) (*User, error) {
err := isUsernameAllowed(username)
if err != nil {
return nil, err
}
- _, err = db.GetByUsername(username)
+ _, err = db.GetByUsername(ctx, username)
if err == nil {
return nil, ErrUserAlreadyExist{args: errutil.Args{"name": username}}
} else if !IsErrUserNotExist(err) {
return nil, err
}
- _, err = db.GetByEmail(email)
+ _, err = db.GetByEmail(ctx, email)
if err == nil {
return nil, ErrEmailAlreadyUsed{args: errutil.Args{"email": email}}
} else if !IsErrUserNotExist(err) {
@@ -256,7 +258,7 @@ func (db *users) Create(username, email string, opts CreateUserOpts) (*User, err
}
user.EncodePassword()
- return user, db.DB.Create(user).Error
+ return user, db.WithContext(ctx).Create(user).Error
}
var _ errutil.NotFound = (*ErrUserNotExist)(nil)
@@ -278,7 +280,7 @@ func (ErrUserNotExist) NotFound() bool {
return true
}
-func (db *users) GetByEmail(email string) (*User, error) {
+func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
email = strings.ToLower(email)
if email == "" {
@@ -287,7 +289,10 @@ func (db *users) GetByEmail(email string) (*User, error) {
// First try to find the user by primary email
user := new(User)
- err := db.Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true).First(user).Error
+ err := db.WithContext(ctx).
+ Where("email = ? AND type = ? AND is_active = ?", email, UserIndividual, true).
+ First(user).
+ Error
if err == nil {
return user, nil
} else if err != gorm.ErrRecordNotFound {
@@ -296,7 +301,10 @@ func (db *users) GetByEmail(email string) (*User, error) {
// Otherwise, check activated email addresses
emailAddress := new(EmailAddress)
- err = db.Where("email = ? AND is_activated = ?", email, true).First(emailAddress).Error
+ err = db.WithContext(ctx).
+ Where("email = ? AND is_activated = ?", email, true).
+ First(emailAddress).
+ Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
@@ -304,12 +312,12 @@ func (db *users) GetByEmail(email string) (*User, error) {
return nil, err
}
- return db.GetByID(emailAddress.UID)
+ return db.GetByID(ctx, emailAddress.UID)
}
-func (db *users) GetByID(id int64) (*User, error) {
+func (db *users) GetByID(ctx context.Context, id int64) (*User, error) {
user := new(User)
- err := db.Where("id = ?", id).First(user).Error
+ err := db.WithContext(ctx).Where("id = ?", id).First(user).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"userID": id}}
@@ -319,9 +327,9 @@ func (db *users) GetByID(id int64) (*User, error) {
return user, nil
}
-func (db *users) GetByUsername(username string) (*User, error) {
+func (db *users) GetByUsername(ctx context.Context, username string) (*User, error) {
user := new(User)
- err := db.Where("lower_name = ?", strings.ToLower(username)).First(user).Error
+ err := db.WithContext(ctx).Where("lower_name = ?", strings.ToLower(username)).First(user).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotExist{args: errutil.Args{"name": username}}
diff --git a/internal/db/users_test.go b/internal/db/users_test.go
index dac2c208..d691110a 100644
--- a/internal/db/users_test.go
+++ b/internal/db/users_test.go
@@ -5,16 +5,18 @@
package db
import (
+ "context"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"gogs.io/gogs/internal/auth"
"gogs.io/gogs/internal/errutil"
)
-func Test_users(t *testing.T) {
+func TestUsers(t *testing.T) {
if testing.Short() {
t.Skip()
}
@@ -30,18 +32,16 @@ func Test_users(t *testing.T) {
name string
test func(*testing.T, *users)
}{
- {"Authenticate", test_users_Authenticate},
- {"Create", test_users_Create},
- {"GetByEmail", test_users_GetByEmail},
- {"GetByID", test_users_GetByID},
- {"GetByUsername", test_users_GetByUsername},
+ {"Authenticate", usersAuthenticate},
+ {"Create", usersCreate},
+ {"GetByEmail", usersGetByEmail},
+ {"GetByID", usersGetByID},
+ {"GetByUsername", usersGetByUsername},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB, tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
})
tc.test(t, db)
})
@@ -53,187 +53,165 @@ func Test_users(t *testing.T) {
// TODO: Only local account is tested, tests for external account will be added
// along with addressing https://github.com/gogs/gogs/issues/6115.
-func test_users_Authenticate(t *testing.T, db *users) {
+func usersAuthenticate(t *testing.T, db *users) {
+ ctx := context.Background()
+
password := "pa$$word"
- alice, err := db.Create("alice", "alice@example.com", CreateUserOpts{
- Password: password,
- })
- if err != nil {
- t.Fatal(err)
- }
+ alice, err := db.Create(ctx, "alice", "alice@example.com",
+ CreateUserOpts{
+ Password: password,
+ },
+ )
+ require.NoError(t, err)
t.Run("user not found", func(t *testing.T) {
- _, err := db.Authenticate("bob", password, -1)
- expErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": "bob"}}
- assert.Equal(t, expErr, err)
+ _, err := db.Authenticate(ctx, "bob", password, -1)
+ wantErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": "bob"}}
+ assert.Equal(t, wantErr, err)
})
t.Run("invalid password", func(t *testing.T) {
- _, err := db.Authenticate(alice.Name, "bad_password", -1)
- expErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": alice.Name, "userID": alice.ID}}
- assert.Equal(t, expErr, err)
+ _, err := db.Authenticate(ctx, alice.Name, "bad_password", -1)
+ wantErr := auth.ErrBadCredentials{Args: map[string]interface{}{"login": alice.Name, "userID": alice.ID}}
+ assert.Equal(t, wantErr, err)
})
t.Run("via email and password", func(t *testing.T) {
- user, err := db.Authenticate(alice.Email, password, -1)
- if err != nil {
- t.Fatal(err)
- }
+ user, err := db.Authenticate(ctx, alice.Email, password, -1)
+ require.NoError(t, err)
assert.Equal(t, alice.Name, user.Name)
})
t.Run("via username and password", func(t *testing.T) {
- user, err := db.Authenticate(alice.Name, password, -1)
- if err != nil {
- t.Fatal(err)
- }
+ user, err := db.Authenticate(ctx, alice.Name, password, -1)
+ require.NoError(t, err)
assert.Equal(t, alice.Name, user.Name)
})
}
-func test_users_Create(t *testing.T, db *users) {
- alice, err := db.Create("alice", "alice@example.com", CreateUserOpts{
- Activated: true,
- })
- if err != nil {
- t.Fatal(err)
- }
+func usersCreate(t *testing.T, db *users) {
+ ctx := context.Background()
+
+ alice, err := db.Create(ctx, "alice", "alice@example.com",
+ CreateUserOpts{
+ Activated: true,
+ },
+ )
+ require.NoError(t, err)
t.Run("name not allowed", func(t *testing.T) {
- _, err := db.Create("-", "", CreateUserOpts{})
- expErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "name": "-"}}
- assert.Equal(t, expErr, err)
+ _, err := db.Create(ctx, "-", "", CreateUserOpts{})
+ wantErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "name": "-"}}
+ assert.Equal(t, wantErr, err)
})
t.Run("name already exists", func(t *testing.T) {
- _, err := db.Create(alice.Name, "", CreateUserOpts{})
- expErr := ErrUserAlreadyExist{args: errutil.Args{"name": alice.Name}}
- assert.Equal(t, expErr, err)
+ _, err := db.Create(ctx, alice.Name, "", CreateUserOpts{})
+ wantErr := ErrUserAlreadyExist{args: errutil.Args{"name": alice.Name}}
+ assert.Equal(t, wantErr, err)
})
t.Run("email already exists", func(t *testing.T) {
- _, err := db.Create("bob", alice.Email, CreateUserOpts{})
- expErr := ErrEmailAlreadyUsed{args: errutil.Args{"email": alice.Email}}
- assert.Equal(t, expErr, err)
+ _, err := db.Create(ctx, "bob", alice.Email, CreateUserOpts{})
+ wantErr := ErrEmailAlreadyUsed{args: errutil.Args{"email": alice.Email}}
+ assert.Equal(t, wantErr, err)
})
- user, err := db.GetByUsername(alice.Name)
- if err != nil {
- t.Fatal(err)
- }
+ user, err := db.GetByUsername(ctx, alice.Name)
+ require.NoError(t, err)
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Created.UTC().Format(time.RFC3339))
assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339))
}
-func test_users_GetByEmail(t *testing.T, db *users) {
+func usersGetByEmail(t *testing.T, db *users) {
+ ctx := context.Background()
+
t.Run("empty email", func(t *testing.T) {
- _, err := db.GetByEmail("")
- expErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
- assert.Equal(t, expErr, err)
+ _, err := db.GetByEmail(ctx, "")
+ wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}}
+ assert.Equal(t, wantErr, err)
})
t.Run("ignore organization", func(t *testing.T) {
// TODO: Use Orgs.Create to replace SQL hack when the method is available.
- org, err := db.Create("gogs", "gogs@exmaple.com", CreateUserOpts{})
- if err != nil {
- t.Fatal(err)
- }
+ org, err := db.Create(ctx, "gogs", "gogs@exmaple.com", CreateUserOpts{})
+ require.NoError(t, err)
err = db.Model(&User{}).Where("id", org.ID).UpdateColumn("type", UserOrganization).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
- _, err = db.GetByEmail(org.Email)
- expErr := ErrUserNotExist{args: errutil.Args{"email": org.Email}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByEmail(ctx, org.Email)
+ wantErr := ErrUserNotExist{args: errutil.Args{"email": org.Email}}
+ assert.Equal(t, wantErr, err)
})
t.Run("by primary email", func(t *testing.T) {
- alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{})
- if err != nil {
- t.Fatal(err)
- }
+ alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{})
+ require.NoError(t, err)
- _, err = db.GetByEmail(alice.Email)
- expErr := ErrUserNotExist{args: errutil.Args{"email": alice.Email}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByEmail(ctx, alice.Email)
+ wantErr := ErrUserNotExist{args: errutil.Args{"email": alice.Email}}
+ assert.Equal(t, wantErr, err)
// Mark user as activated
// TODO: Use UserEmails.Verify to replace SQL hack when the method is available.
err = db.Model(&User{}).Where("id", alice.ID).UpdateColumn("is_active", true).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
- user, err := db.GetByEmail(alice.Email)
- if err != nil {
- t.Fatal(err)
- }
+ user, err := db.GetByEmail(ctx, alice.Email)
+ require.NoError(t, err)
assert.Equal(t, alice.Name, user.Name)
})
t.Run("by secondary email", func(t *testing.T) {
- bob, err := db.Create("bob", "bob@example.com", CreateUserOpts{})
- if err != nil {
- t.Fatal(err)
- }
+ bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOpts{})
+ require.NoError(t, err)
// TODO: Use UserEmails.Create to replace SQL hack when the method is available.
email2 := "bob2@exmaple.com"
err = db.Exec(`INSERT INTO email_address (uid, email) VALUES (?, ?)`, bob.ID, email2).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
- _, err = db.GetByEmail(email2)
- expErr := ErrUserNotExist{args: errutil.Args{"email": email2}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByEmail(ctx, email2)
+ wantErr := ErrUserNotExist{args: errutil.Args{"email": email2}}
+ assert.Equal(t, wantErr, err)
// TODO: Use UserEmails.Verify to replace SQL hack when the method is available.
err = db.Exec(`UPDATE email_address SET is_activated = ? WHERE email = ?`, true, email2).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
- user, err := db.GetByEmail(email2)
- if err != nil {
- t.Fatal(err)
- }
+ user, err := db.GetByEmail(ctx, email2)
+ require.NoError(t, err)
assert.Equal(t, bob.Name, user.Name)
})
}
-func test_users_GetByID(t *testing.T, db *users) {
- alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{})
- if err != nil {
- t.Fatal(err)
- }
+func usersGetByID(t *testing.T, db *users) {
+ ctx := context.Background()
- user, err := db.GetByID(alice.ID)
- if err != nil {
- t.Fatal(err)
- }
+ alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{})
+ require.NoError(t, err)
+
+ user, err := db.GetByID(ctx, alice.ID)
+ require.NoError(t, err)
assert.Equal(t, alice.Name, user.Name)
- _, err = db.GetByID(404)
- expErr := ErrUserNotExist{args: errutil.Args{"userID": int64(404)}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByID(ctx, 404)
+ wantErr := ErrUserNotExist{args: errutil.Args{"userID": int64(404)}}
+ assert.Equal(t, wantErr, err)
}
-func test_users_GetByUsername(t *testing.T, db *users) {
- alice, err := db.Create("alice", "alice@exmaple.com", CreateUserOpts{})
- if err != nil {
- t.Fatal(err)
- }
+func usersGetByUsername(t *testing.T, db *users) {
+ ctx := context.Background()
- user, err := db.GetByUsername(alice.Name)
- if err != nil {
- t.Fatal(err)
- }
+ alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOpts{})
+ require.NoError(t, err)
+
+ user, err := db.GetByUsername(ctx, alice.Name)
+ require.NoError(t, err)
assert.Equal(t, alice.Name, user.Name)
- _, err = db.GetByUsername("bad_username")
- expErr := ErrUserNotExist{args: errutil.Args{"name": "bad_username"}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByUsername(ctx, "bad_username")
+ wantErr := ErrUserNotExist{args: errutil.Args{"name": "bad_username"}}
+ assert.Equal(t, wantErr, err)
}
diff --git a/internal/route/lfs/route.go b/internal/route/lfs/route.go
index f5195837..94c42fea 100644
--- a/internal/route/lfs/route.go
+++ b/internal/route/lfs/route.go
@@ -58,7 +58,7 @@ func authenticate() macaron.Handler {
return
}
- user, err := db.Users.Authenticate(username, password, -1)
+ user, err := db.Users.Authenticate(c.Req.Context(), username, password, -1)
if err != nil && !auth.IsErrBadCredentials(err) {
internalServerError(c.Resp)
log.Error("Failed to authenticate user [name: %s]: %v", username, err)
@@ -86,7 +86,7 @@ func authenticate() macaron.Handler {
log.Error("Failed to touch access token: %v", err)
}
- user, err = db.Users.GetByID(token.UserID)
+ user, err = db.Users.GetByID(c.Req.Context(), token.UserID)
if err != nil {
// Once we found the token, we're supposed to find its related user,
// thus any error is unexpected.
@@ -108,7 +108,7 @@ func authorize(mode db.AccessMode) macaron.Handler {
username := c.Params(":username")
reponame := strings.TrimSuffix(c.Params(":reponame"), ".git")
- owner, err := db.Users.GetByUsername(username)
+ owner, err := db.Users.GetByUsername(c.Req.Context(), username)
if err != nil {
if db.IsErrUserNotExist(err) {
c.Status(http.StatusNotFound)
diff --git a/internal/route/lfs/route_test.go b/internal/route/lfs/route_test.go
index 57d0728f..6202695b 100644
--- a/internal/route/lfs/route_test.go
+++ b/internal/route/lfs/route_test.go
@@ -30,7 +30,7 @@ func Test_authenticate(t *testing.T) {
tests := []struct {
name string
header http.Header
- mockUsersStore *db.MockUsersStore
+ mockUsersStore func() db.UsersStore
mockTwoFactorsStore *db.MockTwoFactorsStore
mockAccessTokensStore func() db.AccessTokensStore
expStatusCode int
@@ -51,10 +51,10 @@ func Test_authenticate(t *testing.T) {
header: http.Header{
"Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
},
- mockUsersStore: &db.MockUsersStore{
- MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
- return &db.User{}, nil
- },
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.AuthenticateFunc.SetDefaultReturn(&db.User{}, nil)
+ return mock
},
mockTwoFactorsStore: &db.MockTwoFactorsStore{
MockIsUserEnabled: func(userID int64) bool {
@@ -70,10 +70,10 @@ func Test_authenticate(t *testing.T) {
header: http.Header{
"Authorization": []string{"Basic dXNlcm5hbWU="},
},
- mockUsersStore: &db.MockUsersStore{
- MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
- return nil, auth.ErrBadCredentials{}
- },
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.AuthenticateFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{})
+ return mock
},
mockAccessTokensStore: func() db.AccessTokensStore {
mock := db.NewMockAccessTokensStore()
@@ -93,10 +93,10 @@ func Test_authenticate(t *testing.T) {
header: http.Header{
"Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
},
- mockUsersStore: &db.MockUsersStore{
- MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
- return &db.User{ID: 1, Name: "unknwon"}, nil
- },
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.AuthenticateFunc.SetDefaultReturn(&db.User{ID: 1, Name: "unknwon"}, nil)
+ return mock
},
mockTwoFactorsStore: &db.MockTwoFactorsStore{
MockIsUserEnabled: func(userID int64) bool {
@@ -112,13 +112,11 @@ func Test_authenticate(t *testing.T) {
header: http.Header{
"Authorization": []string{"Basic dXNlcm5hbWU="},
},
- mockUsersStore: &db.MockUsersStore{
- MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
- return nil, auth.ErrBadCredentials{}
- },
- MockGetByID: func(id int64) (*db.User, error) {
- return &db.User{ID: 1, Name: "unknwon"}, nil
- },
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.AuthenticateFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{})
+ mock.GetByIDFunc.SetDefaultReturn(&db.User{ID: 1, Name: "unknwon"}, nil)
+ return mock
},
mockAccessTokensStore: func() db.AccessTokensStore {
mock := db.NewMockAccessTokensStore()
@@ -132,9 +130,10 @@ func Test_authenticate(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- db.SetMockUsersStore(t, test.mockUsersStore)
+ if test.mockUsersStore != nil {
+ db.SetMockUsersStore(t, test.mockUsersStore())
+ }
db.SetMockTwoFactorsStore(t, test.mockTwoFactorsStore)
-
if test.mockAccessTokensStore != nil {
db.SetMockAccessTokensStore(t, test.mockAccessTokensStore())
}
@@ -165,7 +164,7 @@ func Test_authorize(t *testing.T) {
tests := []struct {
name string
authroize macaron.Handler
- mockUsersStore *db.MockUsersStore
+ mockUsersStore func() db.UsersStore
mockReposStore *db.MockReposStore
mockPermsStore func() db.PermsStore
expStatusCode int
@@ -174,20 +173,22 @@ func Test_authorize(t *testing.T) {
{
name: "user does not exist",
authroize: authorize(db.AccessModeNone),
- mockUsersStore: &db.MockUsersStore{
- MockGetByUsername: func(username string) (*db.User, error) {
- return nil, db.ErrUserNotExist{}
- },
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.GetByUsernameFunc.SetDefaultReturn(nil, db.ErrUserNotExist{})
+ return mock
},
expStatusCode: http.StatusNotFound,
},
{
name: "repository does not exist",
authroize: authorize(db.AccessModeNone),
- mockUsersStore: &db.MockUsersStore{
- MockGetByUsername: func(username string) (*db.User, error) {
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) {
return &db.User{Name: username}, nil
- },
+ })
+ return mock
},
mockReposStore: &db.MockReposStore{
MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
@@ -199,10 +200,12 @@ func Test_authorize(t *testing.T) {
{
name: "actor is not authorized",
authroize: authorize(db.AccessModeWrite),
- mockUsersStore: &db.MockUsersStore{
- MockGetByUsername: func(username string) (*db.User, error) {
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) {
return &db.User{Name: username}, nil
- },
+ })
+ return mock
},
mockReposStore: &db.MockReposStore{
MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
@@ -222,10 +225,12 @@ func Test_authorize(t *testing.T) {
{
name: "actor is authorized",
authroize: authorize(db.AccessModeRead),
- mockUsersStore: &db.MockUsersStore{
- MockGetByUsername: func(username string) (*db.User, error) {
+ mockUsersStore: func() db.UsersStore {
+ mock := db.NewMockUsersStore()
+ mock.GetByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*db.User, error) {
return &db.User{Name: username}, nil
- },
+ })
+ return mock
},
mockReposStore: &db.MockReposStore{
MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
@@ -245,9 +250,10 @@ func Test_authorize(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- db.SetMockUsersStore(t, test.mockUsersStore)
+ if test.mockUsersStore != nil {
+ db.SetMockUsersStore(t, test.mockUsersStore())
+ }
db.SetMockReposStore(t, test.mockReposStore)
-
if test.mockPermsStore != nil {
db.SetMockPermsStore(t, test.mockPermsStore())
}
diff --git a/internal/route/org/setting.go b/internal/route/org/setting.go
index 94e9b7a6..e15b9faf 100644
--- a/internal/route/org/setting.go
+++ b/internal/route/org/setting.go
@@ -109,7 +109,7 @@ func SettingsDelete(c *context.Context) {
org := c.Org.Organization
if c.Req.Method == "POST" {
- if _, err := db.Users.Authenticate(c.User.Name, c.Query("password"), c.User.LoginSource); err != nil {
+ if _, err := db.Users.Authenticate(c.Req.Context(), c.User.Name, c.Query("password"), c.User.LoginSource); err != nil {
if auth.IsErrBadCredentials(err) {
c.RenderWithErr(c.Tr("form.enterred_invalid_password"), SETTINGS_DELETE, nil)
} else {
diff --git a/internal/route/repo/http.go b/internal/route/repo/http.go
index 7e970194..888ad4d8 100644
--- a/internal/route/repo/http.go
+++ b/internal/route/repo/http.go
@@ -65,7 +65,7 @@ func HTTPContexter() macaron.Handler {
strings.HasSuffix(c.Req.URL.Path, "git-upload-pack") ||
c.Req.Method == "GET"
- owner, err := db.Users.GetByUsername(ownerName)
+ owner, err := db.Users.GetByUsername(c.Req.Context(), ownerName)
if err != nil {
if db.IsErrUserNotExist(err) {
c.Status(http.StatusNotFound)
@@ -123,7 +123,7 @@ func HTTPContexter() macaron.Handler {
return
}
- authUser, err := db.Users.Authenticate(authUsername, authPassword, -1)
+ authUser, err := db.Users.Authenticate(c.Req.Context(), authUsername, authPassword, -1)
if err != nil && !auth.IsErrBadCredentials(err) {
c.Status(http.StatusInternalServerError)
log.Error("Failed to authenticate user [name: %s]: %v", authUsername, err)
@@ -146,7 +146,7 @@ func HTTPContexter() macaron.Handler {
log.Error("Failed to touch access token: %v", err)
}
- authUser, err = db.Users.GetByID(token.UserID)
+ authUser, err = db.Users.GetByID(c.Req.Context(), token.UserID)
if err != nil {
// Once we found token, we're supposed to find its related user,
// thus any error is unexpected.
diff --git a/internal/route/repo/tasks.go b/internal/route/repo/tasks.go
index 81e85e2a..c5b555b9 100644
--- a/internal/route/repo/tasks.go
+++ b/internal/route/repo/tasks.go
@@ -26,7 +26,7 @@ func TriggerTask(c *macaron.Context) {
username := c.Params(":username")
reponame := c.Params(":reponame")
- owner, err := db.Users.GetByUsername(username)
+ owner, err := db.Users.GetByUsername(c.Req.Context(), username)
if err != nil {
if db.IsErrUserNotExist(err) {
c.Error(http.StatusBadRequest, "Owner does not exist")
@@ -55,7 +55,7 @@ func TriggerTask(c *macaron.Context) {
return
}
- pusher, err := db.Users.GetByID(pusherID)
+ pusher, err := db.Users.GetByID(c.Req.Context(), pusherID)
if err != nil {
if db.IsErrUserNotExist(err) {
c.Error(http.StatusBadRequest, "Pusher does not exist")
diff --git a/internal/route/user/auth.go b/internal/route/user/auth.go
index f8bbe7ab..b3b785c2 100644
--- a/internal/route/user/auth.go
+++ b/internal/route/user/auth.go
@@ -161,7 +161,7 @@ func LoginPost(c *context.Context, f form.SignIn) {
return
}
- u, err := db.Users.Authenticate(f.UserName, f.Password, f.LoginSource)
+ u, err := db.Users.Authenticate(c.Req.Context(), f.UserName, f.Password, f.LoginSource)
if err != nil {
switch errors.Cause(err).(type) {
case auth.ErrBadCredentials:
diff --git a/internal/route/user/setting.go b/internal/route/user/setting.go
index cf2226c5..c28a1747 100644
--- a/internal/route/user/setting.go
+++ b/internal/route/user/setting.go
@@ -640,7 +640,7 @@ func SettingsDelete(c *context.Context) {
c.PageIs("SettingsDelete")
if c.Req.Method == "POST" {
- if _, err := db.Users.Authenticate(c.User.Name, c.Query("password"), c.User.LoginSource); err != nil {
+ if _, err := db.Users.Authenticate(c.Req.Context(), c.User.Name, c.Query("password"), c.User.LoginSource); err != nil {
if auth.IsErrBadCredentials(err) {
c.RenderWithErr(c.Tr("form.enterred_invalid_password"), SETTINGS_DELETE, nil)
} else {