diff options
Diffstat (limited to 'internal/db/repos.go')
-rw-r--r-- | internal/db/repos.go | 59 |
1 files changed, 56 insertions, 3 deletions
diff --git a/internal/db/repos.go b/internal/db/repos.go index 7791a458..28a38148 100644 --- a/internal/db/repos.go +++ b/internal/db/repos.go @@ -19,8 +19,6 @@ import ( ) // ReposStore is the persistent interface for repositories. -// -// NOTE: All methods are sorted in alphabetical order. type ReposStore interface { // Create creates a new repository record in the database. It returns // ErrNameNotAllowed when the repository name is not allowed, or @@ -48,6 +46,14 @@ type ReposStore interface { // Touch updates the updated time to the current time and removes the bare state // of the given repository. Touch(ctx context.Context, id int64) error + + // ListWatches returns all watches of the given repository. + ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) + // Watch marks the user to watch the repository. + Watch(ctx context.Context, userID, repoID int64) error + + // HasForkedBy returns true if the given repository has forked by the given user. + HasForkedBy(ctx context.Context, repoID, userID int64) bool } var Repos ReposStore @@ -189,7 +195,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio return errors.Wrap(err, "create") } - err = NewWatchesStore(tx).Watch(ctx, ownerID, repo.ID) + err = NewReposStore(tx).Watch(ctx, ownerID, repo.ID) if err != nil { return errors.Wrap(err, "watch") } @@ -371,3 +377,50 @@ func (db *repos) Touch(ctx context.Context, id int64) error { }). Error } + +func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) { + var watches []*Watch + return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error +} + +func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error { + /* + Equivalent SQL for PostgreSQL: + + UPDATE repository + SET num_watches = ( + SELECT COUNT(*) FROM watch WHERE repo_id = @repoID + ) + WHERE id = @repoID + */ + return tx.Model(&Repository{}). + Where("id = ?", repoID). + Update( + "num_watches", + tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID), + ). + Error +} + +func (db *repos) Watch(ctx context.Context, userID, repoID int64) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + w := &Watch{ + UserID: userID, + RepoID: repoID, + } + result := tx.FirstOrCreate(w, w) + if result.Error != nil { + return errors.Wrap(result.Error, "upsert") + } else if result.RowsAffected <= 0 { + return nil // Relation already exists + } + + return db.recountWatches(tx, repoID) + }) +} + +func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool { + var count int64 + db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count) + return count > 0 +} |