aboutsummaryrefslogtreecommitdiff
path: root/internal/db/watches.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/watches.go')
-rw-r--r--internal/db/watches.go39
1 files changed, 39 insertions, 0 deletions
diff --git a/internal/db/watches.go b/internal/db/watches.go
index e93a4ab6..6ca78a37 100644
--- a/internal/db/watches.go
+++ b/internal/db/watches.go
@@ -7,6 +7,7 @@ package db
import (
"context"
+ "github.com/pkg/errors"
"gorm.io/gorm"
)
@@ -16,6 +17,8 @@ import (
type WatchesStore interface {
// ListByRepo returns all watches of the given repository.
ListByRepo(ctx context.Context, repoID int64) ([]*Watch, error)
+ // Watch marks the user to watch the repository.
+ Watch(ctx context.Context, userID, repoID int64) error
}
var Watches WatchesStore
@@ -36,3 +39,39 @@ func (db *watches) ListByRepo(ctx context.Context, repoID int64) ([]*Watch, erro
var watches []*Watch
return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
}
+
+func (db *watches) updateWatchingCount(tx *gorm.DB, repoID int64) error {
+ /*
+ Equivalent SQL for PostgreSQL:
+
+ UPDATE repository
+ SET num_watches = (
+ SELECT COUNT(*) FROM watch WHERE repo_id = @repoID
+ )
+ WHERE id = @repoID
+ */
+ return tx.Model(&Repository{}).
+ Where("id = ?", repoID).
+ Update(
+ "num_watches",
+ tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID),
+ ).
+ Error
+}
+
+func (db *watches) Watch(ctx context.Context, userID, repoID int64) error {
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ w := &Watch{
+ UserID: userID,
+ RepoID: repoID,
+ }
+ result := tx.FirstOrCreate(w, w)
+ if result.Error != nil {
+ return errors.Wrap(result.Error, "upsert")
+ } else if result.RowsAffected <= 0 {
+ return nil // Relation already exists
+ }
+
+ return db.updateWatchingCount(tx, repoID)
+ })
+}