diff options
author | Joe Chen <jc@unknwon.io> | 2023-02-05 16:28:47 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-05 16:28:47 +0800 |
commit | 7ff09cf35916cad50495c26a47f4c0d05487e24e (patch) | |
tree | dc4516234c3a8e72051e2ab1674adb7e7c32c3cc /internal/db | |
parent | 3c43b9b21c74faf60d62b2cbf2ee89e9ada37f0c (diff) |
refactor(db): migrate methods off `user.go` (#7336)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/actions_test.go | 4 | ||||
-rw-r--r-- | internal/db/db.go | 2 | ||||
-rw-r--r-- | internal/db/perms.go | 6 | ||||
-rw-r--r-- | internal/db/repos.go | 63 | ||||
-rw-r--r-- | internal/db/repos_test.go | 62 | ||||
-rw-r--r-- | internal/db/user.go | 39 |
6 files changed, 135 insertions, 41 deletions
diff --git a/internal/db/actions_test.go b/internal/db/actions_test.go index aa6bbf4e..15e11e9f 100644 --- a/internal/db/actions_test.go +++ b/internal/db/actions_test.go @@ -720,6 +720,10 @@ func actionsNewRepo(t *testing.T, db *actions) { func actionsPushTag(t *testing.T, db *actions) { ctx := context.Background() + // NOTE: We set a noop mock here to avoid data race with other tests that writes + // to the mock server because this function holds a lock. + conf.SetMockServer(t, conf.ServerOpts{}) + alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, diff --git a/internal/db/db.go b/internal/db/db.go index 7b063dec..2e883fc0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -127,7 +127,7 @@ func Init(w logger.Writer) (*gorm.DB, error) { LFS = &lfs{DB: db} Orgs = NewOrgsStore(db) OrgUsers = NewOrgUsersStore(db) - Perms = &perms{DB: db} + Perms = NewPermsStore(db) Repos = NewReposStore(db) TwoFactors = &twoFactors{DB: db} Users = NewUsersStore(db) diff --git a/internal/db/perms.go b/internal/db/perms.go index a72a013a..b0a1a85a 100644 --- a/internal/db/perms.go +++ b/internal/db/perms.go @@ -82,6 +82,12 @@ type perms struct { *gorm.DB } +// NewPermsStore returns a persistent interface for permissions with given +// database connection. +func NewPermsStore(db *gorm.DB) PermsStore { + return &perms{DB: db} +} + type AccessModeOptions struct { OwnerID int64 // The ID of the repository owner. Private bool // Whether the repository is private. diff --git a/internal/db/repos.go b/internal/db/repos.go index 7ca78ad8..10f292b9 100644 --- a/internal/db/repos.go +++ b/internal/db/repos.go @@ -26,6 +26,16 @@ type ReposStore interface { // ErrRepoAlreadyExist when a repository with same name already exists for the // owner. Create(ctx context.Context, ownerID int64, opts CreateRepoOptions) (*Repository, error) + // GetByCollaboratorID returns a list of repositories that the given + // collaborator has access to. Results are limited to the given limit and sorted + // by the given order (e.g. "updated_unix DESC"). Repositories that are owned + // directly by the given collaborator are not included. + GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) + // GetByCollaboratorIDWithAccessMode returns a list of repositories and + // corresponding access mode that the given collaborator has access to. + // Repositories that are owned directly by the given collaborator are not + // included. + GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) // GetByName returns the repository with given owner and name. It returns // ErrRepoNotExist when not found. GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) @@ -170,6 +180,59 @@ func (db *repos) Create(ctx context.Context, ownerID int64, opts CreateRepoOptio return repo, db.WithContext(ctx).Create(repo).Error } +func (db *repos) GetByCollaboratorID(ctx context.Context, collaboratorID int64, limit int, orderBy string) ([]*Repository, error) { + /* + Equivalent SQL for PostgreSQL: + + SELECT * FROM repository + JOIN access ON access.repo_id = repository.id AND access.user_id = @collaboratorID + WHERE access.mode >= @accessModeRead + ORDER BY @orderBy + LIMIT @limit + */ + var repos []*Repository + return repos, db.WithContext(ctx). + Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). + Where("access.mode >= ?", AccessModeRead). + Order(orderBy). + Limit(limit). + Find(&repos). + Error +} + +func (db *repos) GetByCollaboratorIDWithAccessMode(ctx context.Context, collaboratorID int64) (map[*Repository]AccessMode, error) { + /* + Equivalent SQL for PostgreSQL: + + SELECT + repository.*, + access.mode + FROM repository + JOIN access ON access.repo_id = repository.id AND access.user_id = @collaboratorID + WHERE access.mode >= @accessModeRead + */ + var reposWithAccessMode []*struct { + *Repository + Mode AccessMode + } + err := db.WithContext(ctx). + Select("repository.*", "access.mode"). + Table("repository"). + Joins("JOIN access ON access.repo_id = repository.id AND access.user_id = ?", collaboratorID). + Where("access.mode >= ?", AccessModeRead). + Find(&reposWithAccessMode). + Error + if err != nil { + return nil, err + } + + repos := make(map[*Repository]AccessMode, len(reposWithAccessMode)) + for _, repoWithAccessMode := range reposWithAccessMode { + repos[repoWithAccessMode.Repository] = repoWithAccessMode.Mode + } + return repos, nil +} + var _ errutil.NotFound = (*ErrRepoNotExist)(nil) type ErrRepoNotExist struct { diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index 3c01105d..09289729 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -85,7 +85,7 @@ func TestRepos(t *testing.T) { } t.Parallel() - tables := []any{new(Repository)} + tables := []any{new(Repository), new(Access)} db := &repos{ DB: dbtest.NewDB(t, "repos", tables...), } @@ -95,6 +95,8 @@ func TestRepos(t *testing.T) { test func(t *testing.T, db *repos) }{ {"Create", reposCreate}, + {"GetByCollaboratorID", reposGetByCollaboratorID}, + {"GetByCollaboratorIDWithAccessMode", reposGetByCollaboratorIDWithAccessMode}, {"GetByName", reposGetByName}, {"Touch", reposTouch}, } { @@ -154,6 +156,64 @@ func reposCreate(t *testing.T, db *repos) { assert.Equal(t, db.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339)) } +func reposGetByCollaboratorID(t *testing.T, db *repos) { + ctx := context.Background() + + repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) + require.NoError(t, err) + repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) + require.NoError(t, err) + + permsStore := NewPermsStore(db.DB) + err = permsStore.SetRepoPerms(ctx, repo1.ID, map[int64]AccessMode{3: AccessModeRead}) + require.NoError(t, err) + err = permsStore.SetRepoPerms(ctx, repo2.ID, map[int64]AccessMode{4: AccessModeAdmin}) + require.NoError(t, err) + + t.Run("user 3 is a collaborator of repo1", func(t *testing.T) { + got, err := db.GetByCollaboratorID(ctx, 3, 10, "") + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, repo1.ID, got[0].ID) + }) + + t.Run("do not return directly owned repository", func(t *testing.T) { + got, err := db.GetByCollaboratorID(ctx, 1, 10, "") + require.NoError(t, err) + require.Len(t, got, 0) + }) +} + +func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) { + ctx := context.Background() + + repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) + require.NoError(t, err) + repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) + require.NoError(t, err) + repo3, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo3"}) + require.NoError(t, err) + + permsStore := NewPermsStore(db.DB) + err = permsStore.SetRepoPerms(ctx, repo1.ID, map[int64]AccessMode{3: AccessModeRead}) + require.NoError(t, err) + err = permsStore.SetRepoPerms(ctx, repo2.ID, map[int64]AccessMode{3: AccessModeAdmin, 4: AccessModeWrite}) + require.NoError(t, err) + err = permsStore.SetRepoPerms(ctx, repo3.ID, map[int64]AccessMode{4: AccessModeWrite}) + require.NoError(t, err) + + got, err := db.GetByCollaboratorIDWithAccessMode(ctx, 3) + require.NoError(t, err) + require.Len(t, got, 2) + + accessModes := make(map[int64]AccessMode) + for repo, mode := range got { + accessModes[repo.ID] = mode + } + assert.Equal(t, AccessModeRead, accessModes[repo1.ID]) + assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID]) +} + func reposGetByName(t *testing.T, db *repos) { ctx := context.Background() diff --git a/internal/db/user.go b/internal/db/user.go index 5be64668..9cbc2fa3 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -10,7 +10,6 @@ import ( "os" "time" - log "unknwon.dev/clog/v2" "xorm.io/xorm" "gogs.io/gogs/internal/repoutil" @@ -196,41 +195,3 @@ func DeleteInactivateUsers() (err error) { _, err = x.Where("is_activated = ?", false).Delete(new(EmailAddress)) return err } - -// GetRepositoryAccesses finds all repositories with their access mode where a user has access but does not own. -func (u *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) { - accesses := make([]*Access, 0, 10) - if err := x.Find(&accesses, &Access{UserID: u.ID}); err != nil { - return nil, err - } - - repos := make(map[*Repository]AccessMode, len(accesses)) - for _, access := range accesses { - repo, err := GetRepositoryByID(access.RepoID) - if err != nil { - if IsErrRepoNotExist(err) { - log.Error("Failed to get repository by ID: %v", err) - continue - } - return nil, err - } - if repo.OwnerID == u.ID { - continue - } - repos[repo] = access.Mode - } - return repos, nil -} - -// GetAccessibleRepositories finds repositories which the user has access but does not own. -// If limit is smaller than 1 means returns all found results. -func (user *User) GetAccessibleRepositories(limit int) (repos []*Repository, _ error) { - sess := x.Where("owner_id !=? ", user.ID).Desc("updated_unix") - if limit > 0 { - sess.Limit(limit) - repos = make([]*Repository, 0, limit) - } else { - repos = make([]*Repository, 0, 10) - } - return repos, sess.Join("INNER", "access", "access.user_id = ? AND access.repo_id = repository.id", user.ID).Find(&repos) -} |