diff options
author | Joe Chen <jc@unknwon.io> | 2023-12-17 16:32:28 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-17 16:32:28 -0500 |
commit | 25fdeaac49af622ff8e059a3502df6d7190f7313 (patch) | |
tree | 7a0a516f8034f6bdf070d2e7c27c1b70950d0476 | |
parent | 0c7b45ad1f41e0aa4494fae70cfa3b7668245e24 (diff) |
db: pass context to tests by default (#7622)
[skip ci]
-rw-r--r-- | internal/db/access_tokens_test.go | 25 | ||||
-rw-r--r-- | internal/db/actions_test.go | 51 | ||||
-rw-r--r-- | internal/db/lfs_test.go | 17 | ||||
-rw-r--r-- | internal/db/login_sources_test.go | 33 | ||||
-rw-r--r-- | internal/db/notices_test.go | 25 | ||||
-rw-r--r-- | internal/db/orgs_test.go | 17 | ||||
-rw-r--r-- | internal/db/perms_test.go | 17 | ||||
-rw-r--r-- | internal/db/public_keys_test.go | 8 | ||||
-rw-r--r-- | internal/db/repos_test.go | 45 | ||||
-rw-r--r-- | internal/db/two_factors_test.go | 17 | ||||
-rw-r--r-- | internal/db/users_test.go | 114 |
11 files changed, 114 insertions, 255 deletions
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go index 36a39c73..2a1d452f 100644 --- a/internal/db/access_tokens_test.go +++ b/internal/db/access_tokens_test.go @@ -98,6 +98,7 @@ func TestAccessTokens(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(AccessToken)} db := &accessTokens{ DB: dbtest.NewDB(t, "accessTokens", tables...), @@ -105,7 +106,7 @@ func TestAccessTokens(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *accessTokens) + test func(t *testing.T, ctx context.Context, db *accessTokens) }{ {"Create", accessTokensCreate}, {"DeleteByID", accessTokensDeleteByID}, @@ -118,7 +119,7 @@ func TestAccessTokens(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -126,9 +127,7 @@ func TestAccessTokens(t *testing.T) { } } -func accessTokensCreate(t *testing.T, db *accessTokens) { - ctx := context.Background() - +func accessTokensCreate(t *testing.T, ctx context.Context, db *accessTokens) { // Create first access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -153,9 +152,7 @@ func accessTokensCreate(t *testing.T, db *accessTokens) { assert.Equal(t, wantErr, err) } -func accessTokensDeleteByID(t *testing.T, db *accessTokens) { - ctx := context.Background() - +func accessTokensDeleteByID(t *testing.T, ctx context.Context, db *accessTokens) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -182,9 +179,7 @@ func accessTokensDeleteByID(t *testing.T, db *accessTokens) { assert.Equal(t, wantErr, err) } -func accessTokensGetBySHA(t *testing.T, db *accessTokens) { - ctx := context.Background() - +func accessTokensGetBySHA(t *testing.T, ctx context.Context, db *accessTokens) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) @@ -203,9 +198,7 @@ func accessTokensGetBySHA(t *testing.T, db *accessTokens) { assert.Equal(t, wantErr, err) } -func accessTokensList(t *testing.T, db *accessTokens) { - ctx := context.Background() - +func accessTokensList(t *testing.T, ctx context.Context, db *accessTokens) { // Create two access tokens for user 1 _, err := db.Create(ctx, 1, "user1_1") require.NoError(t, err) @@ -228,9 +221,7 @@ func accessTokensList(t *testing.T, db *accessTokens) { assert.Equal(t, "user1_2", tokens[1].Name) } -func accessTokensTouch(t *testing.T, db *accessTokens) { - ctx := context.Background() - +func accessTokensTouch(t *testing.T, ctx context.Context, db *accessTokens) { // Create an access token with name "Test" token, err := db.Create(ctx, 1, "Test") require.NoError(t, err) diff --git a/internal/db/actions_test.go b/internal/db/actions_test.go index 6b8a01a3..8bf20f69 100644 --- a/internal/db/actions_test.go +++ b/internal/db/actions_test.go @@ -97,8 +97,9 @@ func TestActions(t *testing.T) { if testing.Short() { t.Skip() } - t.Parallel() + ctx := context.Background() + t.Parallel() tables := []any{new(Action), new(User), new(Repository), new(EmailAddress), new(Watch)} db := &actions{ DB: dbtest.NewDB(t, "actions", tables...), @@ -106,7 +107,7 @@ func TestActions(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *actions) + test func(t *testing.T, ctx context.Context, db *actions) }{ {"CommitRepo", actionsCommitRepo}, {"ListByOrganization", actionsListByOrganization}, @@ -125,7 +126,7 @@ func TestActions(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -133,9 +134,7 @@ func TestActions(t *testing.T) { } } -func actionsCommitRepo(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsCommitRepo(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -327,14 +326,12 @@ func actionsCommitRepo(t *testing.T, db *actions) { }) } -func actionsListByOrganization(t *testing.T, db *actions) { +func actionsListByOrganization(t *testing.T, ctx context.Context, db *actions) { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { t.Skip("Skipping testing with not using PostgreSQL") return } - ctx := context.Background() - conf.SetMockUI(t, conf.UIOpts{ User: conf.UIUserOpts{ @@ -375,14 +372,12 @@ func actionsListByOrganization(t *testing.T, db *actions) { } } -func actionsListByUser(t *testing.T, db *actions) { +func actionsListByUser(t *testing.T, ctx context.Context, db *actions) { if os.Getenv("GOGS_DATABASE_TYPE") != "postgres" { t.Skip("Skipping testing with not using PostgreSQL") return } - ctx := context.Background() - conf.SetMockUI(t, conf.UIOpts{ User: conf.UIUserOpts{ @@ -442,9 +437,7 @@ func actionsListByUser(t *testing.T, db *actions) { } } -func actionsMergePullRequest(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsMergePullRequest(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -489,9 +482,7 @@ func actionsMergePullRequest(t *testing.T, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncCreate(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsMirrorSyncCreate(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -532,9 +523,7 @@ func actionsMirrorSyncCreate(t *testing.T, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncDelete(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsMirrorSyncDelete(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -575,9 +564,7 @@ func actionsMirrorSyncDelete(t *testing.T, db *actions) { assert.Equal(t, want, got) } -func actionsMirrorSyncPush(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsMirrorSyncPush(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -642,9 +629,7 @@ func actionsMirrorSyncPush(t *testing.T, db *actions) { assert.Equal(t, want, got) } -func actionsNewRepo(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsNewRepo(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -719,9 +704,7 @@ func actionsNewRepo(t *testing.T, db *actions) { }) } -func actionsPushTag(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsPushTag(t *testing.T, ctx context.Context, db *actions) { // NOTE: We set a noop mock here to avoid data race with other tests that writes // to the mock server because this function holds a lock. conf.SetMockServer(t, conf.ServerOpts{}) @@ -817,9 +800,7 @@ func actionsPushTag(t *testing.T, db *actions) { }) } -func actionsRenameRepo(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsRenameRepo(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) repo, err := NewReposStore(db.DB).Create(ctx, @@ -856,9 +837,7 @@ func actionsRenameRepo(t *testing.T, db *actions) { assert.Equal(t, want, got) } -func actionsTransferRepo(t *testing.T, db *actions) { - ctx := context.Background() - +func actionsTransferRepo(t *testing.T, ctx context.Context, db *actions) { alice, err := NewUsersStore(db.DB).Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) bob, err := NewUsersStore(db.DB).Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go index cf56e37d..ee56a2c7 100644 --- a/internal/db/lfs_test.go +++ b/internal/db/lfs_test.go @@ -23,6 +23,7 @@ func TestLFS(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(LFSObject)} db := &lfs{ DB: dbtest.NewDB(t, "lfs", tables...), @@ -30,7 +31,7 @@ func TestLFS(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *lfs) + test func(t *testing.T, ctx context.Context, db *lfs) }{ {"CreateObject", lfsCreateObject}, {"GetObjectByOID", lfsGetObjectByOID}, @@ -41,7 +42,7 @@ func TestLFS(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -49,9 +50,7 @@ func TestLFS(t *testing.T) { } } -func lfsCreateObject(t *testing.T, db *lfs) { - ctx := context.Background() - +func lfsCreateObject(t *testing.T, ctx context.Context, db *lfs) { // Create first LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") @@ -68,9 +67,7 @@ func lfsCreateObject(t *testing.T, db *lfs) { assert.Error(t, err) } -func lfsGetObjectByOID(t *testing.T, db *lfs) { - ctx := context.Background() - +func lfsGetObjectByOID(t *testing.T, ctx context.Context, db *lfs) { // Create a LFS object repoID := int64(1) oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") @@ -87,9 +84,7 @@ func lfsGetObjectByOID(t *testing.T, db *lfs) { assert.Equal(t, expErr, err) } -func lfsGetObjectsByOIDs(t *testing.T, db *lfs) { - ctx := context.Background() - +func lfsGetObjectsByOIDs(t *testing.T, ctx context.Context, db *lfs) { // Create two LFS objects repoID := int64(1) oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f") diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go index 13c17f79..830a07a8 100644 --- a/internal/db/login_sources_test.go +++ b/internal/db/login_sources_test.go @@ -163,6 +163,7 @@ func TestLoginSources(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(LoginSource), new(User)} db := &loginSources{ DB: dbtest.NewDB(t, "loginSources", tables...), @@ -170,7 +171,7 @@ func TestLoginSources(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *loginSources) + test func(t *testing.T, ctx context.Context, db *loginSources) }{ {"Create", loginSourcesCreate}, {"Count", loginSourcesCount}, @@ -185,7 +186,7 @@ func TestLoginSources(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -193,9 +194,7 @@ func TestLoginSources(t *testing.T) { } } -func loginSourcesCreate(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSources) { // Create first login source with name "GitHub" source, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -222,9 +221,7 @@ func loginSourcesCreate(t *testing.T, db *loginSources) { assert.Equal(t, wantErr, err) } -func loginSourcesCount(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSources) { // Create two login sources, one in database and one as source file. _, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -246,9 +243,7 @@ func loginSourcesCount(t *testing.T, db *loginSources) { assert.Equal(t, int64(3), db.Count(ctx)) } -func loginSourcesDeleteByID(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSources) { t.Run("delete but in used", func(t *testing.T) { source, err := db.Create(ctx, CreateLoginSourceOptions{ @@ -315,9 +310,7 @@ func loginSourcesDeleteByID(t *testing.T, db *loginSources) { assert.Equal(t, wantErr, err) } -func loginSourcesGetByID(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSources) { mock := NewMockLoginSourceFilesStore() mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) { if id != 101 { @@ -353,9 +346,7 @@ func loginSourcesGetByID(t *testing.T, db *loginSources) { require.NoError(t, err) } -func loginSourcesList(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesList(t *testing.T, ctx context.Context, db *loginSources) { mock := NewMockLoginSourceFilesStore() mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { if opts.OnlyActivated { @@ -404,9 +395,7 @@ func loginSourcesList(t *testing.T, db *loginSources) { assert.Equal(t, 2, len(sources), "number of sources") } -func loginSourcesResetNonDefault(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSources) { mock := NewMockLoginSourceFilesStore() mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource { mockFile := NewMockLoginSourceFileStore() @@ -461,9 +450,7 @@ func loginSourcesResetNonDefault(t *testing.T, db *loginSources) { assert.False(t, source2.IsDefault) } -func loginSourcesSave(t *testing.T, db *loginSources) { - ctx := context.Background() - +func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSources) { t.Run("save to database", func(t *testing.T) { // Create a login source with name "GitHub" source, err := db.Create(ctx, diff --git a/internal/db/notices_test.go b/internal/db/notices_test.go index 3ae934dd..b88f92f3 100644 --- a/internal/db/notices_test.go +++ b/internal/db/notices_test.go @@ -66,6 +66,7 @@ func TestNotices(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(Notice)} db := ¬ices{ DB: dbtest.NewDB(t, "notices", tables...), @@ -73,7 +74,7 @@ func TestNotices(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *notices) + test func(t *testing.T, ctx context.Context, db *notices) }{ {"Create", noticesCreate}, {"DeleteByIDs", noticesDeleteByIDs}, @@ -86,7 +87,7 @@ func TestNotices(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -94,9 +95,7 @@ func TestNotices(t *testing.T) { } } -func noticesCreate(t *testing.T, db *notices) { - ctx := context.Background() - +func noticesCreate(t *testing.T, ctx context.Context, db *notices) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -104,9 +103,7 @@ func noticesCreate(t *testing.T, db *notices) { assert.Equal(t, int64(1), count) } -func noticesDeleteByIDs(t *testing.T, db *notices) { - ctx := context.Background() - +func noticesDeleteByIDs(t *testing.T, ctx context.Context, db *notices) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -126,9 +123,7 @@ func noticesDeleteByIDs(t *testing.T, db *notices) { assert.Equal(t, int64(0), count) } -func noticesDeleteAll(t *testing.T, db *notices) { - ctx := context.Background() - +func noticesDeleteAll(t *testing.T, ctx context.Context, db *notices) { err := db.Create(ctx, NoticeTypeRepository, "test") require.NoError(t, err) @@ -139,9 +134,7 @@ func noticesDeleteAll(t *testing.T, db *notices) { assert.Equal(t, int64(0), count) } -func noticesList(t *testing.T, db *notices) { - ctx := context.Background() - +func noticesList(t *testing.T, ctx context.Context, db *notices) { err := db.Create(ctx, NoticeTypeRepository, "test 1") require.NoError(t, err) err = db.Create(ctx, NoticeTypeRepository, "test 2") @@ -161,9 +154,7 @@ func noticesList(t *testing.T, db *notices) { require.Len(t, got, 2) } -func noticesCount(t *testing.T, db *notices) { - ctx := context.Background() - +func noticesCount(t *testing.T, ctx context.Context, db *notices) { count := db.Count(ctx) assert.Equal(t, int64(0), count) diff --git a/internal/db/orgs_test.go b/internal/db/orgs_test.go index 89550d81..64831c8a 100644 --- a/internal/db/orgs_test.go +++ b/internal/db/orgs_test.go @@ -21,6 +21,7 @@ func TestOrgs(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(User), new(EmailAddress), new(OrgUser)} db := &orgs{ DB: dbtest.NewDB(t, "orgs", tables...), @@ -28,7 +29,7 @@ func TestOrgs(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *orgs) + test func(t *testing.T, ctx context.Context, db *orgs) }{ {"List", orgsList}, {"SearchByName", orgsSearchByName}, @@ -39,7 +40,7 @@ func TestOrgs(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -47,9 +48,7 @@ func TestOrgs(t *testing.T) { } } -func orgsList(t *testing.T, db *orgs) { - ctx := context.Background() - +func orgsList(t *testing.T, ctx context.Context, db *orgs) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -119,9 +118,7 @@ func orgsList(t *testing.T, db *orgs) { } } -func orgsSearchByName(t *testing.T, db *orgs) { - ctx := context.Background() - +func orgsSearchByName(t *testing.T, ctx context.Context, db *orgs) { // TODO: Use Orgs.Create to replace SQL hack when the method is available. usersStore := NewUsersStore(db.DB) org1, err := usersStore.Create(ctx, "org1", "org1@example.com", CreateUserOptions{FullName: "Acme Corp"}) @@ -166,9 +163,7 @@ func orgsSearchByName(t *testing.T, db *orgs) { }) } -func orgsCountByUser(t *testing.T, db *orgs) { - ctx := context.Background() - +func orgsCountByUser(t *testing.T, ctx context.Context, db *orgs) { // TODO: Use Orgs.Join to replace SQL hack when the method is available. err := db.Exec(`INSERT INTO org_user (uid, org_id) VALUES (?, ?)`, 1, 1).Error require.NoError(t, err) diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go index 4d3c93c6..5fc4d050 100644 --- a/internal/db/perms_test.go +++ b/internal/db/perms_test.go @@ -20,6 +20,7 @@ func TestPerms(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(Access)} db := &perms{ DB: dbtest.NewDB(t, "perms", tables...), @@ -27,7 +28,7 @@ func TestPerms(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *perms) + test func(t *testing.T, ctx context.Context, db *perms) }{ {"AccessMode", permsAccessMode}, {"Authorize", permsAuthorize}, @@ -38,7 +39,7 @@ func TestPerms(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -46,9 +47,7 @@ func TestPerms(t *testing.T) { } } -func permsAccessMode(t *testing.T, db *perms) { - ctx := context.Background() - +func permsAccessMode(t *testing.T, ctx context.Context, db *perms) { // Set up permissions err := db.SetRepoPerms(ctx, 1, map[int64]AccessMode{ @@ -159,9 +158,7 @@ func permsAccessMode(t *testing.T, db *perms) { } } -func permsAuthorize(t *testing.T, db *perms) { - ctx := context.Background() - +func permsAuthorize(t *testing.T, ctx context.Context, db *perms) { // Set up permissions err := db.SetRepoPerms(ctx, 1, map[int64]AccessMode{ @@ -247,9 +244,7 @@ func permsAuthorize(t *testing.T, db *perms) { } } -func permsSetRepoPerms(t *testing.T, db *perms) { - ctx := context.Background() - +func permsSetRepoPerms(t *testing.T, ctx context.Context, db *perms) { for _, update := range []struct { repoID int64 accessMap map[int64]AccessMode diff --git a/internal/db/public_keys_test.go b/internal/db/public_keys_test.go index 1e83014b..32983e96 100644 --- a/internal/db/public_keys_test.go +++ b/internal/db/public_keys_test.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "os" "path/filepath" @@ -23,6 +24,7 @@ func TestPublicKeys(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(PublicKey)} db := &publicKeys{ DB: dbtest.NewDB(t, "publicKeys", tables...), @@ -30,7 +32,7 @@ func TestPublicKeys(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *publicKeys) + test func(t *testing.T, ctx context.Context, db *publicKeys) }{ {"RewriteAuthorizedKeys", publicKeysRewriteAuthorizedKeys}, } { @@ -39,7 +41,7 @@ func TestPublicKeys(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -47,7 +49,7 @@ func TestPublicKeys(t *testing.T) { } } -func publicKeysRewriteAuthorizedKeys(t *testing.T, db *publicKeys) { +func publicKeysRewriteAuthorizedKeys(t *testing.T, ctx context.Context, db *publicKeys) { // TODO: Use PublicKeys.Add to replace SQL hack when the method is available. publicKey := &PublicKey{ OwnerID: 1, diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index 64b59d78..7bb74f88 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -85,6 +85,7 @@ func TestRepos(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(Repository), new(Access), new(Watch), new(User), new(EmailAddress), new(Star)} db := &repos{ DB: dbtest.NewDB(t, "repos", tables...), @@ -92,7 +93,7 @@ func TestRepos(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *repos) + test func(t *testing.T, ctx context.Context, db *repos) }{ {"Create", reposCreate}, {"GetByCollaboratorID", reposGetByCollaboratorID}, @@ -110,7 +111,7 @@ func TestRepos(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -118,9 +119,7 @@ func TestRepos(t *testing.T) { } } -func reposCreate(t *testing.T, db *repos) { - ctx := context.Background() - +func reposCreate(t *testing.T, ctx context.Context, db *repos) { t.Run("name not allowed", func(t *testing.T) { _, err := db.Create(ctx, 1, @@ -162,9 +161,7 @@ func reposCreate(t *testing.T, db *repos) { assert.Equal(t, 1, repo.NumWatches) // The owner is watching the repo by default. } -func reposGetByCollaboratorID(t *testing.T, db *repos) { - ctx := context.Background() - +func reposGetByCollaboratorID(t *testing.T, ctx context.Context, db *repos) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) @@ -190,9 +187,7 @@ func reposGetByCollaboratorID(t *testing.T, db *repos) { }) } -func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) { - ctx := context.Background() - +func reposGetByCollaboratorIDWithAccessMode(t *testing.T, ctx context.Context, db *repos) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) repo2, err := db.Create(ctx, 2, CreateRepoOptions{Name: "repo2"}) @@ -220,9 +215,7 @@ func reposGetByCollaboratorIDWithAccessMode(t *testing.T, db *repos) { assert.Equal(t, AccessModeAdmin, accessModes[repo2.ID]) } -func reposGetByID(t *testing.T, db *repos) { - ctx := context.Background() - +func reposGetByID(t *testing.T, ctx context.Context, db *repos) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) @@ -235,9 +228,7 @@ func reposGetByID(t *testing.T, db *repos) { assert.Equal(t, wantErr, err) } -func reposGetByName(t *testing.T, db *repos) { - ctx := context.Background() - +func reposGetByName(t *testing.T, ctx context.Context, db *repos) { repo, err := db.Create(ctx, 1, CreateRepoOptions{ Name: "repo1", @@ -253,9 +244,7 @@ func reposGetByName(t *testing.T, db *repos) { assert.Equal(t, wantErr, err) } -func reposStar(t *testing.T, db *repos) { - ctx := context.Background() - +func reposStar(t *testing.T, ctx context.Context, db *repos) { repo1, err := db.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) usersStore := NewUsersStore(db.DB) @@ -274,9 +263,7 @@ func reposStar(t *testing.T, db *repos) { assert.Equal(t, 1, alice.NumStars) } -func reposTouch(t *testing.T, db *repos) { - ctx := context.Background() - +func reposTouch(t *testing.T, ctx context.Context, db *repos) { repo, err := db.Create(ctx, 1, CreateRepoOptions{ Name: "repo1", @@ -302,9 +289,7 @@ func reposTouch(t *testing.T, db *repos) { assert.False(t, got.IsBare) } -func reposListWatches(t *testing.T, db *repos) { - ctx := context.Background() - +func reposListWatches(t *testing.T, ctx context.Context, db *repos) { err := db.Watch(ctx, 1, 1) require.NoError(t, err) err = db.Watch(ctx, 2, 1) @@ -325,9 +310,7 @@ func reposListWatches(t *testing.T, db *repos) { assert.Equal(t, want, got) } -func reposWatch(t *testing.T, db *repos) { - ctx := context.Background() - +func reposWatch(t *testing.T, ctx context.Context, db *repos) { reposStore := NewReposStore(db.DB) repo1, err := reposStore.Create(ctx, 1, CreateRepoOptions{Name: "repo1"}) require.NoError(t, err) @@ -344,9 +327,7 @@ func reposWatch(t *testing.T, db *repos) { assert.Equal(t, 2, repo1.NumWatches) // The owner is watching the repo by default. } -func reposHasForkedBy(t *testing.T, db *repos) { - ctx := context.Background() - +func reposHasForkedBy(t *testing.T, ctx context.Context, db *repos) { has := db.HasForkedBy(ctx, 1, 2) assert.False(t, has) diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go index 64e253bd..3ebcfd7d 100644 --- a/internal/db/two_factors_test.go +++ b/internal/db/two_factors_test.go @@ -67,6 +67,7 @@ func TestTwoFactors(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{new(TwoFactor), new(TwoFactorRecoveryCode)} db := &twoFactors{ DB: dbtest.NewDB(t, "twoFactors", tables...), @@ -74,7 +75,7 @@ func TestTwoFactors(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *twoFactors) + test func(t *testing.T, ctx context.Context, db *twoFactors) }{ {"Create", twoFactorsCreate}, {"GetByUserID", twoFactorsGetByUserID}, @@ -85,7 +86,7 @@ func TestTwoFactors(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -93,9 +94,7 @@ func TestTwoFactors(t *testing.T) { } } -func twoFactorsCreate(t *testing.T, db *twoFactors) { - ctx := context.Background() - +func twoFactorsCreate(t *testing.T, ctx context.Context, db *twoFactors) { // Create a 2FA token err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) @@ -112,9 +111,7 @@ func twoFactorsCreate(t *testing.T, db *twoFactors) { assert.Equal(t, int64(10), count) } -func twoFactorsGetByUserID(t *testing.T, db *twoFactors) { - ctx := context.Background() - +func twoFactorsGetByUserID(t *testing.T, ctx context.Context, db *twoFactors) { // Create a 2FA token for user 1 err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) @@ -129,9 +126,7 @@ func twoFactorsGetByUserID(t *testing.T, db *twoFactors) { assert.Equal(t, wantErr, err) } -func twoFactorsIsEnabled(t *testing.T, db *twoFactors) { - ctx := context.Background() - +func twoFactorsIsEnabled(t *testing.T, ctx context.Context, db *twoFactors) { // Create a 2FA token for user 1 err := db.Create(ctx, 1, "secure-key", "secure-secret") require.NoError(t, err) diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 8b2e7e59..bb0273a5 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -84,6 +84,7 @@ func TestUsers(t *testing.T) { } t.Parallel() + ctx := context.Background() tables := []any{ new(User), new(EmailAddress), new(Repository), new(Follow), new(PullRequest), new(PublicKey), new(OrgUser), new(Watch), new(Star), new(Issue), new(AccessToken), new(Collaboration), new(Action), new(IssueUser), @@ -95,7 +96,7 @@ func TestUsers(t *testing.T) { for _, tc := range []struct { name string - test func(t *testing.T, db *users) + test func(t *testing.T, ctx context.Context, db *users) }{ {"Authenticate", usersAuthenticate}, {"ChangeUsername", usersChangeUsername}, @@ -131,7 +132,7 @@ func TestUsers(t *testing.T) { err := clearTables(t, db.DB, tables...) require.NoError(t, err) }) - tc.test(t, db) + tc.test(t, ctx, db) }) if t.Failed() { break @@ -139,9 +140,7 @@ func TestUsers(t *testing.T) { } } -func usersAuthenticate(t *testing.T, db *users) { - ctx := context.Background() - +func usersAuthenticate(t *testing.T, ctx context.Context, db *users) { password := "pa$$word" alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{ @@ -236,9 +235,7 @@ func usersAuthenticate(t *testing.T, db *users) { }) } -func usersChangeUsername(t *testing.T, db *users) { - ctx := context.Background() - +func usersChangeUsername(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create( ctx, "alice", @@ -368,9 +365,7 @@ func usersChangeUsername(t *testing.T, db *users) { assert.Equal(t, strings.ToUpper(newUsername), alice.Name) } -func usersCount(t *testing.T, db *users) { - ctx := context.Background() - +func usersCount(t *testing.T, ctx context.Context, db *users) { // Has no user initially got := db.Count(ctx) assert.Equal(t, int64(0), got) @@ -393,9 +388,7 @@ func usersCount(t *testing.T, db *users) { assert.Equal(t, int64(1), got) } -func usersCreate(t *testing.T, db *users) { - ctx := context.Background() - +func usersCreate(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create( ctx, "alice", @@ -443,9 +436,7 @@ func usersCreate(t *testing.T, db *users) { assert.Equal(t, db.NowFunc().Format(time.RFC3339), user.Updated.UTC().Format(time.RFC3339)) } -func usersDeleteCustomAvatar(t *testing.T, db *users) { - ctx := context.Background() - +func usersDeleteCustomAvatar(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -479,8 +470,7 @@ func usersDeleteCustomAvatar(t *testing.T, db *users) { assert.False(t, alice.UseCustomAvatar) } -func usersDeleteByID(t *testing.T, db *users) { - ctx := context.Background() +func usersDeleteByID(t *testing.T, ctx context.Context, db *users) { reposStore := NewReposStore(db.DB) t.Run("user still has repository ownership", func(t *testing.T) { @@ -690,9 +680,7 @@ func usersDeleteByID(t *testing.T, db *users) { assert.Equal(t, wantErr, err) } -func usersDeleteInactivated(t *testing.T, db *users) { - ctx := context.Background() - +func usersDeleteInactivated(t *testing.T, ctx context.Context, db *users) { // User with repository ownership should be skipped alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -738,9 +726,7 @@ func usersDeleteInactivated(t *testing.T, db *users) { require.Len(t, users, 3) } -func usersGetByEmail(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetByEmail(t *testing.T, ctx context.Context, db *users) { t.Run("empty email", func(t *testing.T) { _, err := db.GetByEmail(ctx, "") wantErr := ErrUserNotExist{args: errutil.Args{"email": ""}} @@ -801,9 +787,7 @@ func usersGetByEmail(t *testing.T, db *users) { }) } -func usersGetByID(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetByID(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -816,9 +800,7 @@ func usersGetByID(t *testing.T, db *users) { assert.Equal(t, wantErr, err) } -func usersGetByUsername(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetByUsername(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -831,9 +813,7 @@ func usersGetByUsername(t *testing.T, db *users) { assert.Equal(t, wantErr, err) } -func usersGetByKeyID(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetByKeyID(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) @@ -858,9 +838,7 @@ func usersGetByKeyID(t *testing.T, db *users) { assert.Equal(t, wantErr, err) } -func usersGetMailableEmailsByUsernames(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetMailableEmailsByUsernames(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@exmaple.com", CreateUserOptions{}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@exmaple.com", CreateUserOptions{Activated: true}) @@ -874,9 +852,7 @@ func usersGetMailableEmailsByUsernames(t *testing.T, db *users) { assert.Equal(t, want, got) } -func usersIsUsernameUsed(t *testing.T, db *users) { - ctx := context.Background() - +func usersIsUsernameUsed(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -926,9 +902,7 @@ func usersIsUsernameUsed(t *testing.T, db *users) { } } -func usersList(t *testing.T, db *users) { - ctx := context.Background() - +func usersList(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{}) @@ -961,9 +935,7 @@ func usersList(t *testing.T, db *users) { assert.Equal(t, bob.ID, got[1].ID) } -func usersListFollowers(t *testing.T, db *users) { - ctx := context.Background() - +func usersListFollowers(t *testing.T, ctx context.Context, db *users) { john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -994,9 +966,7 @@ func usersListFollowers(t *testing.T, db *users) { assert.Equal(t, alice.ID, got[0].ID) } -func usersListFollowings(t *testing.T, db *users) { - ctx := context.Background() - +func usersListFollowings(t *testing.T, ctx context.Context, db *users) { john, err := db.Create(ctx, "john", "john@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1027,9 +997,7 @@ func usersListFollowings(t *testing.T, db *users) { assert.Equal(t, alice.ID, got[0].ID) } -func usersSearchByName(t *testing.T, db *users) { - ctx := context.Background() - +func usersSearchByName(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{FullName: "Alice Jordan"}) require.NoError(t, err) bob, err := db.Create(ctx, "bob", "bob@example.com", CreateUserOptions{FullName: "Bob Jordan"}) @@ -1067,9 +1035,7 @@ func usersSearchByName(t *testing.T, db *users) { }) } -func usersUpdate(t *testing.T, db *users) { - ctx := context.Background() - +func usersUpdate(t *testing.T, ctx context.Context, db *users) { const oldPassword = "Password" alice, err := db.Create( ctx, @@ -1182,9 +1148,7 @@ func usersUpdate(t *testing.T, db *users) { assertValues() } -func usersUseCustomAvatar(t *testing.T, db *users) { - ctx := context.Background() - +func usersUseCustomAvatar(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1222,9 +1186,7 @@ func TestIsUsernameAllowed(t *testing.T) { } } -func usersAddEmail(t *testing.T, db *users) { - ctx := context.Background() - +func usersAddEmail(t *testing.T, ctx context.Context, db *users) { t.Run("multiple users can add the same unverified email", func(t *testing.T) { alice, err := db.Create(ctx, "alice", "unverified@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1241,9 +1203,7 @@ func usersAddEmail(t *testing.T, db *users) { }) } -func usersGetEmail(t *testing.T, db *users) { - ctx := context.Background() - +func usersGetEmail(t *testing.T, ctx context.Context, db *users) { const testUserID = 1 const testEmail = "alice@example.com" _, err := db.GetEmail(ctx, testUserID, testEmail, false) @@ -1275,9 +1235,7 @@ func usersGetEmail(t *testing.T, db *users) { assert.Equal(t, testEmail, got.Email) } -func usersListEmails(t *testing.T, db *users) { - ctx := context.Background() - +func usersListEmails(t *testing.T, ctx context.Context, db *users) { t.Run("list emails with primary email", func(t *testing.T) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1313,9 +1271,7 @@ func usersListEmails(t *testing.T, db *users) { }) } -func usersMarkEmailActivated(t *testing.T, db *users) { - ctx := context.Background() - +func usersMarkEmailActivated(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1333,8 +1289,7 @@ func usersMarkEmailActivated(t *testing.T, db *users) { assert.NotEqual(t, alice.Rands, gotAlice.Rands) } -func usersMarkEmailPrimary(t *testing.T, db *users) { - ctx := context.Background() +func usersMarkEmailPrimary(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) err = db.AddEmail(ctx, alice.ID, "alice2@example.com", false) @@ -1360,8 +1315,7 @@ func usersMarkEmailPrimary(t *testing.T, db *users) { assert.False(t, gotEmail.IsActivated) } -func usersDeleteEmail(t *testing.T, db *users) { - ctx := context.Background() +func usersDeleteEmail(t *testing.T, ctx context.Context, db *users) { alice, err := db.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1377,9 +1331,7 @@ func usersDeleteEmail(t *testing.T, db *users) { require.Equal(t, want, got) } -func usersFollow(t *testing.T, db *users) { - ctx := context.Background() - +func usersFollow(t *testing.T, ctx context.Context, db *users) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1402,9 +1354,7 @@ func usersFollow(t *testing.T, db *users) { assert.Equal(t, 1, bob.NumFollowers) } -func usersIsFollowing(t *testing.T, db *users) { - ctx := context.Background() - +func usersIsFollowing(t *testing.T, ctx context.Context, db *users) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) @@ -1425,9 +1375,7 @@ func usersIsFollowing(t *testing.T, db *users) { assert.False(t, got) } -func usersUnfollow(t *testing.T, db *users) { - ctx := context.Background() - +func usersUnfollow(t *testing.T, ctx context.Context, db *users) { usersStore := NewUsersStore(db.DB) alice, err := usersStore.Create(ctx, "alice", "alice@example.com", CreateUserOptions{}) require.NoError(t, err) |