aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-11-27 19:36:10 +0800
committerGitHub <noreply@github.com>2022-11-27 19:36:10 +0800
commitae20d03aece78fb44dc1caaacfa40c3aa40c7949 (patch)
tree7e7b33f99eae57d8426eeead443276d5cbe0dd5a /internal/db
parent44333afd20a6312b617e0c33a497a4385ba3a250 (diff)
refactor(db): migrate `UpdateUser` off `user.go` (#7267)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/repo.go26
-rw-r--r--internal/db/user.go45
-rw-r--r--internal/db/user_mail.go22
-rw-r--r--internal/db/users.go138
-rw-r--r--internal/db/users_test.go107
5 files changed, 219 insertions, 119 deletions
diff --git a/internal/db/repo.go b/internal/db/repo.go
index 8abe47ed..896c39b8 100644
--- a/internal/db/repo.go
+++ b/internal/db/repo.go
@@ -34,6 +34,7 @@ import (
"gogs.io/gogs/internal/avatar"
"gogs.io/gogs/internal/conf"
dberrors "gogs.io/gogs/internal/db/errors"
+ "gogs.io/gogs/internal/dbutil"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/markup"
"gogs.io/gogs/internal/osutil"
@@ -1112,11 +1113,9 @@ func createRepository(e *xorm.Session, doer, owner *User, repo *Repository) (err
return err
}
- owner.NumRepos++
- // Remember visibility preference.
- owner.LastRepoVisibility = repo.IsPrivate
- if err = updateUser(e, owner); err != nil {
- return fmt.Errorf("updateUser: %v", err)
+ _, err = e.Exec(dbutil.Quote("UPDATE %s SET num_repos = num_repos + 1 WHERE id = ?", "user"), owner.ID)
+ if err != nil {
+ return errors.Wrap(err, "increase owned repository count")
}
// Give access to all members in owner team.
@@ -1222,8 +1221,17 @@ func CreateRepository(doer, owner *User, opts CreateRepoOptionsLegacy) (_ *Repos
return nil, fmt.Errorf("CreateRepository 'git update-server-info': %s", stderr)
}
}
+ if err = sess.Commit(); err != nil {
+ return nil, err
+ }
- return repo, sess.Commit()
+ // Remember visibility preference
+ err = Users.Update(context.TODO(), owner.ID, UpdateUserOptions{LastRepoVisibility: &repo.IsPrivate})
+ if err != nil {
+ return nil, errors.Wrap(err, "update user")
+ }
+
+ return repo, nil
}
func countRepositories(userID int64, private bool) int64 {
@@ -2544,6 +2552,12 @@ func ForkRepository(doer, owner *User, baseRepo *Repository, name, desc string)
return nil, fmt.Errorf("Commit: %v", err)
}
+ // Remember visibility preference
+ err = Users.Update(context.TODO(), owner.ID, UpdateUserOptions{LastRepoVisibility: &repo.IsPrivate})
+ if err != nil {
+ return nil, errors.Wrap(err, "update user")
+ }
+
if err = repo.UpdateSize(); err != nil {
log.Error("UpdateSize [repo_id: %d]: %v", repo.ID, err)
}
diff --git a/internal/db/user.go b/internal/db/user.go
index cc2de95b..c2091e6f 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -19,10 +19,7 @@ import (
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/db/errors"
- "gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/repoutil"
- "gogs.io/gogs/internal/strutil"
- "gogs.io/gogs/internal/tool"
"gogs.io/gogs/internal/userutil"
)
@@ -42,48 +39,6 @@ func (u *User) AfterSet(colName string, _ xorm.Cell) {
}
}
-// TODO(unknwon): Update call sites to use refactored methods and delete this one.
-func updateUser(e Engine, u *User) error {
- // Organization does not need email
- if !u.IsOrganization() {
- u.Email = strings.ToLower(u.Email)
- has, err := e.Where("id!=?", u.ID).And("type=?", u.Type).And("email=?", u.Email).Get(new(User))
- if err != nil {
- return err
- } else if has {
- return ErrEmailAlreadyUsed{args: errutil.Args{"email": u.Email}}
- }
-
- if u.AvatarEmail == "" {
- u.AvatarEmail = u.Email
- }
- u.Avatar = tool.HashEmail(u.AvatarEmail)
- }
-
- u.LowerName = strings.ToLower(u.Name)
- u.Location = strutil.Truncate(u.Location, 255)
- u.Website = strutil.Truncate(u.Website, 255)
- u.Description = strutil.Truncate(u.Description, 255)
-
- _, err := e.ID(u.ID).AllCols().Update(u)
- return err
-}
-
-// TODO(unknwon): Refactoring together with methods that do updates.
-func (u *User) BeforeUpdate() {
- if u.MaxRepoCreation < -1 {
- u.MaxRepoCreation = -1
- }
- u.UpdatedUnix = time.Now().Unix()
-}
-
-// UpdateUser updates user's information.
-//
-// TODO(unknwon): Update call sites to use refactored methods and delete this one.
-func UpdateUser(u *User) error {
- return updateUser(x, u)
-}
-
// deleteBeans deletes all given beans, beans should contain delete conditions.
func deleteBeans(e Engine, beans ...interface{}) (err error) {
for i := range beans {
diff --git a/internal/db/user_mail.go b/internal/db/user_mail.go
index eae79af7..d657faa0 100644
--- a/internal/db/user_mail.go
+++ b/internal/db/user_mail.go
@@ -11,7 +11,6 @@ import (
"gogs.io/gogs/internal/db/errors"
"gogs.io/gogs/internal/errutil"
- "gogs.io/gogs/internal/userutil"
)
// EmailAddresses is the list of all email addresses of a user. Can contain the
@@ -120,28 +119,11 @@ func AddEmailAddresses(emails []*EmailAddress) error {
}
func (email *EmailAddress) Activate() error {
- user, err := Users.GetByID(context.TODO(), email.UserID)
- if err != nil {
- return err
- }
- if user.Rands, err = userutil.RandomSalt(); err != nil {
- return err
- }
-
- sess := x.NewSession()
- defer sess.Close()
- if err = sess.Begin(); err != nil {
- return err
- }
-
email.IsActivated = true
- if _, err := sess.ID(email.ID).AllCols().Update(email); err != nil {
- return err
- } else if err = updateUser(sess, user); err != nil {
+ if _, err := x.ID(email.ID).AllCols().Update(email); err != nil {
return err
}
-
- return sess.Commit()
+ return Users.Update(context.TODO(), email.UserID, UpdateUserOptions{GenerateNewRands: true})
}
func DeleteEmailAddress(email *EmailAddress) (err error) {
diff --git a/internal/db/users.go b/internal/db/users.go
index 9e5687d4..65326a95 100644
--- a/internal/db/users.go
+++ b/internal/db/users.go
@@ -87,8 +87,7 @@ type UsersStore interface {
// Results are paginated by given page and page size, and sorted by the time of
// follow in descending order.
ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error)
- // Update updates all fields for the given user, all values are persisted as-is
- // (i.e. empty values would overwrite/wipe out existing values).
+ // Update updates fields for the given user.
Update(ctx context.Context, userID int64, opts UpdateUserOptions) error
// UseCustomAvatar uses the given avatar as the user custom avatar.
UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error
@@ -324,8 +323,10 @@ type ErrEmailAlreadyUsed struct {
args errutil.Args
}
+// IsErrEmailAlreadyUsed returns true if the underlying error has the type
+// ErrEmailAlreadyUsed.
func IsErrEmailAlreadyUsed(err error) bool {
- _, ok := err.(ErrEmailAlreadyUsed)
+ _, ok := errors.Cause(err).(ErrEmailAlreadyUsed)
return ok
}
@@ -415,8 +416,10 @@ type ErrUserNotExist struct {
args errutil.Args
}
+// IsErrUserNotExist returns true if the underlying error has the type
+// ErrUserNotExist.
func IsErrUserNotExist(err error) bool {
- _, ok := err.(ErrUserNotExist)
+ _, ok := errors.Cause(err).(ErrUserNotExist)
return ok
}
@@ -549,30 +552,117 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz
}
type UpdateUserOptions struct {
- FullName string
- Website string
- Location string
- Description string
-
- MaxRepoCreation int
+ LoginSource *int64
+ LoginName *string
+
+ Password *string
+ // GenerateNewRands indicates whether to force generate new rands for the user.
+ GenerateNewRands bool
+
+ FullName *string
+ Email *string
+ Website *string
+ Location *string
+ Description *string
+
+ MaxRepoCreation *int
+ LastRepoVisibility *bool
+
+ IsActivated *bool
+ IsAdmin *bool
+ AllowGitHook *bool
+ AllowImportLocal *bool
+ ProhibitLogin *bool
+
+ Avatar *string
+ AvatarEmail *string
}
func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
- if opts.MaxRepoCreation < -1 {
- opts.MaxRepoCreation = -1
+ updates := map[string]any{
+ "updated_unix": db.NowFunc().Unix(),
}
- return db.WithContext(ctx).
- Model(&User{}).
- Where("id = ?", userID).
- Updates(map[string]any{
- "full_name": strutil.Truncate(opts.FullName, 255),
- "website": strutil.Truncate(opts.Website, 255),
- "location": strutil.Truncate(opts.Location, 255),
- "description": strutil.Truncate(opts.Description, 255),
- "max_repo_creation": opts.MaxRepoCreation,
- "updated_unix": db.NowFunc().Unix(),
- }).
- Error
+
+ if opts.LoginSource != nil {
+ updates["login_source"] = *opts.LoginSource
+ }
+ if opts.LoginName != nil {
+ updates["login_name"] = *opts.LoginName
+ }
+
+ if opts.Password != nil {
+ salt, err := userutil.RandomSalt()
+ if err != nil {
+ return errors.Wrap(err, "generate salt")
+ }
+ updates["salt"] = salt
+ updates["passwd"] = userutil.EncodePassword(*opts.Password, salt)
+ opts.GenerateNewRands = true
+ }
+ if opts.GenerateNewRands {
+ rands, err := userutil.RandomSalt()
+ if err != nil {
+ return errors.Wrap(err, "generate rands")
+ }
+ updates["rands"] = rands
+ }
+
+ if opts.FullName != nil {
+ updates["full_name"] = strutil.Truncate(*opts.FullName, 255)
+ }
+ if opts.Email != nil {
+ _, err := db.GetByEmail(ctx, *opts.Email)
+ if err == nil {
+ return ErrEmailAlreadyUsed{args: errutil.Args{"email": *opts.Email}}
+ } else if !IsErrUserNotExist(err) {
+ return errors.Wrap(err, "check email")
+ }
+ updates["email"] = *opts.Email
+ }
+ if opts.Website != nil {
+ updates["website"] = strutil.Truncate(*opts.Website, 255)
+ }
+ if opts.Location != nil {
+ updates["location"] = strutil.Truncate(*opts.Location, 255)
+ }
+ if opts.Description != nil {
+ updates["description"] = strutil.Truncate(*opts.Description, 255)
+ }
+
+ if opts.MaxRepoCreation != nil {
+ if *opts.MaxRepoCreation < -1 {
+ *opts.MaxRepoCreation = -1
+ }
+ updates["max_repo_creation"] = *opts.MaxRepoCreation
+ }
+ if opts.LastRepoVisibility != nil {
+ updates["last_repo_visibility"] = *opts.LastRepoVisibility
+ }
+
+ if opts.IsActivated != nil {
+ updates["is_active"] = *opts.IsActivated
+ }
+ if opts.IsAdmin != nil {
+ updates["is_admin"] = *opts.IsAdmin
+ }
+ if opts.AllowGitHook != nil {
+ updates["allow_git_hook"] = *opts.AllowGitHook
+ }
+ if opts.AllowImportLocal != nil {
+ updates["allow_import_local"] = *opts.AllowImportLocal
+ }
+ if opts.ProhibitLogin != nil {
+ updates["prohibit_login"] = *opts.ProhibitLogin
+ }
+
+ if opts.Avatar != nil {
+ updates["avatar"] = strutil.Truncate(*opts.Avatar, 2048)
+ }
+ if opts.AvatarEmail != nil {
+ updates["avatar_email"] = strutil.Truncate(*opts.AvatarEmail, 255)
+ }
+
+ return db.WithContext(ctx).Model(&User{}).Where("id = ?", userID).Updates(updates).Error
}
func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
diff --git a/internal/db/users_test.go b/internal/db/users_test.go
index 2ce0efd9..4cf5aaf4 100644
--- a/internal/db/users_test.go
+++ b/internal/db/users_test.go
@@ -731,16 +731,69 @@ func usersListFollowings(t *testing.T, db *users) {
func usersUpdate(t *testing.T, db *users) {
ctx := context.Background()
- alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
+ const oldPassword = "Password"
+ alice, err := db.Create(
+ ctx,
+ "alice",
+ "alice@example.com",
+ CreateUserOptions{
+ FullName: "FullName",
+ Password: oldPassword,
+ LoginSource: 9,
+ LoginName: "LoginName",
+ Location: "Location",
+ Website: "Website",
+ Activated: false,
+ Admin: false,
+ },
+ )
require.NoError(t, err)
- overLimitStr := strings.Repeat("a", 300)
+ t.Run("update password", func(t *testing.T) {
+ got := userutil.ValidatePassword(alice.Password, alice.Salt, oldPassword)
+ require.True(t, got)
+
+ newPassword := "NewPassword"
+ err = db.Update(ctx, alice.ID, UpdateUserOptions{Password: &newPassword})
+ require.NoError(t, err)
+ alice, err = db.GetByID(ctx, alice.ID)
+ require.NoError(t, err)
+
+ got = userutil.ValidatePassword(alice.Password, alice.Salt, oldPassword)
+ assert.False(t, got, "Old password should stop working")
+
+ got = userutil.ValidatePassword(alice.Password, alice.Salt, newPassword)
+ assert.True(t, got, "New password should work")
+ })
+
+ t.Run("update email but already used", func(t *testing.T) {
+ // todo
+ })
+
+ loginSource := int64(1)
+ maxRepoCreation := 99
+ lastRepoVisibility := true
+ overLimitStr := strings.Repeat("a", 2050)
opts := UpdateUserOptions{
- FullName: overLimitStr,
- Website: overLimitStr,
- Location: overLimitStr,
- Description: overLimitStr,
- MaxRepoCreation: 1,
+ LoginSource: &loginSource,
+ LoginName: &alice.Name,
+
+ FullName: &overLimitStr,
+ Website: &overLimitStr,
+ Location: &overLimitStr,
+ Description: &overLimitStr,
+
+ MaxRepoCreation: &maxRepoCreation,
+ LastRepoVisibility: &lastRepoVisibility,
+
+ IsActivated: &lastRepoVisibility,
+ IsAdmin: &lastRepoVisibility,
+ AllowGitHook: &lastRepoVisibility,
+ AllowImportLocal: &lastRepoVisibility,
+ ProhibitLogin: &lastRepoVisibility,
+
+ Avatar: &overLimitStr,
+ AvatarEmail: &overLimitStr,
}
err = db.Update(ctx, alice.ID, opts)
require.NoError(t, err)
@@ -748,28 +801,34 @@ func usersUpdate(t *testing.T, db *users) {
alice, err = db.GetByID(ctx, alice.ID)
require.NoError(t, err)
- wantStr := strings.Repeat("a", 255)
- assert.Equal(t, wantStr, alice.FullName)
- assert.Equal(t, wantStr, alice.Website)
- assert.Equal(t, wantStr, alice.Location)
- assert.Equal(t, wantStr, alice.Description)
- assert.Equal(t, 1, alice.MaxRepoCreation)
-
- // Test empty values
- opts = UpdateUserOptions{
- FullName: "Alice John",
- Website: "https://gogs.io",
+ assertValues := func() {
+ assert.Equal(t, loginSource, alice.LoginSource)
+ assert.Equal(t, alice.Name, alice.LoginName)
+ wantStr255 := strings.Repeat("a", 255)
+ assert.Equal(t, wantStr255, alice.FullName)
+ assert.Equal(t, wantStr255, alice.Website)
+ assert.Equal(t, wantStr255, alice.Location)
+ assert.Equal(t, wantStr255, alice.Description)
+ assert.Equal(t, maxRepoCreation, alice.MaxRepoCreation)
+ assert.Equal(t, lastRepoVisibility, alice.LastRepoVisibility)
+ assert.Equal(t, lastRepoVisibility, alice.IsActive)
+ assert.Equal(t, lastRepoVisibility, alice.IsAdmin)
+ assert.Equal(t, lastRepoVisibility, alice.AllowGitHook)
+ assert.Equal(t, lastRepoVisibility, alice.AllowImportLocal)
+ assert.Equal(t, lastRepoVisibility, alice.ProhibitLogin)
+ wantStr2048 := strings.Repeat("a", 2048)
+ assert.Equal(t, wantStr2048, alice.Avatar)
+ assert.Equal(t, wantStr255, alice.AvatarEmail)
}
- err = db.Update(ctx, alice.ID, opts)
+ assertValues()
+
+ // Test ignored values
+ err = db.Update(ctx, alice.ID, UpdateUserOptions{})
require.NoError(t, err)
alice, err = db.GetByID(ctx, alice.ID)
require.NoError(t, err)
- assert.Equal(t, opts.FullName, alice.FullName)
- assert.Equal(t, opts.Website, alice.Website)
- assert.Empty(t, alice.Location)
- assert.Empty(t, alice.Description)
- assert.Empty(t, alice.MaxRepoCreation)
+ assertValues()
}
func usersUseCustomAvatar(t *testing.T, db *users) {