aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/db.go2
-rw-r--r--internal/db/db_test.go20
-rw-r--r--internal/db/mock_gen.go12
-rw-r--r--internal/db/mocks.go153
-rw-r--r--internal/db/repos.go26
-rw-r--r--internal/db/repos_test.go101
6 files changed, 233 insertions, 81 deletions
diff --git a/internal/db/db.go b/internal/db/db.go
index b58d88c8..e67ffde9 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -73,7 +73,7 @@ func newDSN(opts conf.DatabaseOpts) (dsn string, err error) {
case "postgres":
host, port := parsePostgreSQLHostPort(opts.Host)
- dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s'",
+ dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'",
opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema)
case "mssql":
diff --git a/internal/db/db_test.go b/internal/db/db_test.go
index 1f4f0109..65f0c067 100644
--- a/internal/db/db_test.go
+++ b/internal/db/db_test.go
@@ -63,9 +63,9 @@ func Test_parseDSN(t *testing.T) {
})
tests := []struct {
- name string
- opts conf.DatabaseOpts
- expDSN string
+ name string
+ opts conf.DatabaseOpts
+ wantDSN string
}{
{
name: "mysql: unix",
@@ -76,7 +76,7 @@ func Test_parseDSN(t *testing.T) {
User: "gogs",
Password: "pa$$word",
},
- expDSN: "gogs:pa$$word@unix(/tmp/mysql.sock)/gogs?charset=utf8mb4&parseTime=true",
+ wantDSN: "gogs:pa$$word@unix(/tmp/mysql.sock)/gogs?charset=utf8mb4&parseTime=true",
},
{
name: "mysql: tcp",
@@ -87,7 +87,7 @@ func Test_parseDSN(t *testing.T) {
User: "gogs",
Password: "pa$$word",
},
- expDSN: "gogs:pa$$word@tcp(localhost:3306)/gogs?charset=utf8mb4&parseTime=true",
+ wantDSN: "gogs:pa$$word@tcp(localhost:3306)/gogs?charset=utf8mb4&parseTime=true",
},
{
@@ -101,7 +101,7 @@ func Test_parseDSN(t *testing.T) {
Password: "pa$$word",
SSLMode: "disable",
},
- expDSN: "user='gogs@local' password='pa$$word' host='/tmp/pg.sock' port='5432' dbname='gogs' sslmode='disable' search_path='test'",
+ wantDSN: "user='gogs@local' password='pa$$word' host='/tmp/pg.sock' port='5432' dbname='gogs' sslmode='disable' search_path='test' application_name='gogs'",
},
{
name: "postgres: tcp",
@@ -114,7 +114,7 @@ func Test_parseDSN(t *testing.T) {
Password: "pa$$word",
SSLMode: "disable",
},
- expDSN: "user='gogs@local' password='pa$$word' host='127.0.0.1' port='5432' dbname='gogs' sslmode='disable' search_path='test'",
+ wantDSN: "user='gogs@local' password='pa$$word' host='127.0.0.1' port='5432' dbname='gogs' sslmode='disable' search_path='test' application_name='gogs'",
},
{
@@ -126,7 +126,7 @@ func Test_parseDSN(t *testing.T) {
User: "gogs@local",
Password: "pa$$word",
},
- expDSN: "server=127.0.0.1; port=1433; database=gogs; user id=gogs@local; password=pa$$word;",
+ wantDSN: "server=127.0.0.1; port=1433; database=gogs; user id=gogs@local; password=pa$$word;",
},
{
@@ -135,7 +135,7 @@ func Test_parseDSN(t *testing.T) {
Type: "sqlite3",
Path: "/tmp/gogs.db",
},
- expDSN: "file:/tmp/gogs.db?cache=shared&mode=rwc",
+ wantDSN: "file:/tmp/gogs.db?cache=shared&mode=rwc",
},
}
for _, test := range tests {
@@ -144,7 +144,7 @@ func Test_parseDSN(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, test.expDSN, dsn)
+ assert.Equal(t, test.wantDSN, dsn)
})
}
}
diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go
index 8d94112f..eba5faf7 100644
--- a/internal/db/mock_gen.go
+++ b/internal/db/mock_gen.go
@@ -8,7 +8,7 @@ import (
"testing"
)
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i TwoFactorsStore -i UsersStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i ReposStore -i TwoFactorsStore -i UsersStore -o mocks.go
func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
before := AccessTokens
@@ -42,16 +42,6 @@ func SetMockPermsStore(t *testing.T, mock PermsStore) {
})
}
-var _ ReposStore = (*MockReposStore)(nil)
-
-type MockReposStore struct {
- MockGetByName func(ownerID int64, name string) (*Repository, error)
-}
-
-func (m *MockReposStore) GetByName(ownerID int64, name string) (*Repository, error) {
- return m.MockGetByName(ownerID, name)
-}
-
func SetMockReposStore(t *testing.T, mock ReposStore) {
before := Repos
Repos = mock
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index d3e302a8..ddf9ee5d 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -2371,6 +2371,159 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
+// MockReposStore is a mock implementation of the ReposStore interface (from
+// the package gogs.io/gogs/internal/db) used for unit testing.
+type MockReposStore struct {
+ // GetByNameFunc is an instance of a mock function object controlling
+ // the behavior of the method GetByName.
+ GetByNameFunc *ReposStoreGetByNameFunc
+}
+
+// NewMockReposStore creates a new mock of the ReposStore interface. All
+// methods return zero values for all results, unless overwritten.
+func NewMockReposStore() *MockReposStore {
+ return &MockReposStore{
+ GetByNameFunc: &ReposStoreGetByNameFunc{
+ defaultHook: func(context.Context, int64, string) (r0 *Repository, r1 error) {
+ return
+ },
+ },
+ }
+}
+
+// NewStrictMockReposStore creates a new mock of the ReposStore interface.
+// All methods panic on invocation, unless overwritten.
+func NewStrictMockReposStore() *MockReposStore {
+ return &MockReposStore{
+ GetByNameFunc: &ReposStoreGetByNameFunc{
+ defaultHook: func(context.Context, int64, string) (*Repository, error) {
+ panic("unexpected invocation of MockReposStore.GetByName")
+ },
+ },
+ }
+}
+
+// NewMockReposStoreFrom creates a new mock of the MockReposStore interface.
+// All methods delegate to the given implementation, unless overwritten.
+func NewMockReposStoreFrom(i ReposStore) *MockReposStore {
+ return &MockReposStore{
+ GetByNameFunc: &ReposStoreGetByNameFunc{
+ defaultHook: i.GetByName,
+ },
+ }
+}
+
+// ReposStoreGetByNameFunc describes the behavior when the GetByName method
+// of the parent MockReposStore instance is invoked.
+type ReposStoreGetByNameFunc struct {
+ defaultHook func(context.Context, int64, string) (*Repository, error)
+ hooks []func(context.Context, int64, string) (*Repository, error)
+ history []ReposStoreGetByNameFuncCall
+ mutex sync.Mutex
+}
+
+// GetByName delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockReposStore) GetByName(v0 context.Context, v1 int64, v2 string) (*Repository, error) {
+ r0, r1 := m.GetByNameFunc.nextHook()(v0, v1, v2)
+ m.GetByNameFunc.appendCall(ReposStoreGetByNameFuncCall{v0, v1, v2, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByName method of
+// the parent MockReposStore instance is invoked and the hook queue is
+// empty.
+func (f *ReposStoreGetByNameFunc) SetDefaultHook(hook func(context.Context, int64, string) (*Repository, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByName method of the parent MockReposStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *ReposStoreGetByNameFunc) PushHook(hook func(context.Context, int64, string) (*Repository, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ReposStoreGetByNameFunc) SetDefaultReturn(r0 *Repository, r1 error) {
+ f.SetDefaultHook(func(context.Context, int64, string) (*Repository, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ReposStoreGetByNameFunc) PushReturn(r0 *Repository, r1 error) {
+ f.PushHook(func(context.Context, int64, string) (*Repository, error) {
+ return r0, r1
+ })
+}
+
+func (f *ReposStoreGetByNameFunc) nextHook() func(context.Context, int64, string) (*Repository, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *ReposStoreGetByNameFunc) appendCall(r0 ReposStoreGetByNameFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of ReposStoreGetByNameFuncCall objects
+// describing the invocations of this function.
+func (f *ReposStoreGetByNameFunc) History() []ReposStoreGetByNameFuncCall {
+ f.mutex.Lock()
+ history := make([]ReposStoreGetByNameFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// ReposStoreGetByNameFuncCall is an object that describes an invocation of
+// method GetByName on an instance of MockReposStore.
+type ReposStoreGetByNameFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 string
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *Repository
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ReposStoreGetByNameFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ReposStoreGetByNameFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
// MockTwoFactorsStore is a mock implementation of the TwoFactorsStore
// interface (from the package gogs.io/gogs/internal/db) used for unit
// testing.
diff --git a/internal/db/repos.go b/internal/db/repos.go
index ecdbc0a5..8b4c5bce 100644
--- a/internal/db/repos.go
+++ b/internal/db/repos.go
@@ -5,6 +5,7 @@
package db
import (
+ "context"
"fmt"
"strings"
"time"
@@ -18,14 +19,14 @@ import (
//
// NOTE: All methods are sorted in alphabetical order.
type ReposStore interface {
- // GetByName returns the repository with given owner and name.
- // It returns ErrRepoNotExist when not found.
- GetByName(ownerID int64, name string) (*Repository, error)
+ // GetByName returns the repository with given owner and name. It returns
+ // ErrRepoNotExist when not found.
+ GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error)
}
var Repos ReposStore
-// NOTE: This is a GORM create hook.
+// BeforeCreate implements the GORM create hook.
func (r *Repository) BeforeCreate(tx *gorm.DB) error {
if r.CreatedUnix == 0 {
r.CreatedUnix = tx.NowFunc().Unix()
@@ -33,13 +34,13 @@ func (r *Repository) BeforeCreate(tx *gorm.DB) error {
return nil
}
-// NOTE: This is a GORM update hook.
+// BeforeUpdate implements the GORM update hook.
func (r *Repository) BeforeUpdate(tx *gorm.DB) error {
r.UpdatedUnix = tx.NowFunc().Unix()
return nil
}
-// NOTE: This is a GORM query hook.
+// AfterFind implements the GORM query hook.
func (r *Repository) AfterFind(_ *gorm.DB) error {
r.Created = time.Unix(r.CreatedUnix, 0).Local()
r.Updated = time.Unix(r.UpdatedUnix, 0).Local()
@@ -81,13 +82,13 @@ type createRepoOpts struct {
// create creates a new repository record in the database. Fields of "repo" will be updated
// in place upon insertion. It returns ErrNameNotAllowed when the repository name is not allowed,
// or ErrRepoAlreadyExist when a repository with same name already exists for the owner.
-func (db *repos) create(ownerID int64, opts createRepoOpts) (*Repository, error) {
+func (db *repos) create(ctx context.Context, ownerID int64, opts createRepoOpts) (*Repository, error) {
err := isRepoNameAllowed(opts.Name)
if err != nil {
return nil, err
}
- _, err = db.GetByName(ownerID, opts.Name)
+ _, err = db.GetByName(ctx, ownerID, opts.Name)
if err == nil {
return nil, ErrRepoAlreadyExist{args: errutil.Args{"ownerID": ownerID, "name": opts.Name}}
} else if !IsErrRepoNotExist(err) {
@@ -108,7 +109,7 @@ func (db *repos) create(ownerID int64, opts createRepoOpts) (*Repository, error)
IsFork: opts.Fork,
ForkID: opts.ForkID,
}
- return repo, db.DB.Create(repo).Error
+ return repo, db.WithContext(ctx).Create(repo).Error
}
var _ errutil.NotFound = (*ErrRepoNotExist)(nil)
@@ -130,9 +131,12 @@ func (ErrRepoNotExist) NotFound() bool {
return true
}
-func (db *repos) GetByName(ownerID int64, name string) (*Repository, error) {
+func (db *repos) GetByName(ctx context.Context, ownerID int64, name string) (*Repository, error) {
repo := new(Repository)
- err := db.Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).First(repo).Error
+ err := db.WithContext(ctx).
+ Where("owner_id = ? AND lower_name = ?", ownerID, strings.ToLower(name)).
+ First(repo).
+ Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrRepoNotExist{args: map[string]interface{}{"ownerID": ownerID, "name": name}}
diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go
index d248f3b0..32482506 100644
--- a/internal/db/repos_test.go
+++ b/internal/db/repos_test.go
@@ -5,15 +5,17 @@
package db
import (
+ "context"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"gogs.io/gogs/internal/errutil"
)
-func Test_repos(t *testing.T) {
+func TestRepos(t *testing.T) {
if testing.Short() {
t.Skip()
}
@@ -29,15 +31,13 @@ func Test_repos(t *testing.T) {
name string
test func(*testing.T, *repos)
}{
- {"create", test_repos_create},
- {"GetByName", test_repos_GetByName},
+ {"create", reposCreate},
+ {"GetByName", reposGetByName},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB, tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
})
tc.test(t, db)
})
@@ -47,58 +47,63 @@ func Test_repos(t *testing.T) {
}
}
-func test_repos_create(t *testing.T, db *repos) {
+func reposCreate(t *testing.T, db *repos) {
+ ctx := context.Background()
+
t.Run("name not allowed", func(t *testing.T) {
- _, err := db.create(1, createRepoOpts{
- Name: "my.git",
- })
- expErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "pattern": "*.git"}}
- assert.Equal(t, expErr, err)
+ _, err := db.create(ctx,
+ 1,
+ createRepoOpts{
+ Name: "my.git",
+ },
+ )
+ wantErr := ErrNameNotAllowed{args: errutil.Args{"reason": "reserved", "pattern": "*.git"}}
+ assert.Equal(t, wantErr, err)
})
t.Run("already exists", func(t *testing.T) {
- _, err := db.create(2, createRepoOpts{
- Name: "repo1",
- })
- if err != nil {
- t.Fatal(err)
- }
-
- _, err = db.create(2, createRepoOpts{
- Name: "repo1",
- })
- expErr := ErrRepoAlreadyExist{args: errutil.Args{"ownerID": int64(2), "name": "repo1"}}
- assert.Equal(t, expErr, err)
+ _, err := db.create(ctx, 2,
+ createRepoOpts{
+ Name: "repo1",
+ },
+ )
+ require.NoError(t, err)
+
+ _, err = db.create(ctx, 2,
+ createRepoOpts{
+ Name: "repo1",
+ },
+ )
+ wantErr := ErrRepoAlreadyExist{args: errutil.Args{"ownerID": int64(2), "name": "repo1"}}
+ assert.Equal(t, wantErr, err)
})
- repo, err := db.create(3, createRepoOpts{
- Name: "repo2",
- })
- if err != nil {
- t.Fatal(err)
- }
+ repo, err := db.create(ctx, 3,
+ createRepoOpts{
+ Name: "repo2",
+ },
+ )
+ require.NoError(t, err)
- repo, err = db.GetByName(repo.OwnerID, repo.Name)
- if err != nil {
- t.Fatal(err)
- }
+ repo, err = db.GetByName(ctx, repo.OwnerID, repo.Name)
+ require.NoError(t, err)
assert.Equal(t, db.NowFunc().Format(time.RFC3339), repo.Created.UTC().Format(time.RFC3339))
}
-func test_repos_GetByName(t *testing.T, db *repos) {
- repo, err := db.create(1, createRepoOpts{
- Name: "repo1",
- })
- if err != nil {
- t.Fatal(err)
- }
+func reposGetByName(t *testing.T, db *repos) {
+ ctx := context.Background()
- _, err = db.GetByName(repo.OwnerID, repo.Name)
- if err != nil {
- t.Fatal(err)
- }
+ repo, err := db.create(ctx, 1,
+ createRepoOpts{
+ Name: "repo1",
+ },
+ )
+ require.NoError(t, err)
+
+ _, err = db.GetByName(ctx, repo.OwnerID, repo.Name)
+ require.NoError(t, err)
- _, err = db.GetByName(1, "bad_name")
- expErr := ErrRepoNotExist{args: errutil.Args{"ownerID": int64(1), "name": "bad_name"}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByName(ctx, 1, "bad_name")
+ wantErr := ErrRepoNotExist{args: errutil.Args{"ownerID": int64(1), "name": "bad_name"}}
+ assert.Equal(t, wantErr, err)
}