aboutsummaryrefslogtreecommitdiff
path: root/internal/db/follows.go
blob: bf50042a811d254ef8dc576e12358af0d6f4ff74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright 2022 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package db

import (
	"context"

	"github.com/pkg/errors"
	"gorm.io/gorm"
)

// FollowsStore is the persistent interface for user follows.
//
// NOTE: All methods are sorted in alphabetical order.
type FollowsStore interface {
	// Follow marks the user to follow the other user.
	Follow(ctx context.Context, userID, followID int64) error
	// IsFollowing returns true if the user is following the other user.
	IsFollowing(ctx context.Context, userID, followID int64) bool
	// Unfollow removes the mark the user to follow the other user.
	Unfollow(ctx context.Context, userID, followID int64) error
}

var Follows FollowsStore

var _ FollowsStore = (*follows)(nil)

type follows struct {
	*gorm.DB
}

// NewFollowsStore returns a persistent interface for user follows with given
// database connection.
func NewFollowsStore(db *gorm.DB) FollowsStore {
	return &follows{DB: db}
}

func (*follows) updateFollowingCount(tx *gorm.DB, userID, followID int64) error {
	/*
		Equivalent SQL for PostgreSQL:

		UPDATE "user"
		SET num_followers = (
			SELECT COUNT(*) FROM follow WHERE follow_id = @followID
		)
		WHERE id = @followID
	*/
	err := tx.Model(&User{}).
		Where("id = ?", followID).
		Update(
			"num_followers",
			tx.Model(&Follow{}).Select("COUNT(*)").Where("follow_id = ?", followID),
		).
		Error
	if err != nil {
		return errors.Wrap(err, `update "user.num_followers"`)
	}

	/*
		Equivalent SQL for PostgreSQL:

		UPDATE "user"
		SET num_following = (
			SELECT COUNT(*) FROM follow WHERE user_id = @userID
		)
		WHERE id = @userID
	*/
	err = tx.Model(&User{}).
		Where("id = ?", userID).
		Update(
			"num_following",
			tx.Model(&Follow{}).Select("COUNT(*)").Where("user_id = ?", userID),
		).
		Error
	if err != nil {
		return errors.Wrap(err, `update "user.num_following"`)
	}
	return nil
}

func (db *follows) Follow(ctx context.Context, userID, followID int64) error {
	if userID == followID {
		return nil
	}

	return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
		f := &Follow{
			UserID:   userID,
			FollowID: followID,
		}
		result := tx.FirstOrCreate(f, f)
		if result.Error != nil {
			return errors.Wrap(result.Error, "upsert")
		} else if result.RowsAffected <= 0 {
			return nil // Relation already exists
		}

		return db.updateFollowingCount(tx, userID, followID)
	})
}

func (db *follows) IsFollowing(ctx context.Context, userID, followID int64) bool {
	return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
}

func (db *follows) Unfollow(ctx context.Context, userID, followID int64) error {
	if userID == followID {
		return nil
	}

	return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
		err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error
		if err != nil {
			return errors.Wrap(err, "delete")
		}
		return db.updateFollowingCount(tx, userID, followID)
	})
}

// Follow represents relations of users and their followers.
type Follow struct {
	ID       int64 `gorm:"primaryKey"`
	UserID   int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"`
	FollowID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"`
}