diff options
Diffstat (limited to 'internal/db/watches.go')
-rw-r--r-- | internal/db/watches.go | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/internal/db/watches.go b/internal/db/watches.go index e93a4ab6..6ca78a37 100644 --- a/internal/db/watches.go +++ b/internal/db/watches.go @@ -7,6 +7,7 @@ package db import ( "context" + "github.com/pkg/errors" "gorm.io/gorm" ) @@ -16,6 +17,8 @@ import ( type WatchesStore interface { // ListByRepo returns all watches of the given repository. ListByRepo(ctx context.Context, repoID int64) ([]*Watch, error) + // Watch marks the user to watch the repository. + Watch(ctx context.Context, userID, repoID int64) error } var Watches WatchesStore @@ -36,3 +39,39 @@ func (db *watches) ListByRepo(ctx context.Context, repoID int64) ([]*Watch, erro var watches []*Watch return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error } + +func (db *watches) updateWatchingCount(tx *gorm.DB, repoID int64) error { + /* + Equivalent SQL for PostgreSQL: + + UPDATE repository + SET num_watches = ( + SELECT COUNT(*) FROM watch WHERE repo_id = @repoID + ) + WHERE id = @repoID + */ + return tx.Model(&Repository{}). + Where("id = ?", repoID). + Update( + "num_watches", + tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID), + ). + Error +} + +func (db *watches) Watch(ctx context.Context, userID, repoID int64) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + w := &Watch{ + UserID: userID, + RepoID: repoID, + } + result := tx.FirstOrCreate(w, w) + if result.Error != nil { + return errors.Wrap(result.Error, "upsert") + } else if result.RowsAffected <= 0 { + return nil // Relation already exists + } + + return db.updateWatchingCount(tx, repoID) + }) +} |