diff options
author | Joe Chen <jc@unknwon.io> | 2022-11-27 19:36:10 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-27 19:36:10 +0800 |
commit | ae20d03aece78fb44dc1caaacfa40c3aa40c7949 (patch) | |
tree | 7e7b33f99eae57d8426eeead443276d5cbe0dd5a /internal/db | |
parent | 44333afd20a6312b617e0c33a497a4385ba3a250 (diff) |
refactor(db): migrate `UpdateUser` off `user.go` (#7267)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/repo.go | 26 | ||||
-rw-r--r-- | internal/db/user.go | 45 | ||||
-rw-r--r-- | internal/db/user_mail.go | 22 | ||||
-rw-r--r-- | internal/db/users.go | 138 | ||||
-rw-r--r-- | internal/db/users_test.go | 107 |
5 files changed, 219 insertions, 119 deletions
diff --git a/internal/db/repo.go b/internal/db/repo.go index 8abe47ed..896c39b8 100644 --- a/internal/db/repo.go +++ b/internal/db/repo.go @@ -34,6 +34,7 @@ import ( "gogs.io/gogs/internal/avatar" "gogs.io/gogs/internal/conf" dberrors "gogs.io/gogs/internal/db/errors" + "gogs.io/gogs/internal/dbutil" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/markup" "gogs.io/gogs/internal/osutil" @@ -1112,11 +1113,9 @@ func createRepository(e *xorm.Session, doer, owner *User, repo *Repository) (err return err } - owner.NumRepos++ - // Remember visibility preference. - owner.LastRepoVisibility = repo.IsPrivate - if err = updateUser(e, owner); err != nil { - return fmt.Errorf("updateUser: %v", err) + _, err = e.Exec(dbutil.Quote("UPDATE %s SET num_repos = num_repos + 1 WHERE id = ?", "user"), owner.ID) + if err != nil { + return errors.Wrap(err, "increase owned repository count") } // Give access to all members in owner team. @@ -1222,8 +1221,17 @@ func CreateRepository(doer, owner *User, opts CreateRepoOptionsLegacy) (_ *Repos return nil, fmt.Errorf("CreateRepository 'git update-server-info': %s", stderr) } } + if err = sess.Commit(); err != nil { + return nil, err + } - return repo, sess.Commit() + // Remember visibility preference + err = Users.Update(context.TODO(), owner.ID, UpdateUserOptions{LastRepoVisibility: &repo.IsPrivate}) + if err != nil { + return nil, errors.Wrap(err, "update user") + } + + return repo, nil } func countRepositories(userID int64, private bool) int64 { @@ -2544,6 +2552,12 @@ func ForkRepository(doer, owner *User, baseRepo *Repository, name, desc string) return nil, fmt.Errorf("Commit: %v", err) } + // Remember visibility preference + err = Users.Update(context.TODO(), owner.ID, UpdateUserOptions{LastRepoVisibility: &repo.IsPrivate}) + if err != nil { + return nil, errors.Wrap(err, "update user") + } + if err = repo.UpdateSize(); err != nil { log.Error("UpdateSize [repo_id: %d]: %v", repo.ID, err) } diff --git a/internal/db/user.go b/internal/db/user.go index cc2de95b..c2091e6f 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -19,10 +19,7 @@ import ( "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/db/errors" - "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/repoutil" - "gogs.io/gogs/internal/strutil" - "gogs.io/gogs/internal/tool" "gogs.io/gogs/internal/userutil" ) @@ -42,48 +39,6 @@ func (u *User) AfterSet(colName string, _ xorm.Cell) { } } -// TODO(unknwon): Update call sites to use refactored methods and delete this one. -func updateUser(e Engine, u *User) error { - // Organization does not need email - if !u.IsOrganization() { - u.Email = strings.ToLower(u.Email) - has, err := e.Where("id!=?", u.ID).And("type=?", u.Type).And("email=?", u.Email).Get(new(User)) - if err != nil { - return err - } else if has { - return ErrEmailAlreadyUsed{args: errutil.Args{"email": u.Email}} - } - - if u.AvatarEmail == "" { - u.AvatarEmail = u.Email - } - u.Avatar = tool.HashEmail(u.AvatarEmail) - } - - u.LowerName = strings.ToLower(u.Name) - u.Location = strutil.Truncate(u.Location, 255) - u.Website = strutil.Truncate(u.Website, 255) - u.Description = strutil.Truncate(u.Description, 255) - - _, err := e.ID(u.ID).AllCols().Update(u) - return err -} - -// TODO(unknwon): Refactoring together with methods that do updates. -func (u *User) BeforeUpdate() { - if u.MaxRepoCreation < -1 { - u.MaxRepoCreation = -1 - } - u.UpdatedUnix = time.Now().Unix() -} - -// UpdateUser updates user's information. -// -// TODO(unknwon): Update call sites to use refactored methods and delete this one. -func UpdateUser(u *User) error { - return updateUser(x, u) -} - // deleteBeans deletes all given beans, beans should contain delete conditions. func deleteBeans(e Engine, beans ...interface{}) (err error) { for i := range beans { diff --git a/internal/db/user_mail.go b/internal/db/user_mail.go index eae79af7..d657faa0 100644 --- a/internal/db/user_mail.go +++ b/internal/db/user_mail.go @@ -11,7 +11,6 @@ import ( "gogs.io/gogs/internal/db/errors" "gogs.io/gogs/internal/errutil" - "gogs.io/gogs/internal/userutil" ) // EmailAddresses is the list of all email addresses of a user. Can contain the @@ -120,28 +119,11 @@ func AddEmailAddresses(emails []*EmailAddress) error { } func (email *EmailAddress) Activate() error { - user, err := Users.GetByID(context.TODO(), email.UserID) - if err != nil { - return err - } - if user.Rands, err = userutil.RandomSalt(); err != nil { - return err - } - - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - email.IsActivated = true - if _, err := sess.ID(email.ID).AllCols().Update(email); err != nil { - return err - } else if err = updateUser(sess, user); err != nil { + if _, err := x.ID(email.ID).AllCols().Update(email); err != nil { return err } - - return sess.Commit() + return Users.Update(context.TODO(), email.UserID, UpdateUserOptions{GenerateNewRands: true}) } func DeleteEmailAddress(email *EmailAddress) (err error) { diff --git a/internal/db/users.go b/internal/db/users.go index 9e5687d4..65326a95 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -87,8 +87,7 @@ type UsersStore interface { // Results are paginated by given page and page size, and sorted by the time of // follow in descending order. ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) - // Update updates all fields for the given user, all values are persisted as-is - // (i.e. empty values would overwrite/wipe out existing values). + // Update updates fields for the given user. Update(ctx context.Context, userID int64, opts UpdateUserOptions) error // UseCustomAvatar uses the given avatar as the user custom avatar. UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error @@ -324,8 +323,10 @@ type ErrEmailAlreadyUsed struct { args errutil.Args } +// IsErrEmailAlreadyUsed returns true if the underlying error has the type +// ErrEmailAlreadyUsed. func IsErrEmailAlreadyUsed(err error) bool { - _, ok := err.(ErrEmailAlreadyUsed) + _, ok := errors.Cause(err).(ErrEmailAlreadyUsed) return ok } @@ -415,8 +416,10 @@ type ErrUserNotExist struct { args errutil.Args } +// IsErrUserNotExist returns true if the underlying error has the type +// ErrUserNotExist. func IsErrUserNotExist(err error) bool { - _, ok := err.(ErrUserNotExist) + _, ok := errors.Cause(err).(ErrUserNotExist) return ok } @@ -549,30 +552,117 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz } type UpdateUserOptions struct { - FullName string - Website string - Location string - Description string - - MaxRepoCreation int + LoginSource *int64 + LoginName *string + + Password *string + // GenerateNewRands indicates whether to force generate new rands for the user. + GenerateNewRands bool + + FullName *string + Email *string + Website *string + Location *string + Description *string + + MaxRepoCreation *int + LastRepoVisibility *bool + + IsActivated *bool + IsAdmin *bool + AllowGitHook *bool + AllowImportLocal *bool + ProhibitLogin *bool + + Avatar *string + AvatarEmail *string } func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error { - if opts.MaxRepoCreation < -1 { - opts.MaxRepoCreation = -1 + updates := map[string]any{ + "updated_unix": db.NowFunc().Unix(), } - return db.WithContext(ctx). - Model(&User{}). - Where("id = ?", userID). - Updates(map[string]any{ - "full_name": strutil.Truncate(opts.FullName, 255), - "website": strutil.Truncate(opts.Website, 255), - "location": strutil.Truncate(opts.Location, 255), - "description": strutil.Truncate(opts.Description, 255), - "max_repo_creation": opts.MaxRepoCreation, - "updated_unix": db.NowFunc().Unix(), - }). - Error + + if opts.LoginSource != nil { + updates["login_source"] = *opts.LoginSource + } + if opts.LoginName != nil { + updates["login_name"] = *opts.LoginName + } + + if opts.Password != nil { + salt, err := userutil.RandomSalt() + if err != nil { + return errors.Wrap(err, "generate salt") + } + updates["salt"] = salt + updates["passwd"] = userutil.EncodePassword(*opts.Password, salt) + opts.GenerateNewRands = true + } + if opts.GenerateNewRands { + rands, err := userutil.RandomSalt() + if err != nil { + return errors.Wrap(err, "generate rands") + } + updates["rands"] = rands + } + + if opts.FullName != nil { + updates["full_name"] = strutil.Truncate(*opts.FullName, 255) + } + if opts.Email != nil { + _, err := db.GetByEmail(ctx, *opts.Email) + if err == nil { + return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}} + } else if !IsErrUserNotExist(err) { + return errors.Wrap(err, "check email") + } + updates["email"] = *opts.Email + } + if opts.Website != nil { + updates["website"] = strutil.Truncate(*opts.Website, 255) + } + if opts.Location != nil { + updates["location"] = strutil.Truncate(*opts.Location, 255) + } + if opts.Description != nil { + updates["description"] = strutil.Truncate(*opts.Description, 255) + } + + if opts.MaxRepoCreation != nil { + if *opts.MaxRepoCreation < -1 { + *opts.MaxRepoCreation = -1 + } + updates["max_repo_creation"] = *opts.MaxRepoCreation + } + if opts.LastRepoVisibility != nil { + updates["last_repo_visibility"] = *opts.LastRepoVisibility + } + + if opts.IsActivated != nil { + updates["is_active"] = *opts.IsActivated + } + if opts.IsAdmin != nil { + updates["is_admin"] = *opts.IsAdmin + } + if opts.AllowGitHook != nil { + updates["allow_git_hook"] = *opts.AllowGitHook + } + if opts.AllowImportLocal != nil { + updates["allow_import_local"] = *opts.AllowImportLocal + } + if opts.ProhibitLogin != nil { + updates["prohibit_login"] = *opts.ProhibitLogin + } + + if opts.Avatar != nil { + updates["avatar"] = strutil.Truncate(*opts.Avatar, 2048) + } + if opts.AvatarEmail != nil { + updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255) + } + + return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error } func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error { diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 2ce0efd9..4cf5aaf4 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -731,16 +731,69 @@ func usersListFollowings(t *testing.T, db *users) { func usersUpdate(t *testing.T, db *users) { ctx := context.Background() - alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + const oldPassword = "Password" + alice, err := db.Create( + ctx, + "alice", + "alice@example.com", + CreateUserOptions{ + FullName: "FullName", + Password: oldPassword, + LoginSource: 9, + LoginName: "LoginName", + Location: "Location", + Website: "Website", + Activated: false, + Admin: false, + }, + ) require.NoError(t, err) - overLimitStr := strings.Repeat("a", 300) + t.Run("update password", func(t *testing.T) { + got := userutil.ValidatePassword(alice.Password, alice.Salt, oldPassword) + require.True(t, got) + + newPassword := "NewPassword" + err = db.Update(ctx, alice.ID, UpdateUserOptions{Password: &newPassword}) + require.NoError(t, err) + alice, err = db.GetByID(ctx, alice.ID) + require.NoError(t, err) + + got = userutil.ValidatePassword(alice.Password, alice.Salt, oldPassword) + assert.False(t, got, "Old password should stop working") + + got = userutil.ValidatePassword(alice.Password, alice.Salt, newPassword) + assert.True(t, got, "New password should work") + }) + + t.Run("update email but already used", func(t *testing.T) { + // todo + }) + + loginSource := int64(1) + maxRepoCreation := 99 + lastRepoVisibility := true + overLimitStr := strings.Repeat("a", 2050) opts := UpdateUserOptions{ - FullName: overLimitStr, - Website: overLimitStr, - Location: overLimitStr, - Description: overLimitStr, - MaxRepoCreation: 1, + LoginSource: &loginSource, + LoginName: &alice.Name, + + FullName: &overLimitStr, + Website: &overLimitStr, + Location: &overLimitStr, + Description: &overLimitStr, + + MaxRepoCreation: &maxRepoCreation, + LastRepoVisibility: &lastRepoVisibility, + + IsActivated: &lastRepoVisibility, + IsAdmin: &lastRepoVisibility, + AllowGitHook: &lastRepoVisibility, + AllowImportLocal: &lastRepoVisibility, + ProhibitLogin: &lastRepoVisibility, + + Avatar: &overLimitStr, + AvatarEmail: &overLimitStr, } err = db.Update(ctx, alice.ID, opts) require.NoError(t, err) @@ -748,28 +801,34 @@ func usersUpdate(t *testing.T, db *users) { alice, err = db.GetByID(ctx, alice.ID) require.NoError(t, err) - wantStr := strings.Repeat("a", 255) - assert.Equal(t, wantStr, alice.FullName) - assert.Equal(t, wantStr, alice.Website) - assert.Equal(t, wantStr, alice.Location) - assert.Equal(t, wantStr, alice.Description) - assert.Equal(t, 1, alice.MaxRepoCreation) - - // Test empty values - opts = UpdateUserOptions{ - FullName: "Alice John", - Website: "https://gogs.io", + assertValues := func() { + assert.Equal(t, loginSource, alice.LoginSource) + assert.Equal(t, alice.Name, alice.LoginName) + wantStr255 := strings.Repeat("a", 255) + assert.Equal(t, wantStr255, alice.FullName) + assert.Equal(t, wantStr255, alice.Website) + assert.Equal(t, wantStr255, alice.Location) + assert.Equal(t, wantStr255, alice.Description) + assert.Equal(t, maxRepoCreation, alice.MaxRepoCreation) + assert.Equal(t, lastRepoVisibility, alice.LastRepoVisibility) + assert.Equal(t, lastRepoVisibility, alice.IsActive) + assert.Equal(t, lastRepoVisibility, alice.IsAdmin) + assert.Equal(t, lastRepoVisibility, alice.AllowGitHook) + assert.Equal(t, lastRepoVisibility, alice.AllowImportLocal) + assert.Equal(t, lastRepoVisibility, alice.ProhibitLogin) + wantStr2048 := strings.Repeat("a", 2048) + assert.Equal(t, wantStr2048, alice.Avatar) + assert.Equal(t, wantStr255, alice.AvatarEmail) } - err = db.Update(ctx, alice.ID, opts) + assertValues() + + // Test ignored values + err = db.Update(ctx, alice.ID, UpdateUserOptions{}) require.NoError(t, err) alice, err = db.GetByID(ctx, alice.ID) require.NoError(t, err) - assert.Equal(t, opts.FullName, alice.FullName) - assert.Equal(t, opts.Website, alice.Website) - assert.Empty(t, alice.Location) - assert.Empty(t, alice.Description) - assert.Empty(t, alice.MaxRepoCreation) + assertValues() } func usersUseCustomAvatar(t *testing.T, db *users) { |