diff options
author | Joe Chen <jc@unknwon.io> | 2022-06-11 11:54:11 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-11 11:54:11 +0800 |
commit | 5e32058c13f34b46c69b7cdee6ccc0b7fe3b6df3 (patch) | |
tree | 92353ba2d8b6461754b89e95f581f4d402cf42af /internal/db | |
parent | 75fbb8244086a2ad964d1c51e3bdbdfb95df90ac (diff) |
db: use `context` and go-mockgen for `TwoFactorsStore` (#7045)
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/mock_gen.go | 22 | ||||
-rw-r--r-- | internal/db/mocks.go | 402 | ||||
-rw-r--r-- | internal/db/two_factors.go | 35 | ||||
-rw-r--r-- | internal/db/two_factors_test.go | 70 | ||||
-rw-r--r-- | internal/db/user.go | 2 |
5 files changed, 454 insertions, 77 deletions
diff --git a/internal/db/mock_gen.go b/internal/db/mock_gen.go index 7235c6b5..8d94112f 100644 --- a/internal/db/mock_gen.go +++ b/internal/db/mock_gen.go @@ -8,7 +8,7 @@ import ( "testing" ) -//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i UsersStore -o mocks.go +//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i TwoFactorsStore -i UsersStore -o mocks.go func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) { before := AccessTokens @@ -60,26 +60,6 @@ func SetMockReposStore(t *testing.T, mock ReposStore) { }) } -var _ TwoFactorsStore = (*MockTwoFactorsStore)(nil) - -type MockTwoFactorsStore struct { - MockCreate func(userID int64, key, secret string) error - MockGetByUserID func(userID int64) (*TwoFactor, error) - MockIsUserEnabled func(userID int64) bool -} - -func (m *MockTwoFactorsStore) Create(userID int64, key, secret string) error { - return m.MockCreate(userID, key, secret) -} - -func (m *MockTwoFactorsStore) GetByUserID(userID int64) (*TwoFactor, error) { - return m.MockGetByUserID(userID) -} - -func (m *MockTwoFactorsStore) IsUserEnabled(userID int64) bool { - return m.MockIsUserEnabled(userID) -} - func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) { before := TwoFactors TwoFactors = mock diff --git a/internal/db/mocks.go b/internal/db/mocks.go index e6a39963..d3e302a8 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -2371,6 +2371,408 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} { return []interface{}{c.Result0} } +// MockTwoFactorsStore is a mock implementation of the TwoFactorsStore +// interface (from the package gogs.io/gogs/internal/db) used for unit +// testing. +type MockTwoFactorsStore struct { + // CreateFunc is an instance of a mock function object controlling the + // behavior of the method Create. + CreateFunc *TwoFactorsStoreCreateFunc + // GetByUserIDFunc is an instance of a mock function object controlling + // the behavior of the method GetByUserID. + GetByUserIDFunc *TwoFactorsStoreGetByUserIDFunc + // IsUserEnabledFunc is an instance of a mock function object + // controlling the behavior of the method IsUserEnabled. + IsUserEnabledFunc *TwoFactorsStoreIsUserEnabledFunc +} + +// NewMockTwoFactorsStore creates a new mock of the TwoFactorsStore +// interface. All methods return zero values for all results, unless +// overwritten. +func NewMockTwoFactorsStore() *MockTwoFactorsStore { + return &MockTwoFactorsStore{ + CreateFunc: &TwoFactorsStoreCreateFunc{ + defaultHook: func(context.Context, int64, string, string) (r0 error) { + return + }, + }, + GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{ + defaultHook: func(context.Context, int64) (r0 *TwoFactor, r1 error) { + return + }, + }, + IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{ + defaultHook: func(context.Context, int64) (r0 bool) { + return + }, + }, + } +} + +// NewStrictMockTwoFactorsStore creates a new mock of the TwoFactorsStore +// interface. All methods panic on invocation, unless overwritten. +func NewStrictMockTwoFactorsStore() *MockTwoFactorsStore { + return &MockTwoFactorsStore{ + CreateFunc: &TwoFactorsStoreCreateFunc{ + defaultHook: func(context.Context, int64, string, string) error { + panic("unexpected invocation of MockTwoFactorsStore.Create") + }, + }, + GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{ + defaultHook: func(context.Context, int64) (*TwoFactor, error) { + panic("unexpected invocation of MockTwoFactorsStore.GetByUserID") + }, + }, + IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{ + defaultHook: func(context.Context, int64) bool { + panic("unexpected invocation of MockTwoFactorsStore.IsUserEnabled") + }, + }, + } +} + +// NewMockTwoFactorsStoreFrom creates a new mock of the MockTwoFactorsStore +// interface. All methods delegate to the given implementation, unless +// overwritten. +func NewMockTwoFactorsStoreFrom(i TwoFactorsStore) *MockTwoFactorsStore { + return &MockTwoFactorsStore{ + CreateFunc: &TwoFactorsStoreCreateFunc{ + defaultHook: i.Create, + }, + GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{ + defaultHook: i.GetByUserID, + }, + IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{ + defaultHook: i.IsUserEnabled, + }, + } +} + +// TwoFactorsStoreCreateFunc describes the behavior when the Create method +// of the parent MockTwoFactorsStore instance is invoked. +type TwoFactorsStoreCreateFunc struct { + defaultHook func(context.Context, int64, string, string) error + hooks []func(context.Context, int64, string, string) error + history []TwoFactorsStoreCreateFuncCall + mutex sync.Mutex +} + +// Create delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockTwoFactorsStore) Create(v0 context.Context, v1 int64, v2 string, v3 string) error { + r0 := m.CreateFunc.nextHook()(v0, v1, v2, v3) + m.CreateFunc.appendCall(TwoFactorsStoreCreateFuncCall{v0, v1, v2, v3, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the Create method of the +// parent MockTwoFactorsStore instance is invoked and the hook queue is +// empty. +func (f *TwoFactorsStoreCreateFunc) SetDefaultHook(hook func(context.Context, int64, string, string) error) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Create method of the parent MockTwoFactorsStore instance invokes the hook +// at the front of the queue and discards it. After the queue is empty, the +// default hook function is invoked for any future action. +func (f *TwoFactorsStoreCreateFunc) PushHook(hook func(context.Context, int64, string, string) error) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *TwoFactorsStoreCreateFunc) SetDefaultReturn(r0 error) { + f.SetDefaultHook(func(context.Context, int64, string, string) error { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *TwoFactorsStoreCreateFunc) PushReturn(r0 error) { + f.PushHook(func(context.Context, int64, string, string) error { + return r0 + }) +} + +func (f *TwoFactorsStoreCreateFunc) nextHook() func(context.Context, int64, string, string) error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *TwoFactorsStoreCreateFunc) appendCall(r0 TwoFactorsStoreCreateFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of TwoFactorsStoreCreateFuncCall objects +// describing the invocations of this function. +func (f *TwoFactorsStoreCreateFunc) History() []TwoFactorsStoreCreateFuncCall { + f.mutex.Lock() + history := make([]TwoFactorsStoreCreateFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// TwoFactorsStoreCreateFuncCall is an object that describes an invocation +// of method Create on an instance of MockTwoFactorsStore. +type TwoFactorsStoreCreateFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Arg2 is the value of the 3rd argument passed to this method + // invocation. + Arg2 string + // Arg3 is the value of the 4th argument passed to this method + // invocation. + Arg3 string + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c TwoFactorsStoreCreateFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c TwoFactorsStoreCreateFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + +// TwoFactorsStoreGetByUserIDFunc describes the behavior when the +// GetByUserID method of the parent MockTwoFactorsStore instance is invoked. +type TwoFactorsStoreGetByUserIDFunc struct { + defaultHook func(context.Context, int64) (*TwoFactor, error) + hooks []func(context.Context, int64) (*TwoFactor, error) + history []TwoFactorsStoreGetByUserIDFuncCall + mutex sync.Mutex +} + +// GetByUserID delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockTwoFactorsStore) GetByUserID(v0 context.Context, v1 int64) (*TwoFactor, error) { + r0, r1 := m.GetByUserIDFunc.nextHook()(v0, v1) + m.GetByUserIDFunc.appendCall(TwoFactorsStoreGetByUserIDFuncCall{v0, v1, r0, r1}) + return r0, r1 +} + +// SetDefaultHook sets function that is called when the GetByUserID method +// of the parent MockTwoFactorsStore instance is invoked and the hook queue +// is empty. +func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultHook(hook func(context.Context, int64) (*TwoFactor, error)) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// GetByUserID method of the parent MockTwoFactorsStore instance invokes the +// hook at the front of the queue and discards it. After the queue is empty, +// the default hook function is invoked for any future action. +func (f *TwoFactorsStoreGetByUserIDFunc) PushHook(hook func(context.Context, int64) (*TwoFactor, error)) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultReturn(r0 *TwoFactor, r1 error) { + f.SetDefaultHook(func(context.Context, int64) (*TwoFactor, error) { + return r0, r1 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *TwoFactorsStoreGetByUserIDFunc) PushReturn(r0 *TwoFactor, r1 error) { + f.PushHook(func(context.Context, int64) (*TwoFactor, error) { + return r0, r1 + }) +} + +func (f *TwoFactorsStoreGetByUserIDFunc) nextHook() func(context.Context, int64) (*TwoFactor, error) { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *TwoFactorsStoreGetByUserIDFunc) appendCall(r0 TwoFactorsStoreGetByUserIDFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of TwoFactorsStoreGetByUserIDFuncCall objects +// describing the invocations of this function. +func (f *TwoFactorsStoreGetByUserIDFunc) History() []TwoFactorsStoreGetByUserIDFuncCall { + f.mutex.Lock() + history := make([]TwoFactorsStoreGetByUserIDFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// TwoFactorsStoreGetByUserIDFuncCall is an object that describes an +// invocation of method GetByUserID on an instance of MockTwoFactorsStore. +type TwoFactorsStoreGetByUserIDFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 *TwoFactor + // Result1 is the value of the 2nd result returned from this method + // invocation. + Result1 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c TwoFactorsStoreGetByUserIDFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c TwoFactorsStoreGetByUserIDFuncCall) Results() []interface{} { + return []interface{}{c.Result0, c.Result1} +} + +// TwoFactorsStoreIsUserEnabledFunc describes the behavior when the +// IsUserEnabled method of the parent MockTwoFactorsStore instance is +// invoked. +type TwoFactorsStoreIsUserEnabledFunc struct { + defaultHook func(context.Context, int64) bool + hooks []func(context.Context, int64) bool + history []TwoFactorsStoreIsUserEnabledFuncCall + mutex sync.Mutex +} + +// IsUserEnabled delegates to the next hook function in the queue and stores +// the parameter and result values of this invocation. +func (m *MockTwoFactorsStore) IsUserEnabled(v0 context.Context, v1 int64) bool { + r0 := m.IsUserEnabledFunc.nextHook()(v0, v1) + m.IsUserEnabledFunc.appendCall(TwoFactorsStoreIsUserEnabledFuncCall{v0, v1, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the IsUserEnabled method +// of the parent MockTwoFactorsStore instance is invoked and the hook queue +// is empty. +func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultHook(hook func(context.Context, int64) bool) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// IsUserEnabled method of the parent MockTwoFactorsStore instance invokes +// the hook at the front of the queue and discards it. After the queue is +// empty, the default hook function is invoked for any future action. +func (f *TwoFactorsStoreIsUserEnabledFunc) PushHook(hook func(context.Context, int64) bool) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultReturn(r0 bool) { + f.SetDefaultHook(func(context.Context, int64) bool { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *TwoFactorsStoreIsUserEnabledFunc) PushReturn(r0 bool) { + f.PushHook(func(context.Context, int64) bool { + return r0 + }) +} + +func (f *TwoFactorsStoreIsUserEnabledFunc) nextHook() func(context.Context, int64) bool { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *TwoFactorsStoreIsUserEnabledFunc) appendCall(r0 TwoFactorsStoreIsUserEnabledFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of TwoFactorsStoreIsUserEnabledFuncCall +// objects describing the invocations of this function. +func (f *TwoFactorsStoreIsUserEnabledFunc) History() []TwoFactorsStoreIsUserEnabledFuncCall { + f.mutex.Lock() + history := make([]TwoFactorsStoreIsUserEnabledFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// TwoFactorsStoreIsUserEnabledFuncCall is an object that describes an +// invocation of method IsUserEnabled on an instance of MockTwoFactorsStore. +type TwoFactorsStoreIsUserEnabledFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 context.Context + // Arg1 is the value of the 2nd argument passed to this method + // invocation. + Arg1 int64 + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 bool +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c TwoFactorsStoreIsUserEnabledFuncCall) Args() []interface{} { + return []interface{}{c.Arg0, c.Arg1} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c TwoFactorsStoreIsUserEnabledFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + // MockUsersStore is a mock implementation of the UsersStore interface (from // the package gogs.io/gogs/internal/db) used for unit testing. type MockUsersStore struct { diff --git a/internal/db/two_factors.go b/internal/db/two_factors.go index 935f66db..2dba2a02 100644 --- a/internal/db/two_factors.go +++ b/internal/db/two_factors.go @@ -5,6 +5,7 @@ package db import ( + "context" "encoding/base64" "fmt" "strings" @@ -23,21 +24,21 @@ import ( // // NOTE: All methods are sorted in alphabetical order. type TwoFactorsStore interface { - // Create creates a new 2FA token and recovery codes for given user. - // The "key" is used to encrypt and later decrypt given "secret", - // which should be configured in site-level and change of the "key" - // will break all existing 2FA tokens. - Create(userID int64, key, secret string) error - // GetByUserID returns the 2FA token of given user. - // It returns ErrTwoFactorNotFound when not found. - GetByUserID(userID int64) (*TwoFactor, error) + // Create creates a new 2FA token and recovery codes for given user. The "key" + // is used to encrypt and later decrypt given "secret", which should be + // configured in site-level and change of the "key" will break all existing 2FA + // tokens. + Create(ctx context.Context, userID int64, key, secret string) error + // GetByUserID returns the 2FA token of given user. It returns + // ErrTwoFactorNotFound when not found. + GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) // IsUserEnabled returns true if the user has enabled 2FA. - IsUserEnabled(userID int64) bool + IsUserEnabled(ctx context.Context, userID int64) bool } var TwoFactors TwoFactorsStore -// NOTE: This is a GORM create hook. +// BeforeCreate implements the GORM create hook. func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error { if t.CreatedUnix == 0 { t.CreatedUnix = tx.NowFunc().Unix() @@ -45,7 +46,7 @@ func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error { return nil } -// NOTE: This is a GORM query hook. +// AfterFind implements the GORM query hook. func (t *TwoFactor) AfterFind(_ *gorm.DB) error { t.Created = time.Unix(t.CreatedUnix, 0).Local() return nil @@ -57,7 +58,7 @@ type twoFactors struct { *gorm.DB } -func (db *twoFactors) Create(userID int64, key, secret string) error { +func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error { encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret)) if err != nil { return errors.Wrap(err, "encrypt secret") @@ -72,7 +73,7 @@ func (db *twoFactors) Create(userID int64, key, secret string) error { return errors.Wrap(err, "generate recovery codes") } - return db.Transaction(func(tx *gorm.DB) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Create(tf).Error if err != nil { return err @@ -101,9 +102,9 @@ func (ErrTwoFactorNotFound) NotFound() bool { return true } -func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) { +func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) { tf := new(TwoFactor) - err := db.Where("user_id = ?", userID).First(tf).Error + err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}} @@ -113,9 +114,9 @@ func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) { return tf, nil } -func (db *twoFactors) IsUserEnabled(userID int64) bool { +func (db *twoFactors) IsUserEnabled(ctx context.Context, userID int64) bool { var count int64 - err := db.Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error + err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error if err != nil { log.Error("Failed to count two factors [user_id: %d]: %v", userID, err) } diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go index c8412213..acd9a576 100644 --- a/internal/db/two_factors_test.go +++ b/internal/db/two_factors_test.go @@ -5,15 +5,17 @@ package db import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gogs.io/gogs/internal/errutil" ) -func Test_twoFactors(t *testing.T) { +func TestTwoFactors(t *testing.T) { if testing.Short() { t.Skip() } @@ -29,16 +31,14 @@ func Test_twoFactors(t *testing.T) { name string test func(*testing.T, *twoFactors) }{ - {"Create", test_twoFactors_Create}, - {"GetByUserID", test_twoFactors_GetByUserID}, - {"IsUserEnabled", test_twoFactors_IsUserEnabled}, + {"Create", twoFactorsCreate}, + {"GetByUserID", twoFactorsGetByUserID}, + {"IsUserEnabled", twoFactorsIsUserEnabled}, } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { err := clearTables(t, db.DB, tables...) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) }) tc.test(t, db) }) @@ -48,55 +48,49 @@ func Test_twoFactors(t *testing.T) { } } -func test_twoFactors_Create(t *testing.T, db *twoFactors) { +func twoFactorsCreate(t *testing.T, db *twoFactors) { + ctx := context.Background() + // Create a 2FA token - err := db.Create(1, "secure-key", "secure-secret") - if err != nil { - t.Fatal(err) - } + err := db.Create(ctx, 1, "secure-key", "secure-secret") + require.NoError(t, err) // Get it back and check the Created field - tf, err := db.GetByUserID(1) - if err != nil { - t.Fatal(err) - } + tf, err := db.GetByUserID(ctx, 1) + require.NoError(t, err) assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339)) // Verify there are 10 recover codes generated var count int64 err = db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Equal(t, int64(10), count) } -func test_twoFactors_GetByUserID(t *testing.T, db *twoFactors) { +func twoFactorsGetByUserID(t *testing.T, db *twoFactors) { + ctx := context.Background() + // Create a 2FA token for user 1 - err := db.Create(1, "secure-key", "secure-secret") - if err != nil { - t.Fatal(err) - } + err := db.Create(ctx, 1, "secure-key", "secure-secret") + require.NoError(t, err) // We should be able to get it back - _, err = db.GetByUserID(1) - if err != nil { - t.Fatal(err) - } + _, err = db.GetByUserID(ctx, 1) + require.NoError(t, err) // Try to get a non-existent 2FA token - _, err = db.GetByUserID(2) - expErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}} - assert.Equal(t, expErr, err) + _, err = db.GetByUserID(ctx, 2) + wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}} + assert.Equal(t, wantErr, err) } -func test_twoFactors_IsUserEnabled(t *testing.T, db *twoFactors) { +func twoFactorsIsUserEnabled(t *testing.T, db *twoFactors) { + ctx := context.Background() + // Create a 2FA token for user 1 - err := db.Create(1, "secure-key", "secure-secret") - if err != nil { - t.Fatal(err) - } + err := db.Create(ctx, 1, "secure-key", "secure-secret") + require.NoError(t, err) - assert.True(t, db.IsUserEnabled(1)) - assert.False(t, db.IsUserEnabled(2)) + assert.True(t, db.IsUserEnabled(ctx, 1)) + assert.False(t, db.IsUserEnabled(ctx, 2)) } diff --git a/internal/db/user.go b/internal/db/user.go index ebfbf082..1ad4b955 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -405,7 +405,7 @@ func (u *User) IsPublicMember(orgId int64) bool { // IsEnabledTwoFactor returns true if user has enabled two-factor authentication. func (u *User) IsEnabledTwoFactor() bool { - return TwoFactors.IsUserEnabled(u.ID) + return TwoFactors.IsUserEnabled(context.TODO(), u.ID) } func (u *User) getOrganizationCount(e Engine) (int64, error) { |