aboutsummaryrefslogtreecommitdiff
path: root/internal/db/repos.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/repos.go')
-rw-r--r--internal/db/repos.go59
1 files changed, 56 insertions, 3 deletions
diff --git a/internal/db/repos.go b/internal/db/repos.go
index 7791a458..28a38148 100644
--- a/internal/db/repos.go
+++ b/internal/db/repos.go
@@ -19,8 +19,6 @@ import (
)
// ReposStore is the persistent interface for repositories.
-//
-// NOTE: All methods are sorted in alphabetical order.
type ReposStore interface {
// Create creates a new repository record in the database. It returns
// ErrNameNotAllowed when the repository name is not allowed, or
@@ -48,6 +46,14 @@ type ReposStore interface {
// Touch updates the updated time to the current time and removes the bare state
// of the given repository.
Touch(ctx context.Context, id int64) error
+
+ // ListWatches returns all watches of the given repository.
+ ListWatches(ctx context.Context, repoID int64) ([]*Watch, error)
+ // Watch marks the user to watch the repository.
+ Watch(ctx context.Context, userID, repoID int64) error
+
+ // HasForkedBy returns true if the given repository has forked by the given user.
+ HasForkedBy(ctx context.Context, repoID, userID int64) bool
}
var Repos ReposStore
@@ -189,7 +195,7 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio
return errors.Wrap(err, "create")
}
- err = NewWatchesStore(tx).Watch(ctx, ownerID, repo.ID)
+ err = NewReposStore(tx).Watch(ctx, ownerID, repo.ID)
if err != nil {
return errors.Wrap(err, "watch")
}
@@ -371,3 +377,50 @@ func (db *repos) Touch(ctx context.Context, id int64) error {
}).
Error
}
+
+func (db *repos) ListWatches(ctx context.Context, repoID int64) ([]*Watch, error) {
+ var watches []*Watch
+ return watches, db.WithContext(ctx).Where("repo_id = ?", repoID).Find(&watches).Error
+}
+
+func (db *repos) recountWatches(tx *gorm.DB, repoID int64) error {
+ /*
+ Equivalent SQL for PostgreSQL:
+
+ UPDATE repository
+ SET num_watches = (
+ SELECT COUNT(*) FROM watch WHERE repo_id = @repoID
+ )
+ WHERE id = @repoID
+ */
+ return tx.Model(&Repository{}).
+ Where("id = ?", repoID).
+ Update(
+ "num_watches",
+ tx.Model(&Watch{}).Select("COUNT(*)").Where("repo_id = ?", repoID),
+ ).
+ Error
+}
+
+func (db *repos) Watch(ctx context.Context, userID, repoID int64) error {
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ w := &Watch{
+ UserID: userID,
+ RepoID: repoID,
+ }
+ result := tx.FirstOrCreate(w, w)
+ if result.Error != nil {
+ return errors.Wrap(result.Error, "upsert")
+ } else if result.RowsAffected <= 0 {
+ return nil // Relation already exists
+ }
+
+ return db.recountWatches(tx, repoID)
+ })
+}
+
+func (db *repos) HasForkedBy(ctx context.Context, repoID, userID int64) bool {
+ var count int64
+ db.WithContext(ctx).Model(new(Repository)).Where("owner_id = ? AND fork_id = ?", userID, repoID).Count(&count)
+ return count > 0
+}