aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorJoe Chen <jc@unknwon.io>2022-06-11 11:54:11 +0800
committerGitHub <noreply@github.com>2022-06-11 11:54:11 +0800
commit5e32058c13f34b46c69b7cdee6ccc0b7fe3b6df3 (patch)
tree92353ba2d8b6461754b89e95f581f4d402cf42af /internal/db
parent75fbb8244086a2ad964d1c51e3bdbdfb95df90ac (diff)
db: use `context` and go-mockgen for `TwoFactorsStore` (#7045)
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/mock_gen.go22
-rw-r--r--internal/db/mocks.go402
-rw-r--r--internal/db/two_factors.go35
-rw-r--r--internal/db/two_factors_test.go70
-rw-r--r--internal/db/user.go2
5 files changed, 454 insertions, 77 deletions
diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go
index 7235c6b5..8d94112f 100644
--- a/internal/db/mock_gen.go
+++ b/internal/db/mock_gen.go
@@ -8,7 +8,7 @@ import (
"testing"
)
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i UsersStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i TwoFactorsStore -i UsersStore -o mocks.go
func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
before := AccessTokens
@@ -60,26 +60,6 @@ func SetMockReposStore(t *testing.T, mock ReposStore) {
})
}
-var _ TwoFactorsStore = (*MockTwoFactorsStore)(nil)
-
-type MockTwoFactorsStore struct {
- MockCreate func(userID int64, key, secret string) error
- MockGetByUserID func(userID int64) (*TwoFactor, error)
- MockIsUserEnabled func(userID int64) bool
-}
-
-func (m *MockTwoFactorsStore) Create(userID int64, key, secret string) error {
- return m.MockCreate(userID, key, secret)
-}
-
-func (m *MockTwoFactorsStore) GetByUserID(userID int64) (*TwoFactor, error) {
- return m.MockGetByUserID(userID)
-}
-
-func (m *MockTwoFactorsStore) IsUserEnabled(userID int64) bool {
- return m.MockIsUserEnabled(userID)
-}
-
func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) {
before := TwoFactors
TwoFactors = mock
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index e6a39963..d3e302a8 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -2371,6 +2371,408 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
+// MockTwoFactorsStore is a mock implementation of the TwoFactorsStore
+// interface (from the package gogs.io/gogs/internal/db) used for unit
+// testing.
+type MockTwoFactorsStore struct {
+ // CreateFunc is an instance of a mock function object controlling the
+ // behavior of the method Create.
+ CreateFunc *TwoFactorsStoreCreateFunc
+ // GetByUserIDFunc is an instance of a mock function object controlling
+ // the behavior of the method GetByUserID.
+ GetByUserIDFunc *TwoFactorsStoreGetByUserIDFunc
+ // IsUserEnabledFunc is an instance of a mock function object
+ // controlling the behavior of the method IsUserEnabled.
+ IsUserEnabledFunc *TwoFactorsStoreIsUserEnabledFunc
+}
+
+// NewMockTwoFactorsStore creates a new mock of the TwoFactorsStore
+// interface. All methods return zero values for all results, unless
+// overwritten.
+func NewMockTwoFactorsStore() *MockTwoFactorsStore {
+ return &MockTwoFactorsStore{
+ CreateFunc: &TwoFactorsStoreCreateFunc{
+ defaultHook: func(context.Context, int64, string, string) (r0 error) {
+ return
+ },
+ },
+ GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+ defaultHook: func(context.Context, int64) (r0 *TwoFactor, r1 error) {
+ return
+ },
+ },
+ IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+ defaultHook: func(context.Context, int64) (r0 bool) {
+ return
+ },
+ },
+ }
+}
+
+// NewStrictMockTwoFactorsStore creates a new mock of the TwoFactorsStore
+// interface. All methods panic on invocation, unless overwritten.
+func NewStrictMockTwoFactorsStore() *MockTwoFactorsStore {
+ return &MockTwoFactorsStore{
+ CreateFunc: &TwoFactorsStoreCreateFunc{
+ defaultHook: func(context.Context, int64, string, string) error {
+ panic("unexpected invocation of MockTwoFactorsStore.Create")
+ },
+ },
+ GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+ defaultHook: func(context.Context, int64) (*TwoFactor, error) {
+ panic("unexpected invocation of MockTwoFactorsStore.GetByUserID")
+ },
+ },
+ IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+ defaultHook: func(context.Context, int64) bool {
+ panic("unexpected invocation of MockTwoFactorsStore.IsUserEnabled")
+ },
+ },
+ }
+}
+
+// NewMockTwoFactorsStoreFrom creates a new mock of the MockTwoFactorsStore
+// interface. All methods delegate to the given implementation, unless
+// overwritten.
+func NewMockTwoFactorsStoreFrom(i TwoFactorsStore) *MockTwoFactorsStore {
+ return &MockTwoFactorsStore{
+ CreateFunc: &TwoFactorsStoreCreateFunc{
+ defaultHook: i.Create,
+ },
+ GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+ defaultHook: i.GetByUserID,
+ },
+ IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+ defaultHook: i.IsUserEnabled,
+ },
+ }
+}
+
+// TwoFactorsStoreCreateFunc describes the behavior when the Create method
+// of the parent MockTwoFactorsStore instance is invoked.
+type TwoFactorsStoreCreateFunc struct {
+ defaultHook func(context.Context, int64, string, string) error
+ hooks []func(context.Context, int64, string, string) error
+ history []TwoFactorsStoreCreateFuncCall
+ mutex sync.Mutex
+}
+
+// Create delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) Create(v0 context.Context, v1 int64, v2 string, v3 string) error {
+ r0 := m.CreateFunc.nextHook()(v0, v1, v2, v3)
+ m.CreateFunc.appendCall(TwoFactorsStoreCreateFuncCall{v0, v1, v2, v3, r0})
+ return r0
+}
+
+// SetDefaultHook sets function that is called when the Create method of the
+// parent MockTwoFactorsStore instance is invoked and the hook queue is
+// empty.
+func (f *TwoFactorsStoreCreateFunc) SetDefaultHook(hook func(context.Context, int64, string, string) error) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Create method of the parent MockTwoFactorsStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *TwoFactorsStoreCreateFunc) PushHook(hook func(context.Context, int64, string, string) error) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreCreateFunc) SetDefaultReturn(r0 error) {
+ f.SetDefaultHook(func(context.Context, int64, string, string) error {
+ return r0
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreCreateFunc) PushReturn(r0 error) {
+ f.PushHook(func(context.Context, int64, string, string) error {
+ return r0
+ })
+}
+
+func (f *TwoFactorsStoreCreateFunc) nextHook() func(context.Context, int64, string, string) error {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *TwoFactorsStoreCreateFunc) appendCall(r0 TwoFactorsStoreCreateFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreCreateFuncCall objects
+// describing the invocations of this function.
+func (f *TwoFactorsStoreCreateFunc) History() []TwoFactorsStoreCreateFuncCall {
+ f.mutex.Lock()
+ history := make([]TwoFactorsStoreCreateFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// TwoFactorsStoreCreateFuncCall is an object that describes an invocation
+// of method Create on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreCreateFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Arg2 is the value of the 3rd argument passed to this method
+ // invocation.
+ Arg2 string
+ // Arg3 is the value of the 4th argument passed to this method
+ // invocation.
+ Arg3 string
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreCreateFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreCreateFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0}
+}
+
+// TwoFactorsStoreGetByUserIDFunc describes the behavior when the
+// GetByUserID method of the parent MockTwoFactorsStore instance is invoked.
+type TwoFactorsStoreGetByUserIDFunc struct {
+ defaultHook func(context.Context, int64) (*TwoFactor, error)
+ hooks []func(context.Context, int64) (*TwoFactor, error)
+ history []TwoFactorsStoreGetByUserIDFuncCall
+ mutex sync.Mutex
+}
+
+// GetByUserID delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) GetByUserID(v0 context.Context, v1 int64) (*TwoFactor, error) {
+ r0, r1 := m.GetByUserIDFunc.nextHook()(v0, v1)
+ m.GetByUserIDFunc.appendCall(TwoFactorsStoreGetByUserIDFuncCall{v0, v1, r0, r1})
+ return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByUserID method
+// of the parent MockTwoFactorsStore instance is invoked and the hook queue
+// is empty.
+func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultHook(hook func(context.Context, int64) (*TwoFactor, error)) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByUserID method of the parent MockTwoFactorsStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *TwoFactorsStoreGetByUserIDFunc) PushHook(hook func(context.Context, int64) (*TwoFactor, error)) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultReturn(r0 *TwoFactor, r1 error) {
+ f.SetDefaultHook(func(context.Context, int64) (*TwoFactor, error) {
+ return r0, r1
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreGetByUserIDFunc) PushReturn(r0 *TwoFactor, r1 error) {
+ f.PushHook(func(context.Context, int64) (*TwoFactor, error) {
+ return r0, r1
+ })
+}
+
+func (f *TwoFactorsStoreGetByUserIDFunc) nextHook() func(context.Context, int64) (*TwoFactor, error) {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *TwoFactorsStoreGetByUserIDFunc) appendCall(r0 TwoFactorsStoreGetByUserIDFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreGetByUserIDFuncCall objects
+// describing the invocations of this function.
+func (f *TwoFactorsStoreGetByUserIDFunc) History() []TwoFactorsStoreGetByUserIDFuncCall {
+ f.mutex.Lock()
+ history := make([]TwoFactorsStoreGetByUserIDFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// TwoFactorsStoreGetByUserIDFuncCall is an object that describes an
+// invocation of method GetByUserID on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreGetByUserIDFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 *TwoFactor
+ // Result1 is the value of the 2nd result returned from this method
+ // invocation.
+ Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreGetByUserIDFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreGetByUserIDFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0, c.Result1}
+}
+
+// TwoFactorsStoreIsUserEnabledFunc describes the behavior when the
+// IsUserEnabled method of the parent MockTwoFactorsStore instance is
+// invoked.
+type TwoFactorsStoreIsUserEnabledFunc struct {
+ defaultHook func(context.Context, int64) bool
+ hooks []func(context.Context, int64) bool
+ history []TwoFactorsStoreIsUserEnabledFuncCall
+ mutex sync.Mutex
+}
+
+// IsUserEnabled delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) IsUserEnabled(v0 context.Context, v1 int64) bool {
+ r0 := m.IsUserEnabledFunc.nextHook()(v0, v1)
+ m.IsUserEnabledFunc.appendCall(TwoFactorsStoreIsUserEnabledFuncCall{v0, v1, r0})
+ return r0
+}
+
+// SetDefaultHook sets function that is called when the IsUserEnabled method
+// of the parent MockTwoFactorsStore instance is invoked and the hook queue
+// is empty.
+func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultHook(hook func(context.Context, int64) bool) {
+ f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// IsUserEnabled method of the parent MockTwoFactorsStore instance invokes
+// the hook at the front of the queue and discards it. After the queue is
+// empty, the default hook function is invoked for any future action.
+func (f *TwoFactorsStoreIsUserEnabledFunc) PushHook(hook func(context.Context, int64) bool) {
+ f.mutex.Lock()
+ f.hooks = append(f.hooks, hook)
+ f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultReturn(r0 bool) {
+ f.SetDefaultHook(func(context.Context, int64) bool {
+ return r0
+ })
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreIsUserEnabledFunc) PushReturn(r0 bool) {
+ f.PushHook(func(context.Context, int64) bool {
+ return r0
+ })
+}
+
+func (f *TwoFactorsStoreIsUserEnabledFunc) nextHook() func(context.Context, int64) bool {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ if len(f.hooks) == 0 {
+ return f.defaultHook
+ }
+
+ hook := f.hooks[0]
+ f.hooks = f.hooks[1:]
+ return hook
+}
+
+func (f *TwoFactorsStoreIsUserEnabledFunc) appendCall(r0 TwoFactorsStoreIsUserEnabledFuncCall) {
+ f.mutex.Lock()
+ f.history = append(f.history, r0)
+ f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreIsUserEnabledFuncCall
+// objects describing the invocations of this function.
+func (f *TwoFactorsStoreIsUserEnabledFunc) History() []TwoFactorsStoreIsUserEnabledFuncCall {
+ f.mutex.Lock()
+ history := make([]TwoFactorsStoreIsUserEnabledFuncCall, len(f.history))
+ copy(history, f.history)
+ f.mutex.Unlock()
+
+ return history
+}
+
+// TwoFactorsStoreIsUserEnabledFuncCall is an object that describes an
+// invocation of method IsUserEnabled on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreIsUserEnabledFuncCall struct {
+ // Arg0 is the value of the 1st argument passed to this method
+ // invocation.
+ Arg0 context.Context
+ // Arg1 is the value of the 2nd argument passed to this method
+ // invocation.
+ Arg1 int64
+ // Result0 is the value of the 1st result returned from this method
+ // invocation.
+ Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreIsUserEnabledFuncCall) Args() []interface{} {
+ return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreIsUserEnabledFuncCall) Results() []interface{} {
+ return []interface{}{c.Result0}
+}
+
// MockUsersStore is a mock implementation of the UsersStore interface (from
// the package gogs.io/gogs/internal/db) used for unit testing.
type MockUsersStore struct {
diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go
index 935f66db..2dba2a02 100644
--- a/internal/db/two_factors.go
+++ b/internal/db/two_factors.go
@@ -5,6 +5,7 @@
package db
import (
+ "context"
"encoding/base64"
"fmt"
"strings"
@@ -23,21 +24,21 @@ import (
//
// NOTE: All methods are sorted in alphabetical order.
type TwoFactorsStore interface {
- // Create creates a new 2FA token and recovery codes for given user.
- // The "key" is used to encrypt and later decrypt given "secret",
- // which should be configured in site-level and change of the "key"
- // will break all existing 2FA tokens.
- Create(userID int64, key, secret string) error
- // GetByUserID returns the 2FA token of given user.
- // It returns ErrTwoFactorNotFound when not found.
- GetByUserID(userID int64) (*TwoFactor, error)
+ // Create creates a new 2FA token and recovery codes for given user. The "key"
+ // is used to encrypt and later decrypt given "secret", which should be
+ // configured in site-level and change of the "key" will break all existing 2FA
+ // tokens.
+ Create(ctx context.Context, userID int64, key, secret string) error
+ // GetByUserID returns the 2FA token of given user. It returns
+ // ErrTwoFactorNotFound when not found.
+ GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error)
// IsUserEnabled returns true if the user has enabled 2FA.
- IsUserEnabled(userID int64) bool
+ IsUserEnabled(ctx context.Context, userID int64) bool
}
var TwoFactors TwoFactorsStore
-// NOTE: This is a GORM create hook.
+// BeforeCreate implements the GORM create hook.
func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error {
if t.CreatedUnix == 0 {
t.CreatedUnix = tx.NowFunc().Unix()
@@ -45,7 +46,7 @@ func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error {
return nil
}
-// NOTE: This is a GORM query hook.
+// AfterFind implements the GORM query hook.
func (t *TwoFactor) AfterFind(_ *gorm.DB) error {
t.Created = time.Unix(t.CreatedUnix, 0).Local()
return nil
@@ -57,7 +58,7 @@ type twoFactors struct {
*gorm.DB
}
-func (db *twoFactors) Create(userID int64, key, secret string) error {
+func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error {
encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret))
if err != nil {
return errors.Wrap(err, "encrypt secret")
@@ -72,7 +73,7 @@ func (db *twoFactors) Create(userID int64, key, secret string) error {
return errors.Wrap(err, "generate recovery codes")
}
- return db.Transaction(func(tx *gorm.DB) error {
+ return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Create(tf).Error
if err != nil {
return err
@@ -101,9 +102,9 @@ func (ErrTwoFactorNotFound) NotFound() bool {
return true
}
-func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) {
+func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
tf := new(TwoFactor)
- err := db.Where("user_id = ?", userID).First(tf).Error
+ err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
@@ -113,9 +114,9 @@ func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) {
return tf, nil
}
-func (db *twoFactors) IsUserEnabled(userID int64) bool {
+func (db *twoFactors) IsUserEnabled(ctx context.Context, userID int64) bool {
var count int64
- err := db.Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
+ err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
if err != nil {
log.Error("Failed to count two factors [user_id: %d]: %v", userID, err)
}
diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go
index c8412213..acd9a576 100644
--- a/internal/db/two_factors_test.go
+++ b/internal/db/two_factors_test.go
@@ -5,15 +5,17 @@
package db
import (
+ "context"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"gogs.io/gogs/internal/errutil"
)
-func Test_twoFactors(t *testing.T) {
+func TestTwoFactors(t *testing.T) {
if testing.Short() {
t.Skip()
}
@@ -29,16 +31,14 @@ func Test_twoFactors(t *testing.T) {
name string
test func(*testing.T, *twoFactors)
}{
- {"Create", test_twoFactors_Create},
- {"GetByUserID", test_twoFactors_GetByUserID},
- {"IsUserEnabled", test_twoFactors_IsUserEnabled},
+ {"Create", twoFactorsCreate},
+ {"GetByUserID", twoFactorsGetByUserID},
+ {"IsUserEnabled", twoFactorsIsUserEnabled},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := clearTables(t, db.DB, tables...)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
})
tc.test(t, db)
})
@@ -48,55 +48,49 @@ func Test_twoFactors(t *testing.T) {
}
}
-func test_twoFactors_Create(t *testing.T, db *twoFactors) {
+func twoFactorsCreate(t *testing.T, db *twoFactors) {
+ ctx := context.Background()
+
// Create a 2FA token
- err := db.Create(1, "secure-key", "secure-secret")
- if err != nil {
- t.Fatal(err)
- }
+ err := db.Create(ctx, 1, "secure-key", "secure-secret")
+ require.NoError(t, err)
// Get it back and check the Created field
- tf, err := db.GetByUserID(1)
- if err != nil {
- t.Fatal(err)
- }
+ tf, err := db.GetByUserID(ctx, 1)
+ require.NoError(t, err)
assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
// Verify there are 10 recover codes generated
var count int64
err = db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
assert.Equal(t, int64(10), count)
}
-func test_twoFactors_GetByUserID(t *testing.T, db *twoFactors) {
+func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
+ ctx := context.Background()
+
// Create a 2FA token for user 1
- err := db.Create(1, "secure-key", "secure-secret")
- if err != nil {
- t.Fatal(err)
- }
+ err := db.Create(ctx, 1, "secure-key", "secure-secret")
+ require.NoError(t, err)
// We should be able to get it back
- _, err = db.GetByUserID(1)
- if err != nil {
- t.Fatal(err)
- }
+ _, err = db.GetByUserID(ctx, 1)
+ require.NoError(t, err)
// Try to get a non-existent 2FA token
- _, err = db.GetByUserID(2)
- expErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
- assert.Equal(t, expErr, err)
+ _, err = db.GetByUserID(ctx, 2)
+ wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
+ assert.Equal(t, wantErr, err)
}
-func test_twoFactors_IsUserEnabled(t *testing.T, db *twoFactors) {
+func twoFactorsIsUserEnabled(t *testing.T, db *twoFactors) {
+ ctx := context.Background()
+
// Create a 2FA token for user 1
- err := db.Create(1, "secure-key", "secure-secret")
- if err != nil {
- t.Fatal(err)
- }
+ err := db.Create(ctx, 1, "secure-key", "secure-secret")
+ require.NoError(t, err)
- assert.True(t, db.IsUserEnabled(1))
- assert.False(t, db.IsUserEnabled(2))
+ assert.True(t, db.IsUserEnabled(ctx, 1))
+ assert.False(t, db.IsUserEnabled(ctx, 2))
}
diff --git a/internal/db/user.go b/internal/db/user.go
index ebfbf082..1ad4b955 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -405,7 +405,7 @@ func (u *User) IsPublicMember(orgId int64) bool {
// IsEnabledTwoFactor returns true if user has enabled two-factor authentication.
func (u *User) IsEnabledTwoFactor() bool {
- return TwoFactors.IsUserEnabled(u.ID)
+ return TwoFactors.IsUserEnabled(context.TODO(), u.ID)
}
func (u *User) getOrganizationCount(e Engine) (int64, error) {