diff options
author | Joe Chen <jc@unknwon.io> | 2022-10-23 16:17:53 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-23 16:17:53 +0800 |
commit | b1fefcbe5011a4a792808faaf26fae6881ecc1b0 (patch) | |
tree | 42c3d2922ea8460f6f5bf1ebb314039275c25ec2 /internal/db | |
parent | 8077360cf6370c9ddb026f2432ceb4f4f4ac31c4 (diff) |
refactor(db): migrate `Follow` off `user.go` (#7203)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/backup_test.go | 15 | ||||
-rw-r--r-- | internal/db/db.go | 2 | ||||
-rw-r--r-- | internal/db/follows.go | 127 | ||||
-rw-r--r-- | internal/db/follows_test.go | 122 | ||||
-rw-r--r-- | internal/db/main_test.go | 10 | ||||
-rw-r--r-- | internal/db/models.go | 2 | ||||
-rw-r--r-- | internal/db/testdata/backup/Follow.golden.json | 2 | ||||
-rw-r--r-- | internal/db/user.go | 99 | ||||
-rw-r--r-- | internal/db/users.go | 62 | ||||
-rw-r--r-- | internal/db/users_test.go | 72 |
10 files changed, 410 insertions, 103 deletions
diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go index b79d455c..52ce9aaa 100644 --- a/internal/db/backup_test.go +++ b/internal/db/backup_test.go @@ -31,8 +31,8 @@ func TestDumpAndImport(t *testing.T) { } t.Parallel() - if len(Tables) != 5 { - t.Fatalf("New table has added (want 5 got %d), please add new tests for the table and update this check", len(Tables)) + if len(Tables) != 6 { + t.Fatalf("New table has added (want 6 got %d), please add new tests for the table and update this check", len(Tables)) } db := dbtest.NewDB(t, "dumpAndImport", Tables...) @@ -131,6 +131,17 @@ func setupDBToDump(t *testing.T, db *gorm.DB) { CreatedUnix: 1588568886, }, + &Follow{ + ID: 1, + UserID: 1, + FollowID: 2, + }, + &Follow{ + ID: 2, + UserID: 2, + FollowID: 1, + }, + &LFSObject{ RepoID: 1, OID: "ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", diff --git a/internal/db/db.go b/internal/db/db.go index f287ab15..b765dfd8 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -42,6 +42,7 @@ func newLogWriter() (logger.Writer, error) { // NOTE: Lines are sorted in alphabetical order, each letter in its own line. var Tables = []interface{}{ new(Access), new(AccessToken), new(Action), + new(Follow), new(LFSObject), new(LoginSource), } @@ -120,6 +121,7 @@ func Init(w logger.Writer) (*gorm.DB, error) { // Initialize stores, sorted in alphabetical order. AccessTokens = &accessTokens{DB: db} Actions = NewActionsStore(db) + Follows = NewFollowsStore(db) LoginSources = &loginSources{DB: db, files: sourceFiles} LFS = &lfs{DB: db} Perms = &perms{DB: db} diff --git a/internal/db/follows.go b/internal/db/follows.go new file mode 100644 index 00000000..4f3d55f0 --- /dev/null +++ b/internal/db/follows.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "context" + + "github.com/pkg/errors" + "gorm.io/gorm" +) + +// FollowsStore is the persistent interface for follows. +// +// NOTE: All methods are sorted in alphabetical order. +type FollowsStore interface { + // Follow marks the user to follow the other user. + Follow(ctx context.Context, userID, followID int64) error + // IsFollowing returns true if the user is following the other user. + IsFollowing(ctx context.Context, userID, followID int64) bool + // Unfollow removes the mark the user to follow the other user. + Unfollow(ctx context.Context, userID, followID int64) error +} + +var Follows FollowsStore + +var _ FollowsStore = (*follows)(nil) + +type follows struct { + *gorm.DB +} + +// NewFollowsStore returns a persistent interface for follows with given +// database connection. +func NewFollowsStore(db *gorm.DB) FollowsStore { + return &follows{DB: db} +} + +func (*follows) updateFollowingCount(tx *gorm.DB, userID, followID int64) error { + /* + Equivalent SQL for PostgreSQL: + + UPDATE "user" + SET num_followers = ( + SELECT COUNT(*) FROM follow WHERE follow_id = @followID + ) + WHERE id = @followID + */ + err := tx.Model(&User{}). + Where("id = ?", followID). + Update( + "num_followers", + tx.Model(&Follow{}).Select("COUNT(*)").Where("follow_id = ?", followID), + ). + Error + if err != nil { + return errors.Wrap(err, `update "num_followers"`) + } + + /* + Equivalent SQL for PostgreSQL: + + UPDATE "user" + SET num_following = ( + SELECT COUNT(*) FROM follow WHERE user_id = @userID + ) + WHERE id = @userID + */ + err = tx.Model(&User{}). + Where("id = ?", userID). + Update( + "num_following", + tx.Model(&Follow{}).Select("COUNT(*)").Where("user_id = ?", userID), + ). + Error + if err != nil { + return errors.Wrap(err, `update "num_following"`) + } + return nil +} + +func (db *follows) Follow(ctx context.Context, userID, followID int64) error { + if userID == followID { + return nil + } + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + f := &Follow{ + UserID: userID, + FollowID: followID, + } + result := tx.FirstOrCreate(f, f) + if result.Error != nil { + return errors.Wrap(result.Error, "upsert") + } else if result.RowsAffected <= 0 { + return nil // Relation already exists + } + + return db.updateFollowingCount(tx, userID, followID) + }) +} + +func (db *follows) IsFollowing(ctx context.Context, userID, followID int64) bool { + return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil +} + +func (db *follows) Unfollow(ctx context.Context, userID, followID int64) error { + if userID == followID { + return nil + } + + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error + if err != nil { + return errors.Wrap(err, "delete") + } + return db.updateFollowingCount(tx, userID, followID) + }) +} + +// Follow represents relations of users and their followers. +type Follow struct { + ID int64 `gorm:"primaryKey"` + UserID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` + FollowID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"` +} diff --git a/internal/db/follows_test.go b/internal/db/follows_test.go new file mode 100644 index 00000000..cd37cc97 --- /dev/null +++ b/internal/db/follows_test.go @@ -0,0 +1,122 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gogs.io/gogs/internal/dbtest" +) + +func TestFollows(t *testing.T) { + if testing.Short() { + t.Skip() + } + t.Parallel() + + tables := []interface{}{new(User), new(EmailAddress), new(Follow)} + db := &follows{ + DB: dbtest.NewDB(t, "follows", tables...), + } + + for _, tc := range []struct { + name string + test func(*testing.T, *follows) + }{ + {"Follow", followsFollow}, + {"IsFollowing", followsIsFollowing}, + {"Unfollow", followsUnfollow}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + err := clearTables(t, db.DB, tables...) + require.NoError(t, err) + }) + tc.test(t, db) + }) + if t.Failed() { + break + } + } +} + +func followsFollow(t *testing.T, db *follows) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + // It is OK to follow multiple times and just be noop. + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + alice, err = usersStore.GetByID(ctx, alice.ID) + require.NoError(t, err) + assert.Equal(t, 1, alice.NumFollowing) + + bob, err = usersStore.GetByID(ctx, bob.ID) + require.NoError(t, err) + assert.Equal(t, 1, bob.NumFollowers) +} + +func followsIsFollowing(t *testing.T, db *follows) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + got := db.IsFollowing(ctx, alice.ID, bob.ID) + assert.False(t, got) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + got = db.IsFollowing(ctx, alice.ID, bob.ID) + assert.True(t, got) + + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + got = db.IsFollowing(ctx, alice.ID, bob.ID) + assert.False(t, got) +} + +func followsUnfollow(t *testing.T, db *follows) { + ctx := context.Background() + + usersStore := NewUsersStore(db.DB) + alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := usersStore.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + err = db.Follow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + // It is OK to unfollow multiple times and just be noop. + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + err = db.Unfollow(ctx, alice.ID, bob.ID) + require.NoError(t, err) + + alice, err = usersStore.GetByID(ctx, alice.ID) + require.NoError(t, err) + assert.Equal(t, 0, alice.NumFollowing) + + bob, err = usersStore.GetByID(ctx, bob.ID) + require.NoError(t, err) + assert.Equal(t, 0, bob.NumFollowers) +} diff --git a/internal/db/main_test.go b/internal/db/main_test.go index bc55a0d0..4f19da43 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -15,6 +15,7 @@ import ( _ "modernc.org/sqlite" log "unknwon.dev/clog/v2" + "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/testutil" ) @@ -37,6 +38,15 @@ func TestMain(m *testing.M) { // NOTE: AutoMigrate does not respect logger passed in gorm.Config. logger.Default = logger.Default.LogMode(level) + switch os.Getenv("GOGS_DATABASE_TYPE") { + case "mysql": + conf.UseMySQL = true + case "postgres": + conf.UsePostgreSQL = true + default: + conf.UseSQLite3 = true + } + os.Exit(m.Run()) } diff --git a/internal/db/models.go b/internal/db/models.go index 31db3a15..c9cc5a3b 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -52,7 +52,7 @@ func init() { legacyTables = append(legacyTables, new(User), new(PublicKey), new(TwoFactor), new(TwoFactorRecoveryCode), new(Repository), new(DeployKey), new(Collaboration), new(Upload), - new(Watch), new(Star), new(Follow), + new(Watch), new(Star), new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser), new(Label), new(IssueLabel), new(Milestone), new(Mirror), new(Release), new(Webhook), new(HookTask), diff --git a/internal/db/testdata/backup/Follow.golden.json b/internal/db/testdata/backup/Follow.golden.json new file mode 100644 index 00000000..51250e5a --- /dev/null +++ b/internal/db/testdata/backup/Follow.golden.json @@ -0,0 +1,2 @@ +{"ID":1,"UserID":1,"FollowID":2} +{"ID":2,"UserID":2,"FollowID":1} diff --git a/internal/db/user.go b/internal/db/user.go index 48444b1e..95ee70c2 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -61,34 +61,6 @@ func (u *User) AfterSet(colName string, _ xorm.Cell) { } } -// User.GetFollowers returns range of user's followers. -func (u *User) GetFollowers(page int) ([]*User, error) { - users := make([]*User, 0, ItemsPerPage) - sess := x.Limit(ItemsPerPage, (page-1)*ItemsPerPage).Where("follow.follow_id=?", u.ID) - if conf.UsePostgreSQL { - sess = sess.Join("LEFT", "follow", `"user".id=follow.user_id`) - } else { - sess = sess.Join("LEFT", "follow", "user.id=follow.user_id") - } - return users, sess.Find(&users) -} - -func (u *User) IsFollowing(followID int64) bool { - return IsFollowing(u.ID, followID) -} - -// GetFollowing returns range of user's following. -func (u *User) GetFollowing(page int) ([]*User, error) { - users := make([]*User, 0, ItemsPerPage) - sess := x.Limit(ItemsPerPage, (page-1)*ItemsPerPage).Where("follow.user_id=?", u.ID) - if conf.UsePostgreSQL { - sess = sess.Join("LEFT", "follow", `"user".id=follow.follow_id`) - } else { - sess = sess.Join("LEFT", "follow", "user.id=follow.follow_id") - } - return users, sess.Find(&users) -} - // NewGitSig generates and returns the signature of given user. func (u *User) NewGitSig() *git.Signature { return &git.Signature{ @@ -887,77 +859,6 @@ func SearchUserByName(opts *SearchUserOptions) (users []*User, _ int64, _ error) return users, count, sess.Limit(opts.PageSize, (opts.Page-1)*opts.PageSize).Find(&users) } -// ___________ .__ .__ -// \_ _____/___ | | | | ______ _ __ -// | __)/ _ \| | | | / _ \ \/ \/ / -// | \( <_> ) |_| |_( <_> ) / -// \___ / \____/|____/____/\____/ \/\_/ -// \/ - -// Follow represents relations of user and his/her followers. -type Follow struct { - ID int64 - UserID int64 `xorm:"UNIQUE(follow)"` - FollowID int64 `xorm:"UNIQUE(follow)"` -} - -func IsFollowing(userID, followID int64) bool { - has, _ := x.Get(&Follow{UserID: userID, FollowID: followID}) - return has -} - -// FollowUser marks someone be another's follower. -func FollowUser(userID, followID int64) (err error) { - if userID == followID || IsFollowing(userID, followID) { - return nil - } - - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.Insert(&Follow{UserID: userID, FollowID: followID}); err != nil { - return err - } - - if _, err = sess.Exec("UPDATE `user` SET num_followers = num_followers + 1 WHERE id = ?", followID); err != nil { - return err - } - - if _, err = sess.Exec("UPDATE `user` SET num_following = num_following + 1 WHERE id = ?", userID); err != nil { - return err - } - return sess.Commit() -} - -// UnfollowUser unmarks someone be another's follower. -func UnfollowUser(userID, followID int64) (err error) { - if userID == followID || !IsFollowing(userID, followID) { - return nil - } - - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.Delete(&Follow{UserID: userID, FollowID: followID}); err != nil { - return err - } - - if _, err = sess.Exec("UPDATE `user` SET num_followers = num_followers - 1 WHERE id = ?", followID); err != nil { - return err - } - - if _, err = sess.Exec("UPDATE `user` SET num_following = num_following - 1 WHERE id = ?", userID); err != nil { - return err - } - return sess.Commit() -} - // GetRepositoryAccesses finds all repositories with their access mode where a user has access but does not own. func (u *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) { accesses := make([]*Access, 0, 10) diff --git a/internal/db/users.go b/internal/db/users.go index bc57f317..d537b8d3 100644 --- a/internal/db/users.go +++ b/internal/db/users.go @@ -56,6 +56,14 @@ type UsersStore interface { GetByUsername(ctx context.Context, username string) (*User, error) // HasForkedRepository returns true if the user has forked given repository. HasForkedRepository(ctx context.Context, userID, repoID int64) bool + // ListFollowers returns a list of users that are following the given user. + // Results are paginated by given page and page size, and sorted by the time of + // follow in descending order. + ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) + // ListFollowings returns a list of users that are followed by the given user. + // Results are paginated by given page and page size, and sorted by the time of + // follow in descending order. + ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) } var Users UsersStore @@ -343,6 +351,52 @@ func (db *users) HasForkedRepository(ctx context.Context, userID, repoID int64) return count > 0 } +func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { + /* + Equivalent SQL for PostgreSQL: + + SELECT * FROM "user" + LEFT JOIN follow ON follow.user_id = "user".id + WHERE follow.follow_id = @userID + ORDER BY follow.id DESC + LIMIT @limit OFFSET @offset + */ + users := make([]*User, 0, pageSize) + tx := db.WithContext(ctx). + Where("follow.follow_id = ?", userID). + Limit(pageSize).Offset((page - 1) * pageSize). + Order("follow.id DESC") + if conf.UsePostgreSQL { + tx.Joins(`LEFT JOIN follow ON follow.user_id = "user".id`) + } else { + tx.Joins(`LEFT JOIN follow ON follow.user_id = user.id`) + } + return users, tx.Find(&users).Error +} + +func (db *users) ListFollowings(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) { + /* + Equivalent SQL for PostgreSQL: + + SELECT * FROM "user" + LEFT JOIN follow ON follow.user_id = "user".id + WHERE follow.user_id = @userID + ORDER BY follow.id DESC + LIMIT @limit OFFSET @offset + */ + users := make([]*User, 0, pageSize) + tx := db.WithContext(ctx). + Where("follow.user_id = ?", userID). + Limit(pageSize).Offset((page - 1) * pageSize). + Order("follow.id DESC") + if conf.UsePostgreSQL { + tx.Joins(`LEFT JOIN follow ON follow.follow_id = "user".id`) + } else { + tx.Joins(`LEFT JOIN follow ON follow.follow_id = user.id`) + } + return users, tx.Find(&users).Error +} + // UserType indicates the type of the user account. type UserType int @@ -530,3 +584,11 @@ func (u *User) AvatarURL() string { } return link } + +// IsFollowing returns true if the user is following the given user. +// +// TODO(unknwon): This is also used in templates, which should be fixed by +// having a dedicated type `template.User`. +func (u *User) IsFollowing(followID int64) bool { + return Follows.IsFollowing(context.TODO(), u.ID, followID) +} diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 67be21dd..1a33151a 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -24,7 +24,7 @@ func TestUsers(t *testing.T) { } t.Parallel() - tables := []interface{}{new(User), new(EmailAddress), new(Repository)} + tables := []interface{}{new(User), new(EmailAddress), new(Repository), new(Follow)} db := &users{ DB: dbtest.NewDB(t, "users", tables...), } @@ -39,6 +39,8 @@ func TestUsers(t *testing.T) { {"GetByID", usersGetByID}, {"GetByUsername", usersGetByUsername}, {"HasForkedRepository", usersHasForkedRepository}, + {"ListFollowers", usersListFollowers}, + {"ListFollowings", usersListFollowings}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { @@ -296,3 +298,71 @@ func usersHasForkedRepository(t *testing.T, db *users) { has = db.HasForkedRepository(ctx, 1, 1) assert.True(t, has) } + +func usersListFollowers(t *testing.T, db *users) { + ctx := context.Background() + + john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) + require.NoError(t, err) + + got, err := db.ListFollowers(ctx, john.ID, 1, 1) + require.NoError(t, err) + assert.Empty(t, got) + + alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + followsStore := NewFollowsStore(db.DB) + err = followsStore.Follow(ctx, alice.ID, john.ID) + require.NoError(t, err) + err = followsStore.Follow(ctx, bob.ID, john.ID) + require.NoError(t, err) + + // First page only has bob + got, err = db.ListFollowers(ctx, john.ID, 1, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, bob.ID, got[0].ID) + + // Second page only has alice + got, err = db.ListFollowers(ctx, john.ID, 2, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, alice.ID, got[0].ID) +} + +func usersListFollowings(t *testing.T, db *users) { + ctx := context.Background() + + john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) + require.NoError(t, err) + + got, err := db.ListFollowers(ctx, john.ID, 1, 1) + require.NoError(t, err) + assert.Empty(t, got) + + alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) + require.NoError(t, err) + bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) + require.NoError(t, err) + + followsStore := NewFollowsStore(db.DB) + err = followsStore.Follow(ctx, john.ID, alice.ID) + require.NoError(t, err) + err = followsStore.Follow(ctx, john.ID, bob.ID) + require.NoError(t, err) + + // First page only has bob + got, err = db.ListFollowings(ctx, john.ID, 1, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, bob.ID, got[0].ID) + + // Second page only has alice + got, err = db.ListFollowings(ctx, john.ID, 2, 1) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, alice.ID, got[0].ID) +} |