diff options
author | Joe Chen <jc@unknwon.io> | 2022-11-05 23:33:05 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-05 23:33:05 +0800 |
commit | 5fb29db2db04bc128af410867f1f602320eb5d66 (patch) | |
tree | 9d0b86702d872f8f5ab7d0691e511c52f38fde34 /internal/db | |
parent | b5d47b969258f3d644ad797b29901eb607f6b94f (diff) |
refactor(db): migrate methods off and delete deprecated methods from `user.go` (#7231)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/comment.go | 3 | ||||
-rw-r--r-- | internal/db/db.go | 1 | ||||
-rw-r--r-- | internal/db/email_addresses.go | 72 | ||||
-rw-r--r-- | internal/db/email_addresses_test.go | 66 | ||||
-rw-r--r-- | internal/db/issue.go | 5 | ||||
-rw-r--r-- | internal/db/issue_mail.go | 3 | ||||
-rw-r--r-- | internal/db/models.go | 2 | ||||
-rw-r--r-- | internal/db/org.go | 9 | ||||
-rw-r--r-- | internal/db/org_team.go | 3 | ||||
-rw-r--r-- | internal/db/orgs.go | 6 | ||||
-rw-r--r-- | internal/db/orgs_test.go | 12 | ||||
-rw-r--r-- | internal/db/repo.go | 10 | ||||
-rw-r--r-- | internal/db/update.go | 2 | ||||
-rw-r--r-- | internal/db/user.go | 161 | ||||
-rw-r--r-- | internal/db/user_mail.go | 17 | ||||
-rw-r--r-- | internal/db/users.go | 27 | ||||
-rw-r--r-- | internal/db/users_test.go | 63 | ||||
-rw-r--r-- | internal/db/wiki.go | 2 |
18 files changed, 274 insertions, 190 deletions
diff --git a/internal/db/comment.go b/internal/db/comment.go index 14b350e8..c5caa69d 100644 --- a/internal/db/comment.go +++ b/internal/db/comment.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" "time" @@ -94,7 +95,7 @@ func (c *Comment) AfterSet(colName string, _ xorm.Cell) { func (c *Comment) loadAttributes(e Engine) (err error) { if c.Poster == nil { - c.Poster, err = GetUserByID(c.PosterID) + c.Poster, err = Users.GetByID(context.TODO(), c.PosterID) if err != nil { if IsErrUserNotExist(err) { c.PosterID = -1 diff --git a/internal/db/db.go b/internal/db/db.go index dfd23d81..600bdaf0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -121,6 +121,7 @@ func Init(w logger.Writer) (*gorm.DB, error) { // Initialize stores, sorted in alphabetical order. AccessTokens = &accessTokens{DB: db} Actions = NewActionsStore(db) + EmailAddresses = NewEmailAddressesStore(db) Follows = NewFollowsStore(db) LoginSources = &loginSources{DB: db, files: sourceFiles} LFS = &lfs{DB: db} diff --git a/internal/db/email_addresses.go b/internal/db/email_addresses.go new file mode 100644 index 00000000..2929ffd0 --- /dev/null +++ b/internal/db/email_addresses.go @@ -0,0 +1,72 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "context" + "fmt" + + "gorm.io/gorm" + + "gogs.io/gogs/internal/errutil" +) + +// EmailAddressesStore is the persistent interface for email addresses. +// +// NOTE: All methods are sorted in alphabetical order. +type EmailAddressesStore interface { + // GetByEmail returns the email address with given email. It may return + // unverified email addresses and returns ErrEmailNotExist when not found. + GetByEmail(ctx context.Context, email string) (*EmailAddress, error) +} + +var EmailAddresses EmailAddressesStore + +var _ EmailAddressesStore = (*emailAddresses)(nil) + +type emailAddresses struct { + *gorm.DB +} + +// NewEmailAddressesStore returns a persistent interface for email addresses +// with given database connection. +func NewEmailAddressesStore(db *gorm.DB) EmailAddressesStore { + return &emailAddresses{DB: db} +} + +var _ errutil.NotFound = (*ErrEmailNotExist)(nil) + +type ErrEmailNotExist struct { + args errutil.Args +} + +func IsErrEmailAddressNotExist(err error) bool { + _, ok := err.(ErrEmailNotExist) + return ok +} + +func (err ErrEmailNotExist) Error() string { + return fmt.Sprintf("email address does not exist: %v", err.args) +} + +func (ErrEmailNotExist) NotFound() bool { + return true +} + +func (db *emailAddresses) GetByEmail(ctx context.Context, email string) (*EmailAddress, error) { + emailAddress := new(EmailAddress) + err := db.WithContext(ctx).Where("email = ?", email).First(emailAddress).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, ErrEmailNotExist{ + args: errutil.Args{ + "email": email, + }, + } + } + return nil, err + } + return emailAddress, nil +} diff --git a/internal/db/email_addresses_test.go b/internal/db/email_addresses_test.go new file mode 100644 index 00000000..b54ffbad --- /dev/null +++ b/internal/db/email_addresses_test.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gogs.io/gogs/internal/dbtest" + "gogs.io/gogs/internal/errutil" +) + +func TestEmailAddresses(t *testing.T) { + if testing.Short() { + t.Skip() + } + t.Parallel() + + tables := []interface{}{new(EmailAddress)} + db := &emailAddresses{ + DB: dbtest.NewDB(t, "emailAddresses", tables...), + } + + for _, tc := range []struct { + name string + test func(t *testing.T, db *emailAddresses) + }{ + {"GetByEmail", emailAddressesGetByEmail}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + err := clearTables(t, db.DB, tables...) + require.NoError(t, err) + }) + tc.test(t, db) + }) + if t.Failed() { + break + } + } +} + +func emailAddressesGetByEmail(t *testing.T, db *emailAddresses) { + ctx := context.Background() + + const testEmail = "alice@example.com" + _, err := db.GetByEmail(ctx, testEmail) + wantErr := ErrEmailNotExist{ + args: errutil.Args{ + "email": testEmail, + }, + } + assert.Equal(t, wantErr, err) + + // TODO: Use EmailAddresses.Create to replace SQL hack when the method is available. + err = db.Exec(`INSERT INTO email_address (uid, email) VALUES (1, ?)`, testEmail).Error + require.NoError(t, err) + got, err := db.GetByEmail(ctx, testEmail) + require.NoError(t, err) + assert.Equal(t, testEmail, got.Email) +} diff --git a/internal/db/issue.go b/internal/db/issue.go index a8909630..dbb34721 100644 --- a/internal/db/issue.go +++ b/internal/db/issue.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" "time" @@ -391,7 +392,7 @@ func (issue *Issue) GetAssignee() (err error) { return nil } - issue.Assignee, err = GetUserByID(issue.AssigneeID) + issue.Assignee, err = Users.GetByID(context.TODO(), issue.AssigneeID) if IsErrUserNotExist(err) { return nil } @@ -597,7 +598,7 @@ func (issue *Issue) ChangeAssignee(doer *User, assigneeID int64) (err error) { return fmt.Errorf("UpdateIssueUserByAssignee: %v", err) } - issue.Assignee, err = GetUserByID(issue.AssigneeID) + issue.Assignee, err = Users.GetByID(context.TODO(), issue.AssigneeID) if err != nil && !IsErrUserNotExist(err) { log.Error("Failed to get user by ID: %v", err) return nil diff --git a/internal/db/issue_mail.go b/internal/db/issue_mail.go index e1dbb671..d529ecdd 100644 --- a/internal/db/issue_mail.go +++ b/internal/db/issue_mail.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "github.com/unknwon/com" @@ -124,7 +125,7 @@ func mailIssueCommentToParticipants(issue *Issue, doer *User, mentions []string) continue } - to, err := GetUserByID(watchers[i].UserID) + to, err := Users.GetByID(context.TODO(), watchers[i].UserID) if err != nil { return fmt.Errorf("GetUserByID [%d]: %v", watchers[i].UserID, err) } diff --git a/internal/db/models.go b/internal/db/models.go index c9cc5a3b..88cdedae 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -210,7 +210,7 @@ type Statistic struct { } func GetStatistic(ctx context.Context) (stats Statistic) { - stats.Counter.User = CountUsers() + stats.Counter.User = Users.Count(ctx) stats.Counter.Org = CountOrganizations() stats.Counter.PublicKey, _ = x.Count(new(PublicKey)) stats.Counter.Repo = CountRepositories(true) diff --git a/internal/db/org.go b/internal/db/org.go index 9685e567..b11123c6 100644 --- a/internal/db/org.go +++ b/internal/db/org.go @@ -15,6 +15,7 @@ import ( "xorm.io/xorm" "gogs.io/gogs/internal/errutil" + "gogs.io/gogs/internal/repoutil" "gogs.io/gogs/internal/userutil" ) @@ -72,7 +73,7 @@ func (org *User) GetMembers(limit int) error { org.Members = make([]*User, len(ous)) for i, ou := range ous { - org.Members[i], err = GetUserByID(ou.Uid) + org.Members[i], err = Users.GetByID(context.TODO(), ou.Uid) if err != nil { return err } @@ -166,7 +167,7 @@ func CreateOrganization(org, owner *User) (err error) { return fmt.Errorf("insert team-user relation: %v", err) } - if err = os.MkdirAll(UserPath(org.Name), os.ModePerm); err != nil { + if err = os.MkdirAll(repoutil.UserPath(org.Name), os.ModePerm); err != nil { return fmt.Errorf("create directory: %v", err) } @@ -366,11 +367,11 @@ func RemoveOrgUser(orgID, userID int64) error { return nil } - user, err := GetUserByID(userID) + user, err := Users.GetByID(context.TODO(), userID) if err != nil { return fmt.Errorf("GetUserByID [%d]: %v", userID, err) } - org, err := GetUserByID(orgID) + org, err := Users.GetByID(context.TODO(), orgID) if err != nil { return fmt.Errorf("GetUserByID [%d]: %v", orgID, err) } diff --git a/internal/db/org_team.go b/internal/db/org_team.go index cfc4c08c..c44deb67 100644 --- a/internal/db/org_team.go +++ b/internal/db/org_team.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" @@ -417,7 +418,7 @@ func DeleteTeam(t *Team) error { } // Get organization. - org, err := GetUserByID(t.OrgID) + org, err := Users.GetByID(context.TODO(), t.OrgID) if err != nil { return err } diff --git a/internal/db/orgs.go b/internal/db/orgs.go index 9fbedf11..e7e8dd16 100644 --- a/internal/db/orgs.go +++ b/internal/db/orgs.go @@ -18,7 +18,7 @@ import ( // NOTE: All methods are sorted in alphabetical order. type OrgsStore interface { // List returns a list of organizations filtered by options. - List(ctx context.Context, opts ListOrgOptions) ([]*Organization, error) + List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) } var Orgs OrgsStore @@ -35,14 +35,14 @@ func NewOrgsStore(db *gorm.DB) OrgsStore { return &orgs{DB: db} } -type ListOrgOptions struct { +type ListOrgsOptions struct { // Filter by the membership with the given user ID. MemberID int64 // Whether to include private memberships. IncludePrivateMembers bool } -func (db *orgs) List(ctx context.Context, opts ListOrgOptions) ([]*Organization, error) { +func (db *orgs) List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) { if opts.MemberID <= 0 { return nil, errors.New("MemberID must be greater than 0") } diff --git a/internal/db/orgs_test.go b/internal/db/orgs_test.go index bee7444b..76e79345 100644 --- a/internal/db/orgs_test.go +++ b/internal/db/orgs_test.go @@ -58,14 +58,14 @@ func orgsList(t *testing.T, db *orgs) { org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{}) require.NoError(t, err) err = db.Exec( - dbutil.Quote("UPDATE %s SET %s = ? WHERE id = ?", "user", "type"), + dbutil.Quote("UPDATE %s SET type = ? WHERE id = ?", "user"), UserTypeOrganization, org1.ID, ).Error require.NoError(t, err) org2, err := usersStore.Create(ctx, "org2", "org2@example.com", CreateUserOptions{}) require.NoError(t, err) err = db.Exec( - dbutil.Quote("UPDATE %s SET %s = ? WHERE id = ?", "user", "type"), + dbutil.Quote("UPDATE %s SET type = ? WHERE id = ?", "user"), UserTypeOrganization, org2.ID, ).Error require.NoError(t, err) @@ -80,12 +80,12 @@ func orgsList(t *testing.T, db *orgs) { tests := []struct { name string - opts ListOrgOptions + opts ListOrgsOptions wantOrgNames []string }{ { name: "only public memberships for a user", - opts: ListOrgOptions{ + opts: ListOrgsOptions{ MemberID: alice.ID, IncludePrivateMembers: false, }, @@ -93,7 +93,7 @@ func orgsList(t *testing.T, db *orgs) { }, { name: "all memberships for a user", - opts: ListOrgOptions{ + opts: ListOrgsOptions{ MemberID: alice.ID, IncludePrivateMembers: true, }, @@ -101,7 +101,7 @@ func orgsList(t *testing.T, db *orgs) { }, { name: "no membership for a non-existent user", - opts: ListOrgOptions{ + opts: ListOrgsOptions{ MemberID: 404, IncludePrivateMembers: true, }, diff --git a/internal/db/repo.go b/internal/db/repo.go index 81a6cd00..f9e19468 100644 --- a/internal/db/repo.go +++ b/internal/db/repo.go @@ -1310,12 +1310,12 @@ func FilterRepositoryWithIssues(repoIDs []int64) ([]int64, error) { // // Deprecated: Use repoutil.RepositoryPath instead. func RepoPath(userName, repoName string) string { - return filepath.Join(UserPath(userName), strings.ToLower(repoName)+".git") + return filepath.Join(repoutil.UserPath(userName), strings.ToLower(repoName)+".git") } // TransferOwnership transfers all corresponding setting from old user to new one. func TransferOwnership(doer *User, newOwnerName string, repo *Repository) error { - newOwner, err := GetUserByName(newOwnerName) + newOwner, err := Users.GetByUsername(context.TODO(), newOwnerName) if err != nil { return fmt.Errorf("get new owner '%s': %v", newOwnerName, err) } @@ -1437,7 +1437,7 @@ func TransferOwnership(doer *User, newOwnerName string, repo *Repository) error } // Rename remote repository to new path and delete local copy. - if err = os.MkdirAll(UserPath(newOwner.Name), os.ModePerm); err != nil { + if err = os.MkdirAll(repoutil.UserPath(newOwner.Name), os.ModePerm); err != nil { return err } if err = os.Rename(RepoPath(owner.Name, repo.Name), RepoPath(newOwner.Name, repo.Name)); err != nil { @@ -1606,7 +1606,7 @@ func DeleteRepository(ownerID, repoID int64) error { } // In case is a organization. - org, err := GetUserByID(ownerID) + org, err := Users.GetByID(context.TODO(), ownerID) if err != nil { return err } @@ -1724,7 +1724,7 @@ func GetRepositoryByRef(ref string) (*Repository, error) { } userName, repoName := ref[:n], ref[n+1:] - user, err := GetUserByName(userName) + user, err := Users.GetByUsername(context.TODO(), userName) if err != nil { return nil, err } diff --git a/internal/db/update.go b/internal/db/update.go index ec538b0b..4c0eb3c5 100644 --- a/internal/db/update.go +++ b/internal/db/update.go @@ -73,7 +73,7 @@ func PushUpdate(opts PushUpdateOptions) (err error) { return fmt.Errorf("open repository: %v", err) } - owner, err := GetUserByName(opts.RepoUserName) + owner, err := Users.GetByUsername(ctx, opts.RepoUserName) if err != nil { return fmt.Errorf("GetUserByName: %v", err) } diff --git a/internal/db/user.go b/internal/db/user.go index e15c9410..922a340a 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -6,11 +6,9 @@ package db import ( "context" - "encoding/hex" "fmt" _ "image/jpeg" "os" - "path/filepath" "strings" "time" @@ -23,6 +21,7 @@ import ( "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/db/errors" "gogs.io/gogs/internal/errutil" + "gogs.io/gogs/internal/repoutil" "gogs.io/gogs/internal/tool" "gogs.io/gogs/internal/userutil" ) @@ -58,77 +57,6 @@ func (u *User) getOrganizationCount(e Engine) (int64, error) { return e.Where("uid=?", u.ID).Count(new(OrgUser)) } -func countUsers(e Engine) int64 { - count, _ := e.Where("type=0").Count(new(User)) - return count -} - -// CountUsers returns number of users. -func CountUsers() int64 { - return countUsers(x) -} - -// Users returns number of users in given page. -func ListUsers(page, pageSize int) ([]*User, error) { - users := make([]*User, 0, pageSize) - return users, x.Limit(pageSize, (page-1)*pageSize).Where("type=0").Asc("id").Find(&users) -} - -// parseUserFromCode returns user by username encoded in code. -// It returns nil if code or username is invalid. -func parseUserFromCode(code string) (user *User) { - if len(code) <= tool.TIME_LIMIT_CODE_LENGTH { - return nil - } - - // Use tail hex username to query user - hexStr := code[tool.TIME_LIMIT_CODE_LENGTH:] - if b, err := hex.DecodeString(hexStr); err == nil { - if user, err = GetUserByName(string(b)); user != nil { - return user - } else if !IsErrUserNotExist(err) { - log.Error("Failed to get user by name %q: %v", string(b), err) - } - } - - return nil -} - -// verify active code when active account -func VerifyUserActiveCode(code string) (user *User) { - minutes := conf.Auth.ActivateCodeLives - - if user = parseUserFromCode(code); user != nil { - // time limit code - prefix := code[:tool.TIME_LIMIT_CODE_LENGTH] - data := com.ToStr(user.ID) + user.Email + user.LowerName + user.Password + user.Rands - - if tool.VerifyTimeLimitCode(data, minutes, prefix) { - return user - } - } - return nil -} - -// verify active code when active account -func VerifyActiveEmailCode(code, email string) *EmailAddress { - minutes := conf.Auth.ActivateCodeLives - - if user := parseUserFromCode(code); user != nil { - // time limit code - prefix := code[:tool.TIME_LIMIT_CODE_LENGTH] - data := com.ToStr(user.ID) + email + user.LowerName + user.Password + user.Rands - - if tool.VerifyTimeLimitCode(data, minutes, prefix) { - emailAddress := &EmailAddress{Email: email} - if has, _ := x.Get(emailAddress); has { - return emailAddress - } - } - } - return nil -} - // ChangeUserName changes all corresponding setting from old user name to new one. func ChangeUserName(u *User, newUserName string) (err error) { if err = isUsernameAllowed(newUserName); err != nil { @@ -155,8 +83,8 @@ func ChangeUserName(u *User, newUserName string) (err error) { } // Rename or create user base directory - baseDir := UserPath(u.Name) - newBaseDir := UserPath(newUserName) + baseDir := repoutil.UserPath(u.Name) + newBaseDir := repoutil.UserPath(newUserName) if com.IsExist(baseDir) { return os.Rename(baseDir, newBaseDir) } @@ -270,7 +198,7 @@ func deleteUser(e *xorm.Session, u *User) error { &Follow{FollowID: u.ID}, &Action{UserID: u.ID}, &IssueUser{UID: u.ID}, - &EmailAddress{UID: u.ID}, + &EmailAddress{UserID: u.ID}, ); err != nil { return fmt.Errorf("deleteBeans: %v", err) } @@ -303,7 +231,7 @@ func deleteUser(e *xorm.Session, u *User) error { // Note: There are something just cannot be roll back, // so just keep error logs of those operations. - _ = os.RemoveAll(UserPath(u.Name)) + _ = os.RemoveAll(repoutil.UserPath(u.Name)) _ = os.Remove(userutil.CustomAvatarPath(u.ID)) return nil @@ -351,13 +279,6 @@ func DeleteInactivateUsers() (err error) { return err } -// UserPath returns the path absolute path of user repositories. -// -// Deprecated: Use repoutil.UserPath instead. -func UserPath(username string) string { - return filepath.Join(conf.Repository.Root, strings.ToLower(username)) -} - func GetUserByKeyID(keyID int64) (*User, error) { user := new(User) has, err := x.SQL("SELECT a.* FROM `user` AS a, public_key AS b WHERE a.id = b.owner_id AND b.id=?", keyID).Get(user) @@ -380,12 +301,6 @@ func getUserByID(e Engine, id int64) (*User, error) { return u, nil } -// GetUserByID returns the user object by given ID if exists. -// Deprecated: Use Users.GetByID instead. -func GetUserByID(id int64) (*User, error) { - return getUserByID(x, id) -} - // GetAssigneeByID returns the user with read access of repository by given ID. func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { ctx := context.TODO() @@ -400,27 +315,11 @@ func GetAssigneeByID(repo *Repository, userID int64) (*User, error) { return Users.GetByID(ctx, userID) } -// GetUserByName returns a user by given name. -// Deprecated: Use Users.GetByUsername instead. -func GetUserByName(name string) (*User, error) { - if name == "" { - return nil, ErrUserNotExist{args: map[string]interface{}{"name": name}} - } - u := &User{LowerName: strings.ToLower(name)} - has, err := x.Get(u) - if err != nil { - return nil, err - } else if !has { - return nil, ErrUserNotExist{args: map[string]interface{}{"name": name}} - } - return u, nil -} - // GetUserEmailsByNames returns a list of e-mails corresponds to names. func GetUserEmailsByNames(names []string) []string { mails := make([]string, 0, len(names)) for _, name := range names { - u, err := GetUserByName(name) + u, err := Users.GetByUsername(context.TODO(), name) if err != nil { continue } @@ -431,19 +330,6 @@ func GetUserEmailsByNames(names []string) []string { return mails } -// GetUserIDsByNames returns a slice of ids corresponds to names. -func GetUserIDsByNames(names []string) []int64 { - ids := make([]int64, 0, len(names)) - for _, name := range names { - u, err := GetUserByName(name) - if err != nil { - continue - } - ids = append(ids, u.ID) - } - return ids -} - // UserCommit represents a commit with validation of user. type UserCommit struct { User *User @@ -452,7 +338,7 @@ type UserCommit struct { // ValidateCommitWithEmail checks if author's e-mail of commit is corresponding to a user. func ValidateCommitWithEmail(c *git.Commit) *User { - u, err := GetUserByEmail(c.Author.Email) + u, err := Users.GetByEmail(context.TODO(), c.Author.Email) if err != nil { return nil } @@ -466,7 +352,7 @@ func ValidateCommitsWithEmails(oldCommits []*git.Commit) []*UserCommit { for i := range oldCommits { var u *User if v, ok := emails[oldCommits[i].Author.Email]; !ok { - u, _ = GetUserByEmail(oldCommits[i].Author.Email) + u, _ = Users.GetByEmail(context.TODO(), oldCommits[i].Author.Email) emails[oldCommits[i].Author.Email] = u } else { u = v @@ -480,37 +366,6 @@ func ValidateCommitsWithEmails(oldCommits []*git.Commit) []*UserCommit { return newCommits } -// GetUserByEmail returns the user object by given e-mail if exists. -// Deprecated: Use Users.GetByEmail instead. -func GetUserByEmail(email string) (*User, error) { - if email == "" { - return nil, ErrUserNotExist{args: map[string]interface{}{"email": email}} - } - - email = strings.ToLower(email) - // First try to find the user by primary email - user := &User{Email: email} - has, err := x.Get(user) - if err != nil { - return nil, err - } - if has { - return user, nil - } - - // Otherwise, check in alternative list for activated email addresses - emailAddress := &EmailAddress{Email: email, IsActivated: true} - has, err = x.Get(emailAddress) - if err != nil { - return nil, err - } - if has { - return GetUserByID(emailAddress.UID) - } - - return nil, ErrUserNotExist{args: map[string]interface{}{"email": email}} -} - type SearchUserOptions struct { Keyword string Type UserType diff --git a/internal/db/user_mail.go b/internal/db/user_mail.go index 122f19bc..eae79af7 100644 --- a/internal/db/user_mail.go +++ b/internal/db/user_mail.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "strings" @@ -16,8 +17,8 @@ import ( // EmailAddresses is the list of all email addresses of a user. Can contain the // primary email address, but is not obligatory. type EmailAddress struct { - ID int64 - UID int64 `xorm:"INDEX NOT NULL" gorm:"index;not null"` + ID int64 `gorm:"primaryKey"` + UserID int64 `xorm:"uid INDEX NOT NULL" gorm:"column:uid;index;not null"` Email string `xorm:"UNIQUE NOT NULL" gorm:"unique;not null"` IsActivated bool `gorm:"not null;default:FALSE"` IsPrimary bool `xorm:"-" gorm:"-" json:"-"` @@ -30,7 +31,7 @@ func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { return nil, err } - u, err := GetUserByID(uid) + u, err := Users.GetByID(context.TODO(), uid) if err != nil { return nil, err } @@ -119,7 +120,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } func (email *EmailAddress) Activate() error { - user, err := GetUserByID(email.UID) + user, err := Users.GetByID(context.TODO(), email.UserID) if err != nil { return err } @@ -170,7 +171,7 @@ func MakeEmailPrimary(userID int64, email *EmailAddress) error { return errors.EmailNotFound{Email: email.Email} } - if email.UID != userID { + if email.UserID != userID { return errors.New("not the owner of the email") } @@ -178,12 +179,12 @@ func MakeEmailPrimary(userID int64, email *EmailAddress) error { return errors.EmailNotVerified{Email: email.Email} } - user := &User{ID: email.UID} + user := &User{ID: email.UserID} has, err = x.Get(user) if err != nil { return err } else if !has { - return ErrUserNotExist{args: map[string]interface{}{"userID": email.UID}} + return ErrUserNotExist{args: map[string]interface{}{"userID": email.UserID}} } // Make sure the former primary email doesn't disappear. @@ -200,7 +201,7 @@ func MakeEmailPrimary(userID int64, email *EmailAddress) error { } if !has { - formerPrimaryEmail.UID = user.ID + formerPrimaryEmail.UserID = user.ID formerPrimaryEmail.IsActivated = user.IsActive if _, err = sess.Insert(formerPrimaryEmail); err != nil { return err diff --git a/internal/db/users.go b/internal/db/users.go index 37494638..f12eca09 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -45,6 +45,8 @@ type UsersStore interface { // When the "loginSourceID" is positive, it tries to authenticate via given // login source and creates a new user when not yet exists in the database. Authenticate(ctx context.Context, username, password string, loginSourceID int64) (*User, error) + // Count returns the total number of users. + Count(ctx context.Context) int64 // Create creates a new user and persists to database. It returns // ErrUserAlreadyExist when a user with same name already exists, or // ErrEmailAlreadyUsed if the email has been used by another user. @@ -65,6 +67,9 @@ type UsersStore interface { HasForkedRepository(ctx context.Context, userID, repoID int64) bool // IsUsernameUsed returns true if the given username has been used. IsUsernameUsed(ctx context.Context, username string) bool + // List returns a list of users. Results are paginated by given page and page + // size, and sorted by primary key (id) in ascending order. + List(ctx context.Context, page, pageSize int) ([]*User, error) // ListFollowers returns a list of users that are following the given user. // Results are paginated by given page and page size, and sorted by the time of // follow in descending order. @@ -183,6 +188,12 @@ func (db *users) Authenticate(ctx context.Context, login, password string, login ) } +func (db *users) Count(ctx context.Context) int64 { + var count int64 + db.WithContext(ctx).Model(&User{}).Where("type = ?", UserTypeIndividual).Count(&count) + return count +} + type CreateUserOptions struct { FullName string Password string @@ -242,6 +253,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create } } + email = strings.ToLower(email) _, err = db.GetByEmail(ctx, email) if err == nil { return nil, ErrEmailAlreadyUsed{ @@ -315,11 +327,10 @@ func (ErrUserNotExist) NotFound() bool { } func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { - email = strings.ToLower(email) - if email == "" { return nil, ErrUserNotExist{args: errutil.Args{"email": email}} } + email = strings.ToLower(email) // First try to find the user by primary email user := new(User) @@ -346,7 +357,7 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) { return nil, err } - return db.GetByID(ctx, emailAddress.UID) + return db.GetByID(ctx, emailAddress.UserID) } func (db *users) GetByID(ctx context.Context, id int64) (*User, error) { @@ -390,6 +401,16 @@ func (db *users) IsUsernameUsed(ctx context.Context, username string) bool { Error != gorm.ErrRecordNotFound } +func (db *users) List(ctx context.Context, page, pageSize int) ([]*User, error) { + users := make([]*User, 0, pageSize) + return users, db.WithContext(ctx). + Where("type = ?", UserTypeIndividual). + Limit(pageSize).Offset((page - 1) * pageSize). + Order("id ASC"). + Find(&users). + Error +} + func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { /* Equivalent SQL for PostgreSQL: diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 2479702a..40006330 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -18,6 +18,7 @@ import ( "gogs.io/gogs/internal/auth" "gogs.io/gogs/internal/dbtest" + "gogs.io/gogs/internal/dbutil" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/osutil" "gogs.io/gogs/internal/userutil" @@ -88,6 +89,7 @@ func TestUsers(t *testing.T) { test func(t *testing.T, db *users) }{ {"Authenticate", usersAuthenticate}, + {"Count", usersCount}, {"Create", usersCreate}, {"DeleteCustomAvatar", usersDeleteCustomAvatar}, {"GetByEmail", usersGetByEmail}, @@ -95,6 +97,7 @@ func TestUsers(t *testing.T) { {"GetByUsername", usersGetByUsername}, {"HasForkedRepository", usersHasForkedRepository}, {"IsUsernameUsed", usersIsUsernameUsed}, + {"List", usersList}, {"ListFollowers", usersListFollowers}, {"ListFollowings", usersListFollowings}, {"UseCustomAvatar", usersUseCustomAvatar}, @@ -209,6 +212,31 @@ func usersAuthenticate(t *testing.T, db *users) { }) } +func usersCount(t *testing.T, db *users) { + ctx := context.Background() + + // Has no user initially + got := db.Count(ctx) + assert.Equal(t, int64(0), got) + + _, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + got = db.Count(ctx) + assert.Equal(t, int64(1), got) + + // Create an organization shouldn't count + // TODO: Use Orgs.Create to replace SQL hack when the method is available. + org1, err := db.Create(ctx, "org1", "org1@example.com", CreateUserOptions{}) + require.NoError(t, err) + err = db.Exec( + dbutil.Quote("UPDATE %s SET type = ? WHERE id = ?", "user"), + UserTypeOrganization, org1.ID, + ).Error + require.NoError(t, err) + got = db.Count(ctx) + assert.Equal(t, int64(1), got) +} + func usersCreate(t *testing.T, db *users) { ctx := context.Background() @@ -420,6 +448,41 @@ func usersIsUsernameUsed(t *testing.T, db *users) { assert.False(t, got) } +func usersList(t *testing.T, db *users) { + ctx := context.Background() + + alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + // Create an organization shouldn't count + // TODO: Use Orgs.Create to replace SQL hack when the method is available. + org1, err := db.Create(ctx, "org1", "org1@example.com", CreateUserOptions{}) + require.NoError(t, err) + err = db.Exec( + dbutil.Quote("UPDATE %s SET type = ? WHERE id = ?", "user"), + UserTypeOrganization, org1.ID, + ).Error + require.NoError(t, err) + + got, err := db.List(ctx, 1, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, alice.ID, got[0].ID) + + got, err = db.List(ctx, 2, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, bob.ID, got[0].ID) + + got, err = db.List(ctx, 1, 3) + require.NoError(t, err) + require.Len(t, got, 2) + assert.Equal(t, alice.ID, got[0].ID) + assert.Equal(t, bob.ID, got[1].ID) +} + func usersListFollowers(t *testing.T, db *users) { ctx := context.Background() diff --git a/internal/db/wiki.go b/internal/db/wiki.go index 8606b9a5..1e1e0cbc 100644 --- a/internal/db/wiki.go +++ b/internal/db/wiki.go @@ -47,7 +47,7 @@ func (repo *Repository) WikiCloneLink() (cl *repoutil.CloneLink) { // WikiPath returns wiki data path by given user and repository name. func WikiPath(userName, repoName string) string { - return filepath.Join(UserPath(userName), strings.ToLower(repoName)+".wiki.git") + return filepath.Join(repoutil.UserPath(userName), strings.ToLower(repoName)+".wiki.git") } func (repo *Repository) WikiPath() string { |