diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/access_tokens.go | 2 | ||||
-rw-r--r-- | internal/db/actions.go | 4 | ||||
-rw-r--r-- | internal/db/db.go | 3 | ||||
-rw-r--r-- | internal/db/email_addresses.go | 2 | ||||
-rw-r--r-- | internal/db/follows.go | 127 | ||||
-rw-r--r-- | internal/db/follows_test.go | 122 | ||||
-rw-r--r-- | internal/db/lfs.go | 2 | ||||
-rw-r--r-- | internal/db/login_source_files.go | 2 | ||||
-rw-r--r-- | internal/db/login_sources.go | 2 | ||||
-rw-r--r-- | internal/db/org_users.go | 38 | ||||
-rw-r--r-- | internal/db/org_users_test.go | 63 | ||||
-rw-r--r-- | internal/db/orgs.go | 10 | ||||
-rw-r--r-- | internal/db/orgs_test.go | 19 | ||||
-rw-r--r-- | internal/db/perms.go | 2 | ||||
-rw-r--r-- | internal/db/public_keys.go | 2 | ||||
-rw-r--r-- | internal/db/repos.go | 59 | ||||
-rw-r--r-- | internal/db/repos_test.go | 65 | ||||
-rw-r--r-- | internal/db/two_factors.go | 2 | ||||
-rw-r--r-- | internal/db/users.go | 176 | ||||
-rw-r--r-- | internal/db/users_test.go | 116 | ||||
-rw-r--r-- | internal/db/watches.go | 77 | ||||
-rw-r--r-- | internal/db/watches_test.go | 88 |
22 files changed, 367 insertions, 616 deletions
diff --git a/internal/db/access_tokens.go b/internal/db/access_tokens.go index 9dab65d4..825cfa87 100644 --- a/internal/db/access_tokens.go +++ b/internal/db/access_tokens.go @@ -18,8 +18,6 @@ import ( ) // AccessTokensStore is the persistent interface for access tokens. -// -// NOTE: All methods are sorted in alphabetical order. type AccessTokensStore interface { // Create creates a new access token and persist to database. It returns // ErrAccessTokenAlreadyExist when an access token with same name already exists diff --git a/internal/db/actions.go b/internal/db/actions.go index 72d3b6ac..48d080b3 100644 --- a/internal/db/actions.go +++ b/internal/db/actions.go @@ -29,8 +29,6 @@ import ( ) // ActionsStore is the persistent interface for actions. -// -// NOTE: All methods are sorted in alphabetical order. type ActionsStore interface { // CommitRepo creates actions for pushing commits to the repository. An action // with the type ActionDeleteBranch is created if the push deletes a branch; an @@ -166,7 +164,7 @@ func (db *actions) ListByUser(ctx context.Context, userID, actorID, afterID int6 // notifyWatchers creates rows in action table for watchers who are able to see the action. func (db *actions) notifyWatchers(ctx context.Context, act *Action) error { - watches, err := NewWatchesStore(db.DB).ListByRepo(ctx, act.RepoID) + watches, err := NewReposStore(db.DB).ListWatches(ctx, act.RepoID) if err != nil { return errors.Wrap(err, "list watches") } diff --git a/internal/db/db.go b/internal/db/db.go index 2e883fc0..20573334 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -122,16 +122,13 @@ func Init(w logger.Writer) (*gorm.DB, error) { AccessTokens = &accessTokens{DB: db} Actions = NewActionsStore(db) EmailAddresses = NewEmailAddressesStore(db) - Follows = NewFollowsStore(db) LoginSources = &loginSources{DB: db, files: sourceFiles} LFS = &lfs{DB: db} Orgs = NewOrgsStore(db) - OrgUsers = NewOrgUsersStore(db) Perms = NewPermsStore(db) Repos = NewReposStore(db) TwoFactors = &twoFactors{DB: db} Users = NewUsersStore(db) - Watches = NewWatchesStore(db) return db, nil } diff --git a/internal/db/email_addresses.go b/internal/db/email_addresses.go index 4f30a898..d27b926d 100644 --- a/internal/db/email_addresses.go +++ b/internal/db/email_addresses.go @@ -15,8 +15,6 @@ import ( ) // EmailAddressesStore is the persistent interface for email addresses. -// -// NOTE: All methods are sorted in alphabetical order. type EmailAddressesStore interface { // GetByEmail returns the email address with given email. If `needsActivated` is // true, only activated email will be returned, otherwise, it may return diff --git a/internal/db/follows.go b/internal/db/follows.go deleted file mode 100644 index bf50042a..00000000 --- a/internal/db/follows.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2022 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - - "github.com/pkg/errors" - "gorm.io/gorm" -) - -// FollowsStore is the persistent interface for user follows. -// -// NOTE: All methods are sorted in alphabetical order. -type FollowsStore interface { - // Follow marks the user to follow the other user. - Follow(ctx context.Context, userID, followID int64) error - // IsFollowing returns true if the user is following the other user. - IsFollowing(ctx context.Context, userID, followID int64) bool - // Unfollow removes the mark the user to follow the other user. - Unfollow(ctx context.Context, userID, followID int64) error -} - -var Follows FollowsStore - -var _ FollowsStore = (*follows)(nil) - -type follows struct { - *gorm.DB -} - -// NewFollowsStore returns a persistent interface for user follows with given -// database connection. -func NewFollowsStore(db *gorm.DB) FollowsStore { - return &follows{DB: db} -} - -func (*follows) updateFollowingCount(tx *gorm.DB, userID, followID int64) error { - /* - Equivalent SQL for PostgreSQL: - - UPDATE "user" - SET num_followers = ( - SELECT COUNT(*) FROM follow WHERE follow_id = @followID - ) - WHERE id = @followID - */ - err := tx.Model(&User{}). - Where("id = ?", followID). - Update( - "num_followers", - tx.Model(&Follow{}).Select("COUNT(*)").Where("follow_id = ?", followID), - ). - Error - if err != nil { - return errors.Wrap(err, `update "user.num_followers"`) - } - - /* - Equivalent SQL for PostgreSQL: - - UPDATE "user" - SET num_following = ( - SELECT COUNT(*) FROM follow WHERE user_id = @userID - ) - WHERE id = @userID - */ - err = tx.Model(&User{}). - Where("id = ?", userID). - Update( - "num_following", - tx.Model(&Follow{}).Select("COUNT(*)").Where("user_id = ?", userID), - ). - Error - if err != nil { - return errors.Wrap(err, `update "user.num_following"`) - } - return nil -} - -func (db *follows) Follow(ctx context.Context, userID, followID int64) error { - if userID == followID { - return nil - } - - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - f := &Follow{ - UserID: userID, - FollowID: followID, - } - result := tx.FirstOrCreate(f, f) - if result.Error != nil { - return errors.Wrap(result.Error, "upsert") - } else if result.RowsAffected <= 0 { - return nil // Relation already exists - } - - return db.updateFollowingCount(tx, userID, followID) - }) -} - -func (db *follows) IsFollowing(ctx context.Context, userID, followID int64) bool { - return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil -} - -func (db *follows) Unfollow(ctx context.Context, userID, followID int64) error { - if userID == followID { - return nil - } - - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error - if err != nil { - return errors.Wrap(err, "delete") - } - return db.updateFollowingCount(tx, userID, followID) - }) -} - -// Follow represents relations of users and their followers. -type Follow struct { - ID int64 `gorm:"primaryKey"` - UserID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` - FollowID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` -} diff --git a/internal/db/follows_test.go b/internal/db/follows_test.go deleted file mode 100644 index ce71fcaa..00000000 --- a/internal/db/follows_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2022 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "gogs.io/gogs/internal/dbtest" -) - -func TestFollows(t *testing.T) { - if testing.Short() { - t.Skip() - } - t.Parallel() - - tables := []any{new(User), new(EmailAddress), new(Follow)} - db := &follows{ - DB: dbtest.NewDB(t, "follows", tables...), - } - - for _, tc := range []struct { - name string - test func(t *testing.T, db *follows) - }{ - {"Follow", followsFollow}, - {"IsFollowing", followsIsFollowing}, - {"Unfollow", followsUnfollow}, - } { - t.Run(tc.name, func(t *testing.T) { - t.Cleanup(func() { - err := clearTables(t, db.DB, tables...) - require.NoError(t, err) - }) - tc.test(t, db) - }) - if t.Failed() { - break - } - } -} - -func followsFollow(t *testing.T, db *follows) { - ctx := context.Background() - - usersStore := NewUsersStore(db.DB) - alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) - require.NoError(t, err) - bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) - require.NoError(t, err) - - err = db.Follow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - - // It is OK to follow multiple times and just be noop. - err = db.Follow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - - alice, err = usersStore.GetByID(ctx, alice.ID) - require.NoError(t, err) - assert.Equal(t, 1, alice.NumFollowing) - - bob, err = usersStore.GetByID(ctx, bob.ID) - require.NoError(t, err) - assert.Equal(t, 1, bob.NumFollowers) -} - -func followsIsFollowing(t *testing.T, db *follows) { - ctx := context.Background() - - usersStore := NewUsersStore(db.DB) - alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) - require.NoError(t, err) - bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) - require.NoError(t, err) - - got := db.IsFollowing(ctx, alice.ID, bob.ID) - assert.False(t, got) - - err = db.Follow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - got = db.IsFollowing(ctx, alice.ID, bob.ID) - assert.True(t, got) - - err = db.Unfollow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - got = db.IsFollowing(ctx, alice.ID, bob.ID) - assert.False(t, got) -} - -func followsUnfollow(t *testing.T, db *follows) { - ctx := context.Background() - - usersStore := NewUsersStore(db.DB) - alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) - require.NoError(t, err) - bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) - require.NoError(t, err) - - err = db.Follow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - - // It is OK to unfollow multiple times and just be noop. - err = db.Unfollow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - err = db.Unfollow(ctx, alice.ID, bob.ID) - require.NoError(t, err) - - alice, err = usersStore.GetByID(ctx, alice.ID) - require.NoError(t, err) - assert.Equal(t, 0, alice.NumFollowing) - - bob, err = usersStore.GetByID(ctx, bob.ID) - require.NoError(t, err) - assert.Equal(t, 0, bob.NumFollowers) -} diff --git a/internal/db/lfs.go b/internal/db/lfs.go index 5b5796c8..bff18efd 100644 --- a/internal/db/lfs.go +++ b/internal/db/lfs.go @@ -16,8 +16,6 @@ import ( ) // LFSStore is the persistent interface for LFS objects. -// -// NOTE: All methods are sorted in alphabetical order. type LFSStore interface { // CreateObject creates a LFS object record in database. CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error diff --git a/internal/db/login_source_files.go b/internal/db/login_source_files.go index 494f828a..7c0967a5 100644 --- a/internal/db/login_source_files.go +++ b/internal/db/login_source_files.go @@ -25,8 +25,6 @@ import ( ) // loginSourceFilesStore is the in-memory interface for login source files stored on file system. -// -// NOTE: All methods are sorted in alphabetical order. type loginSourceFilesStore interface { // GetByID returns a clone of login source by given ID. GetByID(id int64) (*LoginSource, error) diff --git a/internal/db/login_sources.go b/internal/db/login_sources.go index 8bbbaf07..9469a3f0 100644 --- a/internal/db/login_sources.go +++ b/internal/db/login_sources.go @@ -23,8 +23,6 @@ import ( ) // LoginSourcesStore is the persistent interface for login sources. -// -// NOTE: All methods are sorted in alphabetical order. type LoginSourcesStore interface { // Create creates a new login source and persist to database. It returns // ErrLoginSourceAlreadyExist when a login source with same name already exists. diff --git a/internal/db/org_users.go b/internal/db/org_users.go deleted file mode 100644 index 5c4add26..00000000 --- a/internal/db/org_users.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2022 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - - "gorm.io/gorm" -) - -// OrgUsersStore is the persistent interface for organization-user relations. -// -// NOTE: All methods are sorted in alphabetical order. -type OrgUsersStore interface { - // CountByUser returns the number of organizations the user is a member of. - CountByUser(ctx context.Context, userID int64) (int64, error) -} - -var OrgUsers OrgUsersStore - -var _ OrgUsersStore = (*orgUsers)(nil) - -type orgUsers struct { - *gorm.DB -} - -// NewOrgUsersStore returns a persistent interface for organization-user -// relations with given database connection. -func NewOrgUsersStore(db *gorm.DB) OrgUsersStore { - return &orgUsers{DB: db} -} - -func (db *orgUsers) CountByUser(ctx context.Context, userID int64) (int64, error) { - var count int64 - return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error -} diff --git a/internal/db/org_users_test.go b/internal/db/org_users_test.go deleted file mode 100644 index bff515bc..00000000 --- a/internal/db/org_users_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2022 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "gogs.io/gogs/internal/dbtest" -) - -func TestOrgUsers(t *testing.T) { - if testing.Short() { - t.Skip() - } - t.Parallel() - - tables := []any{new(OrgUser)} - db := &orgUsers{ - DB: dbtest.NewDB(t, "orgUsers", tables...), - } - - for _, tc := range []struct { - name string - test func(t *testing.T, db *orgUsers) - }{ - {"CountByUser", orgUsersCountByUser}, - } { - t.Run(tc.name, func(t *testing.T) { - t.Cleanup(func() { - err := clearTables(t, db.DB, tables...) - require.NoError(t, err) - }) - tc.test(t, db) - }) - if t.Failed() { - break - } - } -} - -func orgUsersCountByUser(t *testing.T, db *orgUsers) { - ctx := context.Background() - - // TODO: Use Orgs.Join to replace SQL hack when the method is available. - err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error - require.NoError(t, err) - err = db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 2, 1).Error - require.NoError(t, err) - - got, err := db.CountByUser(ctx, 1) - require.NoError(t, err) - assert.Equal(t, int64(1), got) - - got, err = db.CountByUser(ctx, 404) - require.NoError(t, err) - assert.Equal(t, int64(0), got) -} diff --git a/internal/db/orgs.go b/internal/db/orgs.go index db1078ba..753d8120 100644 --- a/internal/db/orgs.go +++ b/internal/db/orgs.go @@ -14,8 +14,6 @@ import ( ) // OrgsStore is the persistent interface for organizations. -// -// NOTE: All methods are sorted in alphabetical order. type OrgsStore interface { // List returns a list of organizations filtered by options. List(ctx context.Context, opts ListOrgsOptions) ([]*Organization, error) @@ -25,6 +23,9 @@ type OrgsStore interface { // count of all results is also returned. If the order is not given, it's up to // the database to decide. SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*Organization, int64, error) + + // CountByUser returns the number of organizations the user is a member of. + CountByUser(ctx context.Context, userID int64) (int64, error) } var Orgs OrgsStore @@ -79,6 +80,11 @@ func (db *orgs) SearchByName(ctx context.Context, keyword string, page, pageSize return searchUserByName(ctx, db.DB, UserTypeOrganization, keyword, page, pageSize, orderBy) } +func (db *orgs) CountByUser(ctx context.Context, userID int64) (int64, error) { + var count int64 + return count, db.WithContext(ctx).Model(&OrgUser{}).Where("uid = ?", userID).Count(&count).Error +} + type Organization = User func (o *Organization) TableName() string { diff --git a/internal/db/orgs_test.go b/internal/db/orgs_test.go index 9989394d..89550d81 100644 --- a/internal/db/orgs_test.go +++ b/internal/db/orgs_test.go @@ -32,6 +32,7 @@ func TestOrgs(t *testing.T) { }{ {"List", orgsList}, {"SearchByName", orgsSearchByName}, + {"CountByUser", orgsCountByUser}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -164,3 +165,21 @@ func orgsSearchByName(t *testing.T, db *orgs) { assert.Equal(t, org2.ID, orgs[0].ID) }) } + +func orgsCountByUser(t *testing.T, db *orgs) { + ctx := context.Background() + + // TODO: Use Orgs.Join to replace SQL hack when the method is available. + err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error + require.NoError(t, err) + err = db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 2, 1).Error + require.NoError(t, err) + + got, err := db.CountByUser(ctx, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), got) + + got, err = db.CountByUser(ctx, 404) + require.NoError(t, err) + assert.Equal(t, int64(0), got) +} diff --git a/internal/db/perms.go b/internal/db/perms.go index b0a1a85a..c3cd0566 100644 --- a/internal/db/perms.go +++ b/internal/db/perms.go @@ -12,8 +12,6 @@ import ( ) // PermsStore is the persistent interface for permissions. -// -// NOTE: All methods are sorted in alphabetical order. type PermsStore interface { // AccessMode returns the access mode of given user has to the repository. AccessMode(ctx context.Context, userID, repoID int64, opts AccessModeOptions) AccessMode diff --git a/internal/db/public_keys.go b/internal/db/public_keys.go index d2f8307d..71b0ed99 100644 --- a/internal/db/public_keys.go +++ b/internal/db/public_keys.go @@ -16,8 +16,6 @@ import ( ) // PublicKeysStore is the persistent interface for public keys. -// -// NOTE: All methods are sorted in alphabetical order. type PublicKeysStore interface { // RewriteAuthorizedKeys rewrites the "authorized_keys" file under the SSH root // path with all public keys stored in the database. diff --git a/internal/db/repos.go b/internal/db/repos.go index 7791a458..28a38148 100644 --- a/internal/db/repos.go +++ b/internal/db/repos.go @@ -19,8 +19,6 @@ import ( ) // ReposStore is the persistent interface for repositories. -// -// NOTE: All methods are sorted in alphabetical order. type ReposStore interface { // Create creates a new repository record in the database. It returns // ErrNameNotAllowed when the repository name is not allowed, or @@ -48,6 +46,14 @@ type ReposStore interface { // Touch updates the updated time to the current time and removes the bare state // of the given repository. Touch(ctx context.Context, id int64) error + + // ListWatches returns all watches of the given repository. + ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) + // Watch marks the user to watch the repository. + Watch(ctx context.Context, userID, repoID int64) error + + // HasForkedBy returns true if the given repository has forked by the given user. + HasForkedBy(ctx context.Context, repoID, userID int64) bool } var Repos ReposStore @@ -189,7 +195,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio return errors.Wrap(err, "create") } - err = NewWatchesStore(tx).Watch(ctx, ownerID, repo.ID) + err = NewReposStore(tx).Watch(ctx, ownerID, repo.ID) if err != nil { return errors.Wrap(err, "watch") } @@ -371,3 +377,50 @@ func (db *repos) Touch(ctx context.Context, id int64) error { }). Error } + +func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) { + var watches []*Watch + return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error +} + +func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error { + /* + Equivalent SQL for PostgreSQL: + + UPDATE repository + SET num_watches = ( + SELECT COUNT(*) FROM watch WHERE repo_id = @repoID + ) + WHERE id = @repoID + */ + return tx.Model(&Repository{}). + Where("id = ?", repoID). + Update( + "num_watches", + tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID), + ). + Error +} + +func (db *repos) Watch(ctx context.Context, userID, repoID int64) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + w := &Watch{ + UserID: userID, + RepoID: repoID, + } + result := tx.FirstOrCreate(w, w) + if result.Error != nil { + return errors.Wrap(result.Error, "upsert") + } else if result.RowsAffected <= 0 { + return nil // Relation already exists + } + + return db.recountWatches(tx, repoID) + }) +} + +func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool { + var count int64 + db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) + return count > 0 +} diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index d6bfcb0d..64b59d78 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -101,6 +101,9 @@ func TestRepos(t *testing.T) { {"GetByName", reposGetByName}, {"Star", reposStar}, {"Touch", reposTouch}, + {"ListByRepo", reposListWatches}, + {"Watch", reposWatch}, + {"HasForkedBy", reposHasForkedBy}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -298,3 +301,65 @@ func reposTouch(t *testing.T, db *repos) { require.NoError(t, err) assert.False(t, got.IsBare) } + +func reposListWatches(t *testing.T, db *repos) { + ctx := context.Background() + + err := db.Watch(ctx, 1, 1) + require.NoError(t, err) + err = db.Watch(ctx, 2, 1) + require.NoError(t, err) + err = db.Watch(ctx, 2, 2) + require.NoError(t, err) + + got, err := db.ListWatches(ctx, 1) + require.NoError(t, err) + for _, w := range got { + w.ID = 0 + } + + want := []*Watch{ + {UserID: 1, RepoID: 1}, + {UserID: 2, RepoID: 1}, + } + assert.Equal(t, want, got) +} + +func reposWatch(t *testing.T, db *repos) { + ctx := context.Background() + + reposStore := NewReposStore(db.DB) + repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) + require.NoError(t, err) + + err = db.Watch(ctx, 2, repo1.ID) + require.NoError(t, err) + + // It is OK to watch multiple times and just be noop. + err = db.Watch(ctx, 2, repo1.ID) + require.NoError(t, err) + + repo1, err = reposStore.GetByID(ctx, repo1.ID) + require.NoError(t, err) + assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. +} + +func reposHasForkedBy(t *testing.T, db *repos) { + ctx := context.Background() + + has := db.HasForkedBy(ctx, 1, 2) + assert.False(t, has) + + _, err := NewReposStore(db.DB).Create( + ctx, + 2, + CreateRepoOptions{ + Name: "repo1", + ForkID: 1, + }, + ) + require.NoError(t, err) + + has = db.HasForkedBy(ctx, 1, 2) + assert.True(t, has) +} diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go index 6125dda7..741a2ff7 100644 --- a/internal/db/two_factors.go +++ b/internal/db/two_factors.go @@ -21,8 +21,6 @@ import ( ) // TwoFactorsStore is the persistent interface for 2FA. -// -// NOTE: All methods are sorted in alphabetical order. type TwoFactorsStore interface { // Create creates a new 2FA token and recovery codes for given user. The "key" // is used to encrypt and later decrypt given "secret", which should be diff --git a/internal/db/users.go b/internal/db/users.go index 51810dc7..b33772c0 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -32,8 +32,6 @@ import ( ) // UsersStore is the persistent interface for users. -// -// NOTE: All methods are sorted in alphabetical order. type UsersStore interface { // Authenticate validates username and password via given login source ID. It // returns ErrUserNotExist when the user was not found. @@ -47,29 +45,12 @@ type UsersStore interface { // When the "loginSourceID" is positive, it tries to authenticate via given // login source and creates a new user when not yet exists in the database. Authenticate(ctx context.Context, username, password string, loginSourceID int64) (*User, error) - // ChangeUsername changes the username of the given user and updates all - // references to the old username. It returns ErrNameNotAllowed if the given - // name or pattern of the name is not allowed as a username, or - // ErrUserAlreadyExist when another user with same name already exists. - ChangeUsername(ctx context.Context, userID int64, newUsername string) error - // Count returns the total number of users. - Count(ctx context.Context) int64 // Create creates a new user and persists to database. It returns // ErrNameNotAllowed if the given name or pattern of the name is not allowed as // a username, or ErrUserAlreadyExist when a user with same name already exists, // or ErrEmailAlreadyUsed if the email has been used by another user. Create(ctx context.Context, username, email string, opts CreateUserOptions) (*User, error) - // DeleteCustomAvatar deletes the current user custom avatar and falls back to - // use look up avatar by email. - DeleteCustomAvatar(ctx context.Context, userID int64) error - // DeleteByID deletes the given user and all their resources. It returns - // ErrUserOwnRepos when the user still has repository ownership, or returns - // ErrUserHasOrgs when the user still has organization membership. It is more - // performant to skip rewriting the "authorized_keys" file for individual - // deletion in a batch operation. - DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error - // DeleteInactivated deletes all inactivated users. - DeleteInactivated() error + // GetByEmail returns the user (not organization) with given email. It ignores // records with unverified emails and returns ErrUserNotExist when not found. GetByEmail(ctx context.Context, email string) (*User, error) @@ -86,15 +67,45 @@ type UsersStore interface { // addresses (where email notifications are sent to) of users with given list of // usernames. Non-existing usernames are ignored. GetMailableEmailsByUsernames(ctx context.Context, usernames []string) ([]string, error) - // HasForkedRepository returns true if the user has forked given repository. - HasForkedRepository(ctx context.Context, userID, repoID int64) bool + // SearchByName returns a list of users whose username or full name matches the + // given keyword case-insensitively. Results are paginated by given page and + // page size, and sorted by the given order (e.g. "id DESC"). A total count of + // all results is also returned. If the order is not given, it's up to the + // database to decide. + SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) + // IsUsernameUsed returns true if the given username has been used other than // the excluded user (a non-positive ID effectively meaning check against all // users). IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool - // List returns a list of users. Results are paginated by given page and page - // size, and sorted by primary key (id) in ascending order. - List(ctx context.Context, page, pageSize int) ([]*User, error) + // ChangeUsername changes the username of the given user and updates all + // references to the old username. It returns ErrNameNotAllowed if the given + // name or pattern of the name is not allowed as a username, or + // ErrUserAlreadyExist when another user with same name already exists. + ChangeUsername(ctx context.Context, userID int64, newUsername string) error + // Update updates fields for the given user. + Update(ctx context.Context, userID int64, opts UpdateUserOptions) error + // UseCustomAvatar uses the given avatar as the user custom avatar. + UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error + + // DeleteCustomAvatar deletes the current user custom avatar and falls back to + // use look up avatar by email. + DeleteCustomAvatar(ctx context.Context, userID int64) error + // DeleteByID deletes the given user and all their resources. It returns + // ErrUserOwnRepos when the user still has repository ownership, or returns + // ErrUserHasOrgs when the user still has organization membership. It is more + // performant to skip rewriting the "authorized_keys" file for individual + // deletion in a batch operation. + DeleteByID(ctx context.Context, userID int64, skipRewriteAuthorizedKeys bool) error + // DeleteInactivated deletes all inactivated users. + DeleteInactivated() error + + // Follow marks the user to follow the other user. + Follow(ctx context.Context, userID, followID int64) error + // Unfollow removes the mark the user to follow the other user. + Unfollow(ctx context.Context, userID, followID int64) error + // IsFollowing returns true if the user is following the other user. + IsFollowing(ctx context.Context, userID, followID int64) bool // ListFollowers returns a list of users that are following the given user. // Results are paginated by given page and page size, and sorted by the time of // follow in descending order. @@ -103,16 +114,12 @@ type UsersStore interface { // Results are paginated by given page and page size, and sorted by the time of // follow in descending order. ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) - // SearchByName returns a list of users whose username or full name matches the - // given keyword case-insensitively. Results are paginated by given page and - // page size, and sorted by the given order (e.g. "id DESC"). A total count of - // all results is also returned. If the order is not given, it's up to the - // database to decide. - SearchByName(ctx context.Context, keyword string, page, pageSize int, orderBy string) ([]*User, int64, error) - // Update updates fields for the given user. - Update(ctx context.Context, userID int64, opts UpdateUserOptions) error - // UseCustomAvatar uses the given avatar as the user custom avatar. - UseCustomAvatar(ctx context.Context, userID int64, avatar []byte) error + + // List returns a list of users. Results are paginated by given page and page + // size, and sorted by primary key (id) in ascending order. + List(ctx context.Context, page, pageSize int) ([]*User, error) + // Count returns the total number of users. + Count(ctx context.Context) int64 } var Users UsersStore @@ -650,6 +657,88 @@ func (db *users) DeleteInactivated() error { return nil } +func (*users) recountFollows(tx *gorm.DB, userID, followID int64) error { + /* + Equivalent SQL for PostgreSQL: + + UPDATE "user" + SET num_followers = ( + SELECT COUNT(*) FROM follow WHERE follow_id = @followID + ) + WHERE id = @followID + */ + err := tx.Model(&User{}). + Where("id = ?", followID). + Update( + "num_followers", + tx.Model(&Follow{}).Select("COUNT(*)").Where("follow_id = ?", followID), + ). + Error + if err != nil { + return errors.Wrap(err, `update "user.num_followers"`) + } + + /* + Equivalent SQL for PostgreSQL: + + UPDATE "user" + SET num_following = ( + SELECT COUNT(*) FROM follow WHERE user_id = @userID + ) + WHERE id = @userID + */ + err = tx.Model(&User{}). + Where("id = ?", userID). + Update( + "num_following", + tx.Model(&Follow{}).Select("COUNT(*)").Where("user_id = ?", userID), + ). + Error + if err != nil { + return errors.Wrap(err, `update "user.num_following"`) + } + return nil +} + +func (db *users) Follow(ctx context.Context, userID, followID int64) error { + if userID == followID { + return nil + } + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + f := &Follow{ + UserID: userID, + FollowID: followID, + } + result := tx.FirstOrCreate(f, f) + if result.Error != nil { + return errors.Wrap(result.Error, "upsert") + } else if result.RowsAffected <= 0 { + return nil // Relation already exists + } + + return db.recountFollows(tx, userID, followID) + }) +} + +func (db *users) Unfollow(ctx context.Context, userID, followID int64) error { + if userID == followID { + return nil + } + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error + if err != nil { + return errors.Wrap(err, "delete") + } + return db.recountFollows(tx, userID, followID) + }) +} + +func (db *users) IsFollowing(ctx context.Context, userID, followID int64) bool { + return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil +} + var _ errutil.NotFound = (*ErrUserNotExist)(nil) type ErrUserNotExist struct { @@ -757,12 +846,6 @@ func (db *users) GetMailableEmailsByUsernames(ctx context.Context, usernames []s Find(&emails).Error } -func (db *users) HasForkedRepository(ctx context.Context, userID, repoID int64) bool { - var count int64 - db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) - return count > 0 -} - func (db *users) IsUsernameUsed(ctx context.Context, username string, excludeUserId int64) bool { if username == "" { return false @@ -1181,7 +1264,7 @@ func (u *User) AvatarURL() string { // TODO(unknwon): This is also used in templates, which should be fixed by // having a dedicated type `template.User`. func (u *User) IsFollowing(followID int64) bool { - return Follows.IsFollowing(context.TODO(), u.ID, followID) + return Users.IsFollowing(context.TODO(), u.ID, followID) } // IsUserOrgOwner returns true if the user is in the owner team of the given @@ -1208,7 +1291,7 @@ func (u *User) IsPublicMember(orgId int64) bool { // TODO(unknwon): This is also used in templates, which should be fixed by // having a dedicated type `template.User`. func (u *User) GetOrganizationCount() (int64, error) { - return OrgUsers.CountByUser(context.TODO(), u.ID) + return Orgs.CountByUser(context.TODO(), u.ID) } // ShortName truncates and returns the username at most in given length. @@ -1336,3 +1419,10 @@ func isNameAllowed(names map[string]struct{}, patterns []string, name string) er func isUsernameAllowed(name string) error { return isNameAllowed(reservedUsernames, reservedUsernamePatterns, name) } + +// Follow represents relations of users and their followers. +type Follow struct { + ID int64 `gorm:"primaryKey"` + UserID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` + FollowID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` +} diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 69f157ea..edb9c1dd 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -107,7 +107,6 @@ func TestUsers(t *testing.T) { {"GetByUsername", usersGetByUsername}, {"GetByKeyID", usersGetByKeyID}, {"GetMailableEmailsByUsernames", usersGetMailableEmailsByUsernames}, - {"HasForkedRepository", usersHasForkedRepository}, {"IsUsernameUsed", usersIsUsernameUsed}, {"List", usersList}, {"ListFollowers", usersListFollowers}, @@ -115,6 +114,9 @@ func TestUsers(t *testing.T) { {"SearchByName", usersSearchByName}, {"Update", usersUpdate}, {"UseCustomAvatar", usersUseCustomAvatar}, + {"Follow", usersFollow}, + {"IsFollowing", usersIsFollowing}, + {"Unfollow", usersUnfollow}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -518,14 +520,13 @@ func usersDeleteByID(t *testing.T, db *users) { require.NoError(t, err) // Mock watches, stars and follows - err = NewWatchesStore(db.DB).Watch(ctx, testUser.ID, repo2.ID) + err = reposStore.Watch(ctx, testUser.ID, repo2.ID) require.NoError(t, err) err = reposStore.Star(ctx, testUser.ID, repo2.ID) require.NoError(t, err) - followsStore := NewFollowsStore(db.DB) - err = followsStore.Follow(ctx, testUser.ID, cindy.ID) + err = db.Follow(ctx, testUser.ID, cindy.ID) require.NoError(t, err) - err = followsStore.Follow(ctx, frank.ID, testUser.ID) + err = db.Follow(ctx, frank.ID, testUser.ID) require.NoError(t, err) // Mock "authorized_keys" file @@ -865,26 +866,6 @@ func usersGetMailableEmailsByUsernames(t *testing.T, db *users) { assert.Equal(t, want, got) } -func usersHasForkedRepository(t *testing.T, db *users) { - ctx := context.Background() - - has := db.HasForkedRepository(ctx, 1, 1) - assert.False(t, has) - - _, err := NewReposStore(db.DB).Create( - ctx, - 1, - CreateRepoOptions{ - Name: "repo1", - ForkID: 1, - }, - ) - require.NoError(t, err) - - has = db.HasForkedRepository(ctx, 1, 1) - assert.True(t, has) -} - func usersIsUsernameUsed(t *testing.T, db *users) { ctx := context.Background() @@ -987,10 +968,9 @@ func usersListFollowers(t *testing.T, db *users) { bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) require.NoError(t, err) - followsStore := NewFollowsStore(db.DB) - err = followsStore.Follow(ctx, alice.ID, john.ID) + err = db.Follow(ctx, alice.ID, john.ID) require.NoError(t, err) - err = followsStore.Follow(ctx, bob.ID, john.ID) + err = db.Follow(ctx, bob.ID, john.ID) require.NoError(t, err) // First page only has bob @@ -1021,10 +1001,9 @@ func usersListFollowings(t *testing.T, db *users) { bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) require.NoError(t, err) - followsStore := NewFollowsStore(db.DB) - err = followsStore.Follow(ctx, john.ID, alice.ID) + err = db.Follow(ctx, john.ID, alice.ID) require.NoError(t, err) - err = followsStore.Follow(ctx, john.ID, bob.ID) + err = db.Follow(ctx, john.ID, bob.ID) require.NoError(t, err) // First page only has bob @@ -1222,3 +1201,78 @@ func TestIsUsernameAllowed(t *testing.T) { }) } } + +func usersFollow(t *testing.T, db *users) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + // It is OK to follow multiple times and just be noop. + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + alice, err = usersStore.GetByID(ctx, alice.ID) + require.NoError(t, err) + assert.Equal(t, 1, alice.NumFollowing) + + bob, err = usersStore.GetByID(ctx, bob.ID) + require.NoError(t, err) + assert.Equal(t, 1, bob.NumFollowers) +} + +func usersIsFollowing(t *testing.T, db *users) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + got := db.IsFollowing(ctx, alice.ID, bob.ID) + assert.False(t, got) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + got = db.IsFollowing(ctx, alice.ID, bob.ID) + assert.True(t, got) + + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + got = db.IsFollowing(ctx, alice.ID, bob.ID) + assert.False(t, got) +} + +func usersUnfollow(t *testing.T, db *users) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + // It is OK to unfollow multiple times and just be noop. + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + alice, err = usersStore.GetByID(ctx, alice.ID) + require.NoError(t, err) + assert.Equal(t, 0, alice.NumFollowing) + + bob, err = usersStore.GetByID(ctx, bob.ID) + require.NoError(t, err) + assert.Equal(t, 0, bob.NumFollowers) +} diff --git a/internal/db/watches.go b/internal/db/watches.go deleted file mode 100644 index 6ca78a37..00000000 --- a/internal/db/watches.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2020 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - - "github.com/pkg/errors" - "gorm.io/gorm" -) - -// WatchesStore is the persistent interface for watches. -// -// NOTE: All methods are sorted in alphabetical order. -type WatchesStore interface { - // ListByRepo returns all watches of the given repository. - ListByRepo(ctx context.Context, repoID int64) ([]*Watch, error) - // Watch marks the user to watch the repository. - Watch(ctx context.Context, userID, repoID int64) error -} - -var Watches WatchesStore - -var _ WatchesStore = (*watches)(nil) - -type watches struct { - *gorm.DB -} - -// NewWatchesStore returns a persistent interface for watches with given -// database connection. -func NewWatchesStore(db *gorm.DB) WatchesStore { - return &watches{DB: db} -} - -func (db *watches) ListByRepo(ctx context.Context, repoID int64) ([]*Watch, error) { - var watches []*Watch - return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error -} - -func (db *watches) updateWatchingCount(tx *gorm.DB, repoID int64) error { - /* - Equivalent SQL for PostgreSQL: - - UPDATE repository - SET num_watches = ( - SELECT COUNT(*) FROM watch WHERE repo_id = @repoID - ) - WHERE id = @repoID - */ - return tx.Model(&Repository{}). - Where("id = ?", repoID). - Update( - "num_watches", - tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID), - ). - Error -} - -func (db *watches) Watch(ctx context.Context, userID, repoID int64) error { - return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - w := &Watch{ - UserID: userID, - RepoID: repoID, - } - result := tx.FirstOrCreate(w, w) - if result.Error != nil { - return errors.Wrap(result.Error, "upsert") - } else if result.RowsAffected <= 0 { - return nil // Relation already exists - } - - return db.updateWatchingCount(tx, repoID) - }) -} diff --git a/internal/db/watches_test.go b/internal/db/watches_test.go deleted file mode 100644 index 245be7b3..00000000 --- a/internal/db/watches_test.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2022 The Gogs Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package db - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "gogs.io/gogs/internal/dbtest" -) - -func TestWatches(t *testing.T) { - if testing.Short() { - t.Skip() - } - t.Parallel() - - tables := []any{new(Watch), new(Repository)} - db := &watches{ - DB: dbtest.NewDB(t, "watches", tables...), - } - - for _, tc := range []struct { - name string - test func(t *testing.T, db *watches) - }{ - {"ListByRepo", watchesListByRepo}, - {"Watch", watchesWatch}, - } { - t.Run(tc.name, func(t *testing.T) { - t.Cleanup(func() { - err := clearTables(t, db.DB, tables...) - require.NoError(t, err) - }) - tc.test(t, db) - }) - if t.Failed() { - break - } - } -} - -func watchesListByRepo(t *testing.T, db *watches) { - ctx := context.Background() - - err := db.Watch(ctx, 1, 1) - require.NoError(t, err) - err = db.Watch(ctx, 2, 1) - require.NoError(t, err) - err = db.Watch(ctx, 2, 2) - require.NoError(t, err) - - got, err := db.ListByRepo(ctx, 1) - require.NoError(t, err) - for _, w := range got { - w.ID = 0 - } - - want := []*Watch{ - {UserID: 1, RepoID: 1}, - {UserID: 2, RepoID: 1}, - } - assert.Equal(t, want, got) -} - -func watchesWatch(t *testing.T, db *watches) { - ctx := context.Background() - - reposStore := NewReposStore(db.DB) - repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) - require.NoError(t, err) - - err = db.Watch(ctx, 2, repo1.ID) - require.NoError(t, err) - - // It is OK to watch multiple times and just be noop. - err = db.Watch(ctx, 2, repo1.ID) - require.NoError(t, err) - - repo1, err = reposStore.GetByID(ctx, repo1.ID) - require.NoError(t, err) - assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. -} |