aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-11-27 15:19:44 +0800
committerGitHub <noreply@github.com>2022-11-27 15:19:44 +0800
commit13099a7e4fe7565bb858646d42d1fba817cb06cc (patch)
treeac932d0f5df9f14b0f9408c32f699ae7167edc25 /internal/db
parenta7dbc970dfaac9f04addf05da97bb0aa29083e37 (diff)
refactor(db): add `Users.Update` (#7263)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/email_addresses.go20
-rw-r--r--internal/db/email_addresses_test.go17
-rw-r--r--internal/db/org.go2
-rw-r--r--internal/db/user.go10
-rw-r--r--internal/db/users.go58
-rw-r--r--internal/db/users_test.go116
6 files changed, 190 insertions, 33 deletions
diff --git a/internal/db/email_addresses.go b/internal/db/email_addresses.go
index 2929ffd0..8cef705e 100644
--- a/internal/db/email_addresses.go
+++ b/internal/db/email_addresses.go
@@ -8,6 +8,7 @@ import (
"context"
"fmt"
+ "github.com/pkg/errors"
"gorm.io/gorm"
"gogs.io/gogs/internal/errutil"
@@ -17,9 +18,11 @@ import (
//
// NOTE: All methods are sorted in alphabetical order.
type EmailAddressesStore interface {
- // GetByEmail returns the email address with given email. It may return
- // unverified email addresses and returns ErrEmailNotExist when not found.
- GetByEmail(ctx context.Context, email string) (*EmailAddress, error)
+ // GetByEmail returns the email address with given email. If `needsActivated` is
+ // true, only activated email will be returned, otherwise, it may return
+ // inactivated email addresses. It returns ErrEmailNotExist when no qualified
+ // email is not found.
+ GetByEmail(ctx context.Context, email string, needsActivated bool) (*EmailAddress, error)
}
var EmailAddresses EmailAddressesStore
@@ -43,7 +46,7 @@ type ErrEmailNotExist struct {
}
func IsErrEmailAddressNotExist(err error) bool {
- _, ok := err.(ErrEmailNotExist)
+ _, ok := errors.Cause(err).(ErrEmailNotExist)
return ok
}
@@ -55,9 +58,14 @@ func (ErrEmailNotExist) NotFound() bool {
return true
}
-func (db *emailAddresses) GetByEmail(ctx context.Context, email string) (*EmailAddress, error) {
+func (db *emailAddresses) GetByEmail(ctx context.Context, email string, needsActivated bool) (*EmailAddress, error) {
+ tx := db.WithContext(ctx).Where("email = ?", email)
+ if needsActivated {
+ tx = tx.Where("is_activated = ?", true)
+ }
+
emailAddress := new(EmailAddress)
- err := db.WithContext(ctx).Where("email = ?", email).First(emailAddress).Error
+ err := tx.First(emailAddress).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrEmailNotExist{
diff --git a/internal/db/email_addresses_test.go b/internal/db/email_addresses_test.go
index b54ffbad..f55db549 100644
--- a/internal/db/email_addresses_test.go
+++ b/internal/db/email_addresses_test.go
@@ -49,7 +49,7 @@ func emailAddressesGetByEmail(t *testing.T, db *emailAddresses) {
ctx := context.Background()
const testEmail = "alice@example.com"
- _, err := db.GetByEmail(ctx, testEmail)
+ _, err := db.GetByEmail(ctx, testEmail, false)
wantErr := ErrEmailNotExist{
args: errutil.Args{
"email": testEmail,
@@ -58,9 +58,20 @@ func emailAddressesGetByEmail(t *testing.T, db *emailAddresses) {
assert.Equal(t, wantErr, err)
// TODO: Use EmailAddresses.Create to replace SQL hack when the method is available.
- err = db.Exec(`INSERT INTO email_address (uid, email) VALUES (1, ?)`, testEmail).Error
+ err = db.Exec(`INSERT INTO email_address (uid, email, is_activated) VALUES (1, ?, FALSE)`, testEmail).Error
require.NoError(t, err)
- got, err := db.GetByEmail(ctx, testEmail)
+ got, err := db.GetByEmail(ctx, testEmail, false)
+ require.NoError(t, err)
+ assert.Equal(t, testEmail, got.Email)
+
+ // Should not return if we only want activated emails
+ _, err = db.GetByEmail(ctx, testEmail, true)
+ assert.Equal(t, wantErr, err)
+
+ // TODO: Use EmailAddresses.MarkActivated to replace SQL hack when the method is available.
+ err = db.Exec(`UPDATE email_address SET is_activated = TRUE WHERE email = ?`, testEmail).Error
+ require.NoError(t, err)
+ got, err = db.GetByEmail(ctx, testEmail, true)
require.NoError(t, err)
assert.Equal(t, testEmail, got.Email)
}
diff --git a/internal/db/org.go b/internal/db/org.go
index b11123c6..cb68451a 100644
--- a/internal/db/org.go
+++ b/internal/db/org.go
@@ -106,7 +106,7 @@ func CreateOrganization(org, owner *User) (err error) {
return err
}
- if Users.IsUsernameUsed(context.TODO(), org.Name) {
+ if Users.IsUsernameUsed(context.TODO(), org.Name, 0) {
return ErrUserAlreadyExist{
args: errutil.Args{
"name": org.Name,
diff --git a/internal/db/user.go b/internal/db/user.go
index 24af0ec1..cc2de95b 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -21,6 +21,7 @@ import (
"gogs.io/gogs/internal/db/errors"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/repoutil"
+ "gogs.io/gogs/internal/strutil"
"gogs.io/gogs/internal/tool"
"gogs.io/gogs/internal/userutil"
)
@@ -41,6 +42,7 @@ func (u *User) AfterSet(colName string, _ xorm.Cell) {
}
}
+// TODO(unknwon): Update call sites to use refactored methods and delete this one.
func updateUser(e Engine, u *User) error {
// Organization does not need email
if !u.IsOrganization() {
@@ -59,9 +61,9 @@ func updateUser(e Engine, u *User) error {
}
u.LowerName = strings.ToLower(u.Name)
- u.Location = tool.TruncateString(u.Location, 255)
- u.Website = tool.TruncateString(u.Website, 255)
- u.Description = tool.TruncateString(u.Description, 255)
+ u.Location = strutil.Truncate(u.Location, 255)
+ u.Website = strutil.Truncate(u.Website, 255)
+ u.Description = strutil.Truncate(u.Description, 255)
_, err := e.ID(u.ID).AllCols().Update(u)
return err
@@ -76,6 +78,8 @@ func (u *User) BeforeUpdate() {
}
// UpdateUser updates user's information.
+//
+// TODO(unknwon): Update call sites to use refactored methods and delete this one.
func UpdateUser(u *User) error {
return updateUser(x, u)
}
diff --git a/internal/db/users.go b/internal/db/users.go
index 01688dab..2b599597 100644
--- a/internal/db/users.go
+++ b/internal/db/users.go
@@ -72,8 +72,10 @@ type UsersStore interface {
GetByUsername(ctx context.Context, username string) (*User, error)
// HasForkedRepository returns true if the user has forked given repository.
HasForkedRepository(ctx context.Context, userID, repoID int64) bool
- // IsUsernameUsed returns true if the given username has been used.
- IsUsernameUsed(ctx context.Context, username string) bool
+ // IsUsernameUsed returns true if the given username has been used other than
+ // the excluded user (a non-positive ID effectively meaning check against all
+ // users).
+ IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool
// List returns a list of users. Results are paginated by given page and page
// size, and sorted by primary key (id) in ascending order.
List(ctx context.Context, page, pageSize int) ([]*User, error)
@@ -85,6 +87,9 @@ type UsersStore interface {
// Results are paginated by given page and page size, and sorted by the time of
// follow in descending order.
ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error)
+ // Update updates all fields for the given user, all values are persisted as-is
+ // (i.e. empty values would overwrite/wipe out existing values).
+ Update(ctx context.Context, userID int64, opts UpdateUserOptions) error
// UseCustomAvatar uses the given avatar as the user custom avatar.
UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error
}
@@ -201,7 +206,7 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
return err
}
- if db.IsUsernameUsed(ctx, newUsername) {
+ if db.IsUsernameUsed(ctx, newUsername, userID) {
return ErrUserAlreadyExist{
args: errutil.Args{
"name": newUsername,
@@ -226,6 +231,11 @@ func (db *users) ChangeUsername(ctx context.Context, userID int64, newUsername s
return errors.Wrap(err, "update user name")
}
+ // Stop here if it's just a case-change of the username
+ if strings.EqualFold(user.Name, newUsername) {
+ return nil
+ }
+
// Update all references to the user name in pull requests
err = tx.Model(&PullRequest{}).
Where("head_user_name = ?", user.LowerName).
@@ -328,7 +338,7 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
return nil, err
}
- if db.IsUsernameUsed(ctx, username) {
+ if db.IsUsernameUsed(ctx, username, 0) {
return nil, ErrUserAlreadyExist{
args: errutil.Args{
"name": username,
@@ -428,18 +438,13 @@ func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
}
// Otherwise, check activated email addresses
- emailAddress := new(EmailAddress)
- err = db.WithContext(ctx).
- Where("email = ? AND is_activated = ?", email, true).
- First(emailAddress).
- Error
+ emailAddress, err := NewEmailAddressesStore(db.DB).GetByEmail(ctx, email, true)
if err != nil {
- if err == gorm.ErrRecordNotFound {
+ if IsErrEmailAddressNotExist(err) {
return nil, ErrUserNotExist{args: errutil.Args{"email": email}}
}
return nil, err
}
-
return db.GetByID(ctx, emailAddress.UserID)
}
@@ -473,13 +478,13 @@ func (db *users) HasForkedRepository(ctx context.Context, userID, repoID int64)
return count > 0
}
-func (db *users) IsUsernameUsed(ctx context.Context, username string) bool {
+func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool {
if username == "" {
return false
}
return db.WithContext(ctx).
Select("id").
- Where("lower_name = ?", strings.ToLower(username)).
+ Where("lower_name = ? AND id != ?", strings.ToLower(username), excludeUserId).
First(&User{}).
Error != gorm.ErrRecordNotFound
}
@@ -534,6 +539,33 @@ func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSiz
Error
}
+type UpdateUserOptions struct {
+ FullName string
+ Website string
+ Location string
+ Description string
+
+ MaxRepoCreation int
+}
+
+func (db *users) Update(ctx context.Context, userID int64, opts UpdateUserOptions) error {
+ if opts.MaxRepoCreation < -1 {
+ opts.MaxRepoCreation = -1
+ }
+ return db.WithContext(ctx).
+ Model(&User{}).
+ Where("id = ?", userID).
+ Updates(map[string]any{
+ "full_name": strutil.Truncate(opts.FullName, 255),
+ "website": strutil.Truncate(opts.Website, 255),
+ "location": strutil.Truncate(opts.Location, 255),
+ "description": strutil.Truncate(opts.Description, 255),
+ "max_repo_creation": opts.MaxRepoCreation,
+ "updated_unix": db.NowFunc().Unix(),
+ }).
+ Error
+}
+
func (db *users) UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error {
err := userutil.SaveAvatar(userID, avatar)
if err != nil {
diff --git a/internal/db/users_test.go b/internal/db/users_test.go
index 26d5ec7c..2ce0efd9 100644
--- a/internal/db/users_test.go
+++ b/internal/db/users_test.go
@@ -104,6 +104,7 @@ func TestUsers(t *testing.T) {
{"List", usersList},
{"ListFollowers", usersListFollowers},
{"ListFollowings", usersListFollowings},
+ {"Update", usersUpdate},
{"UseCustomAvatar", usersUseCustomAvatar},
} {
t.Run(tc.name, func(t *testing.T) {
@@ -241,10 +242,20 @@ func usersChangeUsername(t *testing.T, db *users) {
})
t.Run("name already exists", func(t *testing.T) {
- err := db.ChangeUsername(ctx, alice.ID, alice.Name)
+ bob, err := db.Create(
+ ctx,
+ "bob",
+ "bob@example.com",
+ CreateUserOptions{
+ Activated: true,
+ },
+ )
+ require.NoError(t, err)
+
+ err = db.ChangeUsername(ctx, alice.ID, bob.Name)
wantErr := ErrUserAlreadyExist{
args: errutil.Args{
- "name": alice.Name,
+ "name": bob.Name,
},
}
assert.Equal(t, wantErr, err)
@@ -303,7 +314,7 @@ func usersChangeUsername(t *testing.T, db *users) {
assert.Equal(t, headUserName, alice.Name)
var updatedUnix int64
- err = db.Model(&User{}).Select("updated_unix").Row().Scan(&updatedUnix)
+ err = db.Model(&User{}).Select("updated_unix").Where("id = ?", alice.ID).Row().Scan(&updatedUnix)
require.NoError(t, err)
assert.Equal(t, int64(0), updatedUnix)
@@ -329,6 +340,13 @@ func usersChangeUsername(t *testing.T, db *users) {
require.NoError(t, err)
assert.Equal(t, newUsername, alice.Name)
assert.Equal(t, db.NowFunc().Unix(), alice.UpdatedUnix)
+
+ // Change the cases of the username should just be fine
+ err = db.ChangeUsername(ctx, alice.ID, strings.ToUpper(newUsername))
+ require.NoError(t, err)
+ alice, err = db.GetByID(ctx, alice.ID)
+ require.NoError(t, err)
+ assert.Equal(t, strings.ToUpper(newUsername), alice.Name)
}
func usersCount(t *testing.T, db *users) {
@@ -561,10 +579,50 @@ func usersIsUsernameUsed(t *testing.T, db *users) {
alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
require.NoError(t, err)
- got := db.IsUsernameUsed(ctx, alice.Name)
- assert.True(t, got)
- got = db.IsUsernameUsed(ctx, "bob")
- assert.False(t, got)
+ tests := []struct {
+ name string
+ username string
+ excludeUserID int64
+ want bool
+ }{
+ {
+ name: "no change",
+ username: alice.Name,
+ excludeUserID: alice.ID,
+ want: false,
+ },
+ {
+ name: "change case",
+ username: strings.ToUpper(alice.Name),
+ excludeUserID: alice.ID,
+ want: false,
+ },
+ {
+ name: "not used",
+ username: "bob",
+ excludeUserID: alice.ID,
+ want: false,
+ },
+ {
+ name: "not used when not excluded",
+ username: "bob",
+ excludeUserID: 0,
+ want: false,
+ },
+
+ {
+ name: "used when not excluded",
+ username: alice.Name,
+ excludeUserID: 0,
+ want: true,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ got := db.IsUsernameUsed(ctx, test.username, test.excludeUserID)
+ assert.Equal(t, test.want, got)
+ })
+ }
}
func usersList(t *testing.T, db *users) {
@@ -670,6 +728,50 @@ func usersListFollowings(t *testing.T, db *users) {
assert.Equal(t, alice.ID, got[0].ID)
}
+func usersUpdate(t *testing.T, db *users) {
+ ctx := context.Background()
+
+ alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{})
+ require.NoError(t, err)
+
+ overLimitStr := strings.Repeat("a", 300)
+ opts := UpdateUserOptions{
+ FullName: overLimitStr,
+ Website: overLimitStr,
+ Location: overLimitStr,
+ Description: overLimitStr,
+ MaxRepoCreation: 1,
+ }
+ err = db.Update(ctx, alice.ID, opts)
+ require.NoError(t, err)
+
+ alice, err = db.GetByID(ctx, alice.ID)
+ require.NoError(t, err)
+
+ wantStr := strings.Repeat("a", 255)
+ assert.Equal(t, wantStr, alice.FullName)
+ assert.Equal(t, wantStr, alice.Website)
+ assert.Equal(t, wantStr, alice.Location)
+ assert.Equal(t, wantStr, alice.Description)
+ assert.Equal(t, 1, alice.MaxRepoCreation)
+
+ // Test empty values
+ opts = UpdateUserOptions{
+ FullName: "Alice John",
+ Website: "https://gogs.io",
+ }
+ err = db.Update(ctx, alice.ID, opts)
+ require.NoError(t, err)
+
+ alice, err = db.GetByID(ctx, alice.ID)
+ require.NoError(t, err)
+ assert.Equal(t, opts.FullName, alice.FullName)
+ assert.Equal(t, opts.Website, alice.Website)
+ assert.Empty(t, alice.Location)
+ assert.Empty(t, alice.Description)
+ assert.Empty(t, alice.MaxRepoCreation)
+}
+
func usersUseCustomAvatar(t *testing.T, db *users) {
ctx := context.Background()