diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/access_tokens_test.go | 13 | ||||
-rw-r--r-- | internal/db/db.go | 9 | ||||
-rw-r--r-- | internal/db/error.go | 33 | ||||
-rw-r--r-- | internal/db/errors/login_source.go | 14 | ||||
-rw-r--r-- | internal/db/lfs_test.go | 5 | ||||
-rw-r--r-- | internal/db/login_source.go | 564 | ||||
-rw-r--r-- | internal/db/login_source_files.go | 212 | ||||
-rw-r--r-- | internal/db/login_sources.go | 276 | ||||
-rw-r--r-- | internal/db/login_sources_test.go | 389 | ||||
-rw-r--r-- | internal/db/main_test.go | 3 | ||||
-rw-r--r-- | internal/db/mocks.go | 53 | ||||
-rw-r--r-- | internal/db/models.go | 12 | ||||
-rw-r--r-- | internal/db/perms_test.go | 5 | ||||
-rw-r--r-- | internal/db/user.go | 38 |
14 files changed, 1022 insertions, 604 deletions
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go index d3dfb83f..0f979e95 100644 --- a/internal/db/access_tokens_test.go +++ b/internal/db/access_tokens_test.go @@ -21,8 +21,9 @@ func Test_accessTokens(t *testing.T) { t.Parallel() + tables := []interface{}{new(AccessToken)} db := &accessTokens{ - DB: initTestDB(t, "accessTokens", new(AccessToken)), + DB: initTestDB(t, "accessTokens", tables...), } for _, tc := range []struct { @@ -37,7 +38,7 @@ func Test_accessTokens(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { - err := deleteTables(db.DB, new(AccessToken)) + err := clearTables(db.DB, tables...) if err != nil { t.Fatal(err) } @@ -78,14 +79,14 @@ func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) { t.Fatal(err) } - // We should be able to get it back - _, err = db.GetBySHA(token.Sha1) + // Delete a token with mismatched user ID is noop + err = db.DeleteByID(2, token.ID) if err != nil { t.Fatal(err) } - // Delete a token with mismatched user ID is noop - err = db.DeleteByID(2, token.ID) + // We should be able to get it back + _, err = db.GetBySHA(token.Sha1) if err != nil { t.Fatal(err) } diff --git a/internal/db/db.go b/internal/db/db.go index 1be2cc4b..77d78f53 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -124,7 +124,7 @@ func getLogWriter() (io.Writer, error) { var tables = []interface{}{ new(AccessToken), - new(LFSObject), + new(LFSObject), new(LoginSource), } func Init() error { @@ -167,9 +167,14 @@ func Init() error { return time.Now().UTC().Truncate(time.Microsecond) } + sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d")) + if err != nil { + return errors.Wrap(err, "load login source files") + } + // Initialize stores, sorted in alphabetical order. AccessTokens = &accessTokens{DB: db} - LoginSources = &loginSources{DB: db} + LoginSources = &loginSources{DB: db, files: sourceFiles} LFS = &lfs{DB: db} Perms = &perms{DB: db} Repos = &repos{DB: db} diff --git a/internal/db/error.go b/internal/db/error.go index ed173d86..46e7dde5 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -327,39 +327,6 @@ func (err ErrRepoFileAlreadyExist) Error() string { return fmt.Sprintf("repository file already exists [file_name: %s]", err.FileName) } -// .____ .__ _________ -// | | ____ ____ |__| ____ / _____/ ____ __ _________ ____ ____ -// | | / _ \ / ___\| |/ \ \_____ \ / _ \| | \_ __ \_/ ___\/ __ \ -// | |__( <_> ) /_/ > | | \ / ( <_> ) | /| | \/\ \__\ ___/ -// |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ > -// \/ /_____/ \/ \/ \/ \/ - -type ErrLoginSourceAlreadyExist struct { - Name string -} - -func IsErrLoginSourceAlreadyExist(err error) bool { - _, ok := err.(ErrLoginSourceAlreadyExist) - return ok -} - -func (err ErrLoginSourceAlreadyExist) Error() string { - return fmt.Sprintf("login source already exists [name: %s]", err.Name) -} - -type ErrLoginSourceInUse struct { - ID int64 -} - -func IsErrLoginSourceInUse(err error) bool { - _, ok := err.(ErrLoginSourceInUse) - return ok -} - -func (err ErrLoginSourceInUse) Error() string { - return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) -} - // ___________ // \__ ___/___ _____ _____ // | |_/ __ \\__ \ / \ diff --git a/internal/db/errors/login_source.go b/internal/db/errors/login_source.go index 876a0820..db0cd1f9 100644 --- a/internal/db/errors/login_source.go +++ b/internal/db/errors/login_source.go @@ -6,19 +6,6 @@ package errors import "fmt" -type LoginSourceNotExist struct { - ID int64 -} - -func IsLoginSourceNotExist(err error) bool { - _, ok := err.(LoginSourceNotExist) - return ok -} - -func (err LoginSourceNotExist) Error() string { - return fmt.Sprintf("login source does not exist [id: %d]", err.ID) -} - type LoginSourceNotActivated struct { SourceID int64 } @@ -44,4 +31,3 @@ func IsInvalidLoginSourceType(err error) bool { func (err InvalidLoginSourceType) Error() string { return fmt.Sprintf("invalid login source type [type: %v]", err.Type) } - diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go index 6eb14019..29c7d665 100644 --- a/internal/db/lfs_test.go +++ b/internal/db/lfs_test.go @@ -22,8 +22,9 @@ func Test_lfs(t *testing.T) { t.Parallel() + tables := []interface{}{new(LFSObject)} db := &lfs{ - DB: initTestDB(t, "lfs", new(LFSObject)), + DB: initTestDB(t, "lfs", tables...), } for _, tc := range []struct { @@ -36,7 +37,7 @@ func Test_lfs(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { - err := deleteTables(db.DB, new(LFSObject)) + err := clearTables(db.DB, tables...) if err != nil { t.Fatal(err) } diff --git a/internal/db/login_source.go b/internal/db/login_source.go index b6665f4e..e821e186 100644 --- a/internal/db/login_source.go +++ b/internal/db/login_source.go @@ -10,30 +10,21 @@ import ( "fmt" "net/smtp" "net/textproto" - "os" - "path/filepath" "strings" - "sync" - "time" "github.com/go-macaron/binding" - "github.com/json-iterator/go" "github.com/unknwon/com" - "gopkg.in/ini.v1" - log "unknwon.dev/clog/v2" - "xorm.io/core" - "xorm.io/xorm" "gogs.io/gogs/internal/auth/github" "gogs.io/gogs/internal/auth/ldap" "gogs.io/gogs/internal/auth/pam" - "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/db/errors" ) type LoginType int // Note: new type must append to the end of list to maintain compatibility. +// TODO: Move to authutil. const ( LoginNotype LoginType = iota LoginPlain // 1 @@ -52,497 +43,24 @@ var LoginNames = map[LoginType]string{ LoginGitHub: "GitHub", } -var SecurityProtocolNames = map[ldap.SecurityProtocol]string{ - ldap.SECURITY_PROTOCOL_UNENCRYPTED: "Unencrypted", - ldap.SECURITY_PROTOCOL_LDAPS: "LDAPS", - ldap.SECURITY_PROTOCOL_START_TLS: "StartTLS", -} - -// Ensure structs implemented interface. -var ( - _ core.Conversion = &LDAPConfig{} - _ core.Conversion = &SMTPConfig{} - _ core.Conversion = &PAMConfig{} - _ core.Conversion = &GitHubConfig{} -) +// *********************** +// ----- LDAP config ----- +// *********************** type LDAPConfig struct { - *ldap.Source `ini:"config"` + ldap.Source `ini:"config"` } -func (cfg *LDAPConfig) FromDB(bs []byte) error { - return jsoniter.Unmarshal(bs, &cfg) -} - -func (cfg *LDAPConfig) ToDB() ([]byte, error) { - return jsoniter.Marshal(cfg) +var SecurityProtocolNames = map[ldap.SecurityProtocol]string{ + ldap.SecurityProtocolUnencrypted: "Unencrypted", + ldap.SecurityProtocolLDAPS: "LDAPS", + ldap.SecurityProtocolStartTLS: "StartTLS", } func (cfg *LDAPConfig) SecurityProtocolName() string { return SecurityProtocolNames[cfg.SecurityProtocol] } -type SMTPConfig struct { - Auth string - Host string - Port int - AllowedDomains string `xorm:"TEXT"` - TLS bool `ini:"tls"` - SkipVerify bool -} - -func (cfg *SMTPConfig) FromDB(bs []byte) error { - return jsoniter.Unmarshal(bs, cfg) -} - -func (cfg *SMTPConfig) ToDB() ([]byte, error) { - return jsoniter.Marshal(cfg) -} - -type PAMConfig struct { - ServiceName string // PAM service (e.g. system-auth) -} - -func (cfg *PAMConfig) FromDB(bs []byte) error { - return jsoniter.Unmarshal(bs, &cfg) -} - -func (cfg *PAMConfig) ToDB() ([]byte, error) { - return jsoniter.Marshal(cfg) -} - -type GitHubConfig struct { - APIEndpoint string // GitHub service (e.g. https://api.github.com/) -} - -func (cfg *GitHubConfig) FromDB(bs []byte) error { - return jsoniter.Unmarshal(bs, &cfg) -} - -func (cfg *GitHubConfig) ToDB() ([]byte, error) { - return jsoniter.Marshal(cfg) -} - -// AuthSourceFile contains information of an authentication source file. -type AuthSourceFile struct { - abspath string - file *ini.File -} - -// SetGeneral sets new value to the given key in the general (default) section. -func (f *AuthSourceFile) SetGeneral(name, value string) { - f.file.Section("").Key(name).SetValue(value) -} - -// SetConfig sets new values to the "config" section. -func (f *AuthSourceFile) SetConfig(cfg core.Conversion) error { - return f.file.Section("config").ReflectFrom(cfg) -} - -// Save writes updates into file system. -func (f *AuthSourceFile) Save() error { - return f.file.SaveTo(f.abspath) -} - -// LoginSource represents an external way for authorizing users. -type LoginSource struct { - ID int64 - Type LoginType - Name string `xorm:"UNIQUE"` - IsActived bool `xorm:"NOT NULL DEFAULT false"` - IsDefault bool `xorm:"DEFAULT false"` - Cfg core.Conversion `xorm:"TEXT" gorm:"COLUMN:remove-me-when-migrated-to-gorm"` - RawCfg string `xorm:"-" gorm:"COLUMN:cfg"` // TODO: Remove me when migrated to GORM. - - Created time.Time `xorm:"-" json:"-"` - CreatedUnix int64 - Updated time.Time `xorm:"-" json:"-"` - UpdatedUnix int64 - - LocalFile *AuthSourceFile `xorm:"-" json:"-"` -} - -func (s *LoginSource) BeforeInsert() { - s.CreatedUnix = time.Now().Unix() - s.UpdatedUnix = s.CreatedUnix -} - -func (s *LoginSource) BeforeUpdate() { - s.UpdatedUnix = time.Now().Unix() -} - -// Cell2Int64 converts a xorm.Cell type to int64, -// and handles possible irregular cases. -func Cell2Int64(val xorm.Cell) int64 { - switch (*val).(type) { - case []uint8: - log.Trace("Cell2Int64 ([]uint8): %v", *val) - return com.StrTo(string((*val).([]uint8))).MustInt64() - } - return (*val).(int64) -} - -func (s *LoginSource) BeforeSet(colName string, val xorm.Cell) { - switch colName { - case "type": - switch LoginType(Cell2Int64(val)) { - case LoginLDAP, LoginDLDAP: - s.Cfg = new(LDAPConfig) - case LoginSMTP: - s.Cfg = new(SMTPConfig) - case LoginPAM: - s.Cfg = new(PAMConfig) - case LoginGitHub: - s.Cfg = new(GitHubConfig) - default: - panic("unrecognized login source type: " + com.ToStr(*val)) - } - } -} - -func (s *LoginSource) AfterSet(colName string, _ xorm.Cell) { - switch colName { - case "created_unix": - s.Created = time.Unix(s.CreatedUnix, 0).Local() - case "updated_unix": - s.Updated = time.Unix(s.UpdatedUnix, 0).Local() - } -} - -// NOTE: This is a GORM query hook. -func (s *LoginSource) AfterFind() error { - switch s.Type { - case LoginLDAP, LoginDLDAP: - s.Cfg = new(LDAPConfig) - case LoginSMTP: - s.Cfg = new(SMTPConfig) - case LoginPAM: - s.Cfg = new(PAMConfig) - case LoginGitHub: - s.Cfg = new(GitHubConfig) - default: - return fmt.Errorf("unrecognized login source type: %v", s.Type) - } - return jsoniter.UnmarshalFromString(s.RawCfg, s.Cfg) -} - -func (s *LoginSource) TypeName() string { - return LoginNames[s.Type] -} - -func (s *LoginSource) IsLDAP() bool { - return s.Type == LoginLDAP -} - -func (s *LoginSource) IsDLDAP() bool { - return s.Type == LoginDLDAP -} - -func (s *LoginSource) IsSMTP() bool { - return s.Type == LoginSMTP -} - -func (s *LoginSource) IsPAM() bool { - return s.Type == LoginPAM -} - -func (s *LoginSource) IsGitHub() bool { - return s.Type == LoginGitHub -} - -func (s *LoginSource) HasTLS() bool { - return ((s.IsLDAP() || s.IsDLDAP()) && - s.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) || - s.IsSMTP() -} - -func (s *LoginSource) UseTLS() bool { - switch s.Type { - case LoginLDAP, LoginDLDAP: - return s.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED - case LoginSMTP: - return s.SMTP().TLS - } - - return false -} - -func (s *LoginSource) SkipVerify() bool { - switch s.Type { - case LoginLDAP, LoginDLDAP: - return s.LDAP().SkipVerify - case LoginSMTP: - return s.SMTP().SkipVerify - } - - return false -} - -func (s *LoginSource) LDAP() *LDAPConfig { - return s.Cfg.(*LDAPConfig) -} - -func (s *LoginSource) SMTP() *SMTPConfig { - return s.Cfg.(*SMTPConfig) -} - -func (s *LoginSource) PAM() *PAMConfig { - return s.Cfg.(*PAMConfig) -} - -func (s *LoginSource) GitHub() *GitHubConfig { - return s.Cfg.(*GitHubConfig) -} - -func CreateLoginSource(source *LoginSource) error { - has, err := x.Get(&LoginSource{Name: source.Name}) - if err != nil { - return err - } else if has { - return ErrLoginSourceAlreadyExist{source.Name} - } - - _, err = x.Insert(source) - if err != nil { - return err - } else if source.IsDefault { - return ResetNonDefaultLoginSources(source) - } - return nil -} - -// ListLoginSources returns all login sources defined. -func ListLoginSources() ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 2) - if err := x.Find(&sources); err != nil { - return nil, err - } - - return append(sources, localLoginSources.List()...), nil -} - -// ActivatedLoginSources returns login sources that are currently activated. -func ActivatedLoginSources() ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 2) - if err := x.Where("is_actived = ?", true).Find(&sources); err != nil { - return nil, fmt.Errorf("find activated login sources: %v", err) - } - return append(sources, localLoginSources.ActivatedList()...), nil -} - -// ResetNonDefaultLoginSources clean other default source flag -func ResetNonDefaultLoginSources(source *LoginSource) error { - // update changes to DB - if _, err := x.NotIn("id", []int64{source.ID}).Cols("is_default").Update(&LoginSource{IsDefault: false}); err != nil { - return err - } - // write changes to local authentications - for i := range localLoginSources.sources { - if localLoginSources.sources[i].LocalFile != nil && localLoginSources.sources[i].ID != source.ID { - localLoginSources.sources[i].LocalFile.SetGeneral("is_default", "false") - if err := localLoginSources.sources[i].LocalFile.SetConfig(source.Cfg); err != nil { - return fmt.Errorf("LocalFile.SetConfig: %v", err) - } else if err = localLoginSources.sources[i].LocalFile.Save(); err != nil { - return fmt.Errorf("LocalFile.Save: %v", err) - } - } - } - // flush memory so that web page can show the same behaviors - localLoginSources.UpdateLoginSource(source) - return nil -} - -// UpdateLoginSource updates information of login source to database or local file. -func UpdateLoginSource(source *LoginSource) error { - if source.LocalFile == nil { - if _, err := x.Id(source.ID).AllCols().Update(source); err != nil { - return err - } else { - return ResetNonDefaultLoginSources(source) - } - - } - - source.LocalFile.SetGeneral("name", source.Name) - source.LocalFile.SetGeneral("is_activated", com.ToStr(source.IsActived)) - source.LocalFile.SetGeneral("is_default", com.ToStr(source.IsDefault)) - if err := source.LocalFile.SetConfig(source.Cfg); err != nil { - return fmt.Errorf("LocalFile.SetConfig: %v", err) - } else if err = source.LocalFile.Save(); err != nil { - return fmt.Errorf("LocalFile.Save: %v", err) - } - return ResetNonDefaultLoginSources(source) -} - -func DeleteSource(source *LoginSource) error { - count, err := x.Count(&User{LoginSource: source.ID}) - if err != nil { - return err - } else if count > 0 { - return ErrLoginSourceInUse{source.ID} - } - _, err = x.Id(source.ID).Delete(new(LoginSource)) - return err -} - -// CountLoginSources returns total number of login sources. -func CountLoginSources() int64 { - count, _ := x.Count(new(LoginSource)) - return count + int64(localLoginSources.Len()) -} - -// LocalLoginSources contains authentication sources configured and loaded from local files. -// Calling its methods is thread-safe; otherwise, please maintain the mutex accordingly. -type LocalLoginSources struct { - sync.RWMutex - sources []*LoginSource -} - -func (s *LocalLoginSources) Len() int { - return len(s.sources) -} - -// List returns full clone of login sources. -func (s *LocalLoginSources) List() []*LoginSource { - s.RLock() - defer s.RUnlock() - - list := make([]*LoginSource, s.Len()) - for i := range s.sources { - list[i] = &LoginSource{} - *list[i] = *s.sources[i] - } - return list -} - -// ActivatedList returns clone of activated login sources. -func (s *LocalLoginSources) ActivatedList() []*LoginSource { - s.RLock() - defer s.RUnlock() - - list := make([]*LoginSource, 0, 2) - for i := range s.sources { - if !s.sources[i].IsActived { - continue - } - source := &LoginSource{} - *source = *s.sources[i] - list = append(list, source) - } - return list -} - -// GetLoginSourceByID returns a clone of login source by given ID. -func (s *LocalLoginSources) GetLoginSourceByID(id int64) (*LoginSource, error) { - s.RLock() - defer s.RUnlock() - - for i := range s.sources { - if s.sources[i].ID == id { - source := &LoginSource{} - *source = *s.sources[i] - return source, nil - } - } - - return nil, errors.LoginSourceNotExist{ID: id} -} - -// UpdateLoginSource updates in-memory copy of the authentication source. -func (s *LocalLoginSources) UpdateLoginSource(source *LoginSource) { - s.Lock() - defer s.Unlock() - - source.Updated = time.Now() - for i := range s.sources { - if s.sources[i].ID == source.ID { - *s.sources[i] = *source - } else if source.IsDefault { - s.sources[i].IsDefault = false - } - } -} - -var localLoginSources = &LocalLoginSources{} - -// LoadAuthSources loads authentication sources from local files -// and converts them into login sources. -func LoadAuthSources() { - authdPath := filepath.Join(conf.CustomDir(), "conf", "auth.d") - if !com.IsDir(authdPath) { - return - } - - paths, err := com.GetFileListBySuffix(authdPath, ".conf") - if err != nil { - log.Fatal("Failed to list authentication sources: %v", err) - } - - localLoginSources.sources = make([]*LoginSource, 0, len(paths)) - - for _, fpath := range paths { - authSource, err := ini.Load(fpath) - if err != nil { - log.Fatal("Failed to load authentication source: %v", err) - } - authSource.NameMapper = ini.TitleUnderscore - - // Set general attributes - s := authSource.Section("") - loginSource := &LoginSource{ - ID: s.Key("id").MustInt64(), - Name: s.Key("name").String(), - IsActived: s.Key("is_activated").MustBool(), - IsDefault: s.Key("is_default").MustBool(), - LocalFile: &AuthSourceFile{ - abspath: fpath, - file: authSource, - }, - } - - fi, err := os.Stat(fpath) - if err != nil { - log.Fatal("Failed to load authentication source: %v", err) - } - loginSource.Updated = fi.ModTime() - - // Parse authentication source file - authType := s.Key("type").String() - switch authType { - case "ldap_bind_dn": - loginSource.Type = LoginLDAP - loginSource.Cfg = &LDAPConfig{} - case "ldap_simple_auth": - loginSource.Type = LoginDLDAP - loginSource.Cfg = &LDAPConfig{} - case "smtp": - loginSource.Type = LoginSMTP - loginSource.Cfg = &SMTPConfig{} - case "pam": - loginSource.Type = LoginPAM - loginSource.Cfg = &PAMConfig{} - case "github": - loginSource.Type = LoginGitHub - loginSource.Cfg = &GitHubConfig{} - default: - log.Fatal("Failed to load authentication source: unknown type '%s'", authType) - } - - if err = authSource.Section("config").MapTo(loginSource.Cfg); err != nil { - log.Fatal("Failed to parse authentication source 'config': %v", err) - } - - localLoginSources.sources = append(localLoginSources.sources, loginSource) - } -} - -// .____ ________ _____ __________ -// | | \______ \ / _ \\______ \ -// | | | | \ / /_\ \| ___/ -// | |___ | ` \/ | \ | -// |_______ \/_______ /\____|__ /____| -// \/ \/ \/ - func composeFullName(firstname, surname, username string) string { switch { case len(firstname) == 0 && len(surname) == 0: @@ -559,7 +77,7 @@ func composeFullName(firstname, surname, username string) string { // LoginViaLDAP queries if login/password is valid against the LDAP directory pool, // and create a local user if success when enabled. func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool) (*User, error) { - username, fn, sn, mail, isAdmin, succeed := source.Cfg.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP) + username, fn, sn, mail, isAdmin, succeed := source.Config.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP) if !succeed { // User not in LDAP, do nothing return nil, ErrUserNotExist{args: map[string]interface{}{"login": login}} @@ -606,12 +124,18 @@ func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool return user, CreateUser(user) } -// _________ __________________________ -// / _____/ / \__ ___/\______ \ -// \_____ \ / \ / \| | | ___/ -// / \/ Y \ | | | -// /_______ /\____|__ /____| |____| -// \/ \/ +// *********************** +// ----- SMTP config ----- +// *********************** + +type SMTPConfig struct { + Auth string + Host string + Port int + AllowedDomains string + TLS bool `ini:"tls"` + SkipVerify bool +} type smtpLoginAuth struct { username, password string @@ -634,11 +158,11 @@ func (auth *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) { } const ( - SMTP_PLAIN = "PLAIN" - SMTP_LOGIN = "LOGIN" + SMTPPlain = "PLAIN" + SMTPLogin = "LOGIN" ) -var SMTPAuths = []string{SMTP_PLAIN, SMTP_LOGIN} +var SMTPAuths = []string{SMTPPlain, SMTPLogin} func SMTPAuth(a smtp.Auth, cfg *SMTPConfig) error { c, err := smtp.Dial(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)) @@ -687,9 +211,9 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR } var auth smtp.Auth - if cfg.Auth == SMTP_PLAIN { + if cfg.Auth == SMTPPlain { auth = smtp.PlainAuth("", login, password, cfg.Host) - } else if cfg.Auth == SMTP_LOGIN { + } else if cfg.Auth == SMTPLogin { auth = &smtpLoginAuth{login, password} } else { return nil, errors.New("Unsupported SMTP authentication type") @@ -729,12 +253,14 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR return user, CreateUser(user) } -// __________ _____ _____ -// \______ \/ _ \ / \ -// | ___/ /_\ \ / \ / \ -// | | / | \/ Y \ -// |____| \____|__ /\____|__ / -// \/ \/ +// ********************** +// ----- PAM config ----- +// ********************** + +type PAMConfig struct { + // The name of the PAM service, e.g. system-auth. + ServiceName string +} // LoginViaPAM queries if login/password is valid against the PAM, // and create a local user if success when enabled. @@ -763,12 +289,14 @@ func LoginViaPAM(login, password string, sourceID int64, cfg *PAMConfig, autoReg return user, CreateUser(user) } -// ________.__ __ ___ ___ ___. -// / _____/|__|/ |_ / | \ __ _\_ |__ -// / \ ___| \ __\/ ~ \ | \ __ \ -// \ \_\ \ || | \ Y / | / \_\ \ -// \______ /__||__| \___|_ /|____/|___ / -// \/ \/ \/ +// ************************* +// ----- GitHub config ----- +// ************************* + +type GitHubConfig struct { + // the GitHub service endpoint, e.g. https://api.github.com/. + APIEndpoint string +} func LoginViaGitHub(login, password string, sourceID int64, cfg *GitHubConfig, autoRegister bool) (*User, error) { fullname, email, url, location, err := github.Authenticate(cfg.APIEndpoint, login, password) @@ -807,11 +335,11 @@ func authenticateViaLoginSource(source *LoginSource, login, password string, aut case LoginLDAP, LoginDLDAP: return LoginViaLDAP(login, password, source, autoRegister) case LoginSMTP: - return LoginViaSMTP(login, password, source.ID, source.Cfg.(*SMTPConfig), autoRegister) + return LoginViaSMTP(login, password, source.ID, source.Config.(*SMTPConfig), autoRegister) case LoginPAM: - return LoginViaPAM(login, password, source.ID, source.Cfg.(*PAMConfig), autoRegister) + return LoginViaPAM(login, password, source.ID, source.Config.(*PAMConfig), autoRegister) case LoginGitHub: - return LoginViaGitHub(login, password, source.ID, source.Cfg.(*GitHubConfig), autoRegister) + return LoginViaGitHub(login, password, source.ID, source.Config.(*GitHubConfig), autoRegister) } return nil, errors.InvalidLoginSourceType{Type: source.Type} diff --git a/internal/db/login_source_files.go b/internal/db/login_source_files.go new file mode 100644 index 00000000..4c5cb377 --- /dev/null +++ b/internal/db/login_source_files.go @@ -0,0 +1,212 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/jinzhu/gorm" + "github.com/pkg/errors" + "gopkg.in/ini.v1" + + "gogs.io/gogs/internal/errutil" + "gogs.io/gogs/internal/osutil" +) + +// loginSourceFilesStore is the in-memory interface for login source files stored on file system. +// +// NOTE: All methods are sorted in alphabetical order. +type loginSourceFilesStore interface { + // GetByID returns a clone of login source by given ID. + GetByID(id int64) (*LoginSource, error) + // Len returns number of login sources. + Len() int + // List returns a list of login sources filtered by options. + List(opts ListLoginSourceOpts) []*LoginSource + // Update updates in-memory copy of the authentication source. + Update(source *LoginSource) +} + +var _ loginSourceFilesStore = (*loginSourceFiles)(nil) + +// loginSourceFiles contains authentication sources configured and loaded from local files. +type loginSourceFiles struct { + sync.RWMutex + sources []*LoginSource +} + +var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil) + +type ErrLoginSourceNotExist struct { + args errutil.Args +} + +func IsErrLoginSourceNotExist(err error) bool { + _, ok := err.(ErrLoginSourceNotExist) + return ok +} + +func (err ErrLoginSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist: %v", err.args) +} + +func (ErrLoginSourceNotExist) NotFound() bool { + return true +} + +func (s *loginSourceFiles) GetByID(id int64) (*LoginSource, error) { + s.RLock() + defer s.RUnlock() + + for _, source := range s.sources { + if source.ID == id { + return source, nil + } + } + + return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}} +} + +func (s *loginSourceFiles) Len() int { + s.RLock() + defer s.RUnlock() + return len(s.sources) +} + +func (s *loginSourceFiles) List(opts ListLoginSourceOpts) []*LoginSource { + s.RLock() + defer s.RUnlock() + + list := make([]*LoginSource, 0, s.Len()) + for _, source := range s.sources { + if opts.OnlyActivated && !source.IsActived { + continue + } + + list = append(list, source) + } + return list +} + +func (s *loginSourceFiles) Update(source *LoginSource) { + s.Lock() + defer s.Unlock() + + source.Updated = gorm.NowFunc() + for _, old := range s.sources { + if old.ID == source.ID { + *old = *source + } else if source.IsDefault { + old.IsDefault = false + } + } +} + +// loadLoginSourceFiles loads login sources from file system. +func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) { + if !osutil.IsDir(authdPath) { + return &loginSourceFiles{}, nil + } + + store := &loginSourceFiles{} + return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if path == authdPath || !strings.HasSuffix(path, ".conf") { + return nil + } else if info.IsDir() { + return filepath.SkipDir + } + + authSource, err := ini.Load(path) + if err != nil { + return errors.Wrap(err, "load file") + } + authSource.NameMapper = ini.TitleUnderscore + + // Set general attributes + s := authSource.Section("") + loginSource := &LoginSource{ + ID: s.Key("id").MustInt64(), + Name: s.Key("name").String(), + IsActived: s.Key("is_activated").MustBool(), + IsDefault: s.Key("is_default").MustBool(), + File: &loginSourceFile{ + path: path, + file: authSource, + }, + } + + fi, err := os.Stat(path) + if err != nil { + return errors.Wrap(err, "stat file") + } + loginSource.Updated = fi.ModTime() + + // Parse authentication source file + authType := s.Key("type").String() + switch authType { + case "ldap_bind_dn": + loginSource.Type = LoginLDAP + loginSource.Config = &LDAPConfig{} + case "ldap_simple_auth": + loginSource.Type = LoginDLDAP + loginSource.Config = &LDAPConfig{} + case "smtp": + loginSource.Type = LoginSMTP + loginSource.Config = &SMTPConfig{} + case "pam": + loginSource.Type = LoginPAM + loginSource.Config = &PAMConfig{} + case "github": + loginSource.Type = LoginGitHub + loginSource.Config = &GitHubConfig{} + default: + return fmt.Errorf("unknown type %q", authType) + } + + if err = authSource.Section("config").MapTo(loginSource.Config); err != nil { + return errors.Wrap(err, `map "config" section`) + } + + store.sources = append(store.sources, loginSource) + return nil + }) +} + +// loginSourceFileStore is the persistent interface for a login source file. +type loginSourceFileStore interface { + // SetGeneral sets new value to the given key in the general (default) section. + SetGeneral(name, value string) + // SetConfig sets new values to the "config" section. + SetConfig(cfg interface{}) error + // Save persists values to file system. + Save() error +} + +var _ loginSourceFileStore = (*loginSourceFile)(nil) + +type loginSourceFile struct { + path string + file *ini.File +} + +func (f *loginSourceFile) SetGeneral(name, value string) { + f.file.Section("").Key(name).SetValue(value) +} + +func (f *loginSourceFile) SetConfig(cfg interface{}) error { + return f.file.Section("config").ReflectFrom(cfg) +} + +func (f *loginSourceFile) Save() error { + return f.file.SaveTo(f.path) +} diff --git a/internal/db/login_sources.go b/internal/db/login_sources.go index 41ee73b9..b620a442 100644 --- a/internal/db/login_sources.go +++ b/internal/db/login_sources.go @@ -5,22 +5,242 @@ package db import ( + "fmt" + "strconv" + "time" + "github.com/jinzhu/gorm" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" + + "gogs.io/gogs/internal/auth/ldap" + "gogs.io/gogs/internal/errutil" ) // LoginSourcesStore is the persistent interface for login sources. // // NOTE: All methods are sorted in alphabetical order. type LoginSourcesStore interface { + // Create creates a new login source and persist to database. + // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists. + Create(opts CreateLoginSourceOpts) (*LoginSource, error) + // Count returns the total number of login sources. + Count() int64 + // DeleteByID deletes a login source by given ID. + // It returns ErrLoginSourceInUse if at least one user is associated with the login source. + DeleteByID(id int64) error // GetByID returns the login source with given ID. // It returns ErrLoginSourceNotExist when not found. GetByID(id int64) (*LoginSource, error) + // List returns a list of login sources filtered by options. + List(opts ListLoginSourceOpts) ([]*LoginSource, error) + // ResetNonDefault clears default flag for all the other login sources. + ResetNonDefault(source *LoginSource) error + // Save persists all values of given login source to database or local file. + // The Updated field is set to current time automatically. + Save(t *LoginSource) error } var LoginSources LoginSourcesStore +// LoginSource represents an external way for authorizing users. +type LoginSource struct { + ID int64 + Type LoginType + Name string `xorm:"UNIQUE"` + IsActived bool `xorm:"NOT NULL DEFAULT false"` + IsDefault bool `xorm:"DEFAULT false"` + Config interface{} `xorm:"-" gorm:"-"` + RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"` + + Created time.Time `xorm:"-" gorm:"-" json:"-"` + CreatedUnix int64 + Updated time.Time `xorm:"-" gorm:"-" json:"-"` + UpdatedUnix int64 + + File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"` +} + +// NOTE: This is a GORM save hook. +func (s *LoginSource) BeforeSave() (err error) { + s.RawConfig, err = jsoniter.MarshalToString(s.Config) + return err +} + +// NOTE: This is a GORM create hook. +func (s *LoginSource) BeforeCreate() { + s.CreatedUnix = gorm.NowFunc().Unix() + s.UpdatedUnix = s.CreatedUnix +} + +// NOTE: This is a GORM update hook. +func (s *LoginSource) BeforeUpdate() { + s.UpdatedUnix = gorm.NowFunc().Unix() +} + +// NOTE: This is a GORM query hook. +func (s *LoginSource) AfterFind() error { + s.Created = time.Unix(s.CreatedUnix, 0).Local() + s.Updated = time.Unix(s.UpdatedUnix, 0).Local() + + switch s.Type { + case LoginLDAP, LoginDLDAP: + s.Config = new(LDAPConfig) + case LoginSMTP: + s.Config = new(SMTPConfig) + case LoginPAM: + s.Config = new(PAMConfig) + case LoginGitHub: + s.Config = new(GitHubConfig) + default: + return fmt.Errorf("unrecognized login source type: %v", s.Type) + } + return jsoniter.UnmarshalFromString(s.RawConfig, s.Config) +} + +func (s *LoginSource) TypeName() string { + return LoginNames[s.Type] +} + +func (s *LoginSource) IsLDAP() bool { + return s.Type == LoginLDAP +} + +func (s *LoginSource) IsDLDAP() bool { + return s.Type == LoginDLDAP +} + +func (s *LoginSource) IsSMTP() bool { + return s.Type == LoginSMTP +} + +func (s *LoginSource) IsPAM() bool { + return s.Type == LoginPAM +} + +func (s *LoginSource) IsGitHub() bool { + return s.Type == LoginGitHub +} + +func (s *LoginSource) HasTLS() bool { + return ((s.IsLDAP() || s.IsDLDAP()) && + s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) || + s.IsSMTP() +} + +func (s *LoginSource) UseTLS() bool { + switch s.Type { + case LoginLDAP, LoginDLDAP: + return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted + case LoginSMTP: + return s.SMTP().TLS + } + + return false +} + +func (s *LoginSource) SkipVerify() bool { + switch s.Type { + case LoginLDAP, LoginDLDAP: + return s.LDAP().SkipVerify + case LoginSMTP: + return s.SMTP().SkipVerify + } + + return false +} + +func (s *LoginSource) LDAP() *LDAPConfig { + return s.Config.(*LDAPConfig) +} + +func (s *LoginSource) SMTP() *SMTPConfig { + return s.Config.(*SMTPConfig) +} + +func (s *LoginSource) PAM() *PAMConfig { + return s.Config.(*PAMConfig) +} + +func (s *LoginSource) GitHub() *GitHubConfig { + return s.Config.(*GitHubConfig) +} + +var _ LoginSourcesStore = (*loginSources)(nil) + type loginSources struct { *gorm.DB + files loginSourceFilesStore +} + +type CreateLoginSourceOpts struct { + Type LoginType + Name string + Activated bool + Default bool + Config interface{} +} + +type ErrLoginSourceAlreadyExist struct { + args errutil.Args +} + +func IsErrLoginSourceAlreadyExist(err error) bool { + _, ok := err.(ErrLoginSourceAlreadyExist) + return ok +} + +func (err ErrLoginSourceAlreadyExist) Error() string { + return fmt.Sprintf("login source already exists: %v", err.args) +} + +func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) { + err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error + if err == nil { + return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}} + } else if !gorm.IsRecordNotFoundError(err) { + return nil, err + } + + source := &LoginSource{ + Type: opts.Type, + Name: opts.Name, + IsActived: opts.Activated, + IsDefault: opts.Default, + Config: opts.Config, + } + return source, db.DB.Create(source).Error +} + +func (db *loginSources) Count() int64 { + var count int64 + db.Model(new(LoginSource)).Count(&count) + return count + int64(db.files.Len()) +} + +type ErrLoginSourceInUse struct { + args errutil.Args +} + +func IsErrLoginSourceInUse(err error) bool { + _, ok := err.(ErrLoginSourceInUse) + return ok +} + +func (err ErrLoginSourceInUse) Error() string { + return fmt.Sprintf("login source is still used by some users: %v", err.args) +} + +func (db *loginSources) DeleteByID(id int64) error { + var count int64 + err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error + if err != nil { + return err + } else if count > 0 { + return ErrLoginSourceInUse{args: errutil.Args{"id": id}} + } + + return db.Where("id = ?", id).Delete(new(LoginSource)).Error } func (db *loginSources) GetByID(id int64) (*LoginSource, error) { @@ -28,9 +248,63 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) { err := db.Where("id = ?", id).First(source).Error if err != nil { if gorm.IsRecordNotFoundError(err) { - return localLoginSources.GetLoginSourceByID(id) + return db.files.GetByID(id) } return nil, err } return source, nil } + +type ListLoginSourceOpts struct { + // Whether to only include activated login sources. + OnlyActivated bool +} + +func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) { + var sources []*LoginSource + query := db.Order("id ASC") + if opts.OnlyActivated { + query = query.Where("is_actived = ?", true) + } + err := query.Find(&sources).Error + if err != nil { + return nil, err + } + + return append(sources, db.files.List(opts)...), nil +} + +func (db *loginSources) ResetNonDefault(dflt *LoginSource) error { + err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error + if err != nil { + return err + } + + for _, source := range db.files.List(ListLoginSourceOpts{}) { + if source.File != nil && source.ID != dflt.ID { + source.File.SetGeneral("is_default", "false") + if err = source.File.Save(); err != nil { + return errors.Wrap(err, "save file") + } + } + } + + db.files.Update(dflt) + return nil +} + +func (db *loginSources) Save(source *LoginSource) error { + if source.File == nil { + return db.DB.Save(source).Error + } + + source.File.SetGeneral("name", source.Name) + source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived)) + source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault)) + if err := source.File.SetConfig(source.Config); err != nil { + return errors.Wrap(err, "set config") + } else if err = source.File.Save(); err != nil { + return errors.Wrap(err, "save file") + } + return nil +} diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go new file mode 100644 index 00000000..8cc0ab19 --- /dev/null +++ b/internal/db/login_sources_test.go @@ -0,0 +1,389 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "testing" + "time" + + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + + "gogs.io/gogs/internal/errutil" +) + +func Test_loginSources(t *testing.T) { + if testing.Short() { + t.Skip() + } + + t.Parallel() + + tables := []interface{}{new(LoginSource), new(User)} + db := &loginSources{ + DB: initTestDB(t, "loginSources", tables...), + } + + for _, tc := range []struct { + name string + test func(*testing.T, *loginSources) + }{ + {"Create", test_loginSources_Create}, + {"Count", test_loginSources_Count}, + {"DeleteByID", test_loginSources_DeleteByID}, + {"GetByID", test_loginSources_GetByID}, + {"List", test_loginSources_List}, + {"ResetNonDefault", test_loginSources_ResetNonDefault}, + {"Save", test_loginSources_Save}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Cleanup(func() { + err := clearTables(db.DB, tables...) + if err != nil { + t.Fatal(err) + } + }) + tc.test(t, db) + }) + } +} + +func test_loginSources_Create(t *testing.T, db *loginSources) { + // Create first login source with name "GitHub" + source, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: false, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Get it back and check the Created field + source, err = db.GetByID(source.ID) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.Format(time.RFC3339)) + assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.Format(time.RFC3339)) + + // Try create second login source with same name should fail + _, err = db.Create(CreateLoginSourceOpts{Name: source.Name}) + expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}} + assert.Equal(t, expErr, err) +} + +func test_loginSources_Count(t *testing.T, db *loginSources) { + // Create two login sources, one in database and one as source file. + _, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: false, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{ + MockLen: func() int { + return 2 + }, + }) + + assert.Equal(t, int64(3), db.Count()) +} + +func test_loginSources_DeleteByID(t *testing.T, db *loginSources) { + t.Run("delete but in used", func(t *testing.T) { + source, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: false, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Create a user that uses this login source + user := &User{ + LoginSource: source.ID, + } + err = db.DB.Create(user).Error + if err != nil { + t.Fatal(err) + } + + // Delete the login source will result in error + err = db.DeleteByID(source.ID) + expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}} + assert.Equal(t, expErr, err) + }) + + setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{ + MockGetByID: func(id int64) (*LoginSource, error) { + return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}} + }, + }) + + // Create a login source with name "GitHub2" + source, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub2", + Activated: true, + Default: false, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Delete a non-existent ID is noop + err = db.DeleteByID(9999) + if err != nil { + t.Fatal(err) + } + + // We should be able to get it back + _, err = db.GetByID(source.ID) + if err != nil { + t.Fatal(err) + } + + // Now delete this login source with ID + err = db.DeleteByID(source.ID) + if err != nil { + t.Fatal(err) + } + + // We should get token not found error + _, err = db.GetByID(source.ID) + expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}} + assert.Equal(t, expErr, err) +} + +func test_loginSources_GetByID(t *testing.T, db *loginSources) { + setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{ + MockGetByID: func(id int64) (*LoginSource, error) { + if id != 101 { + return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}} + } + return &LoginSource{ID: id}, nil + }, + }) + + expConfig := &GitHubConfig{ + APIEndpoint: "https://api.github.com", + } + + // Create a login source with name "GitHub" + source, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: false, + Config: expConfig, + }) + if err != nil { + t.Fatal(err) + } + + // Get the one in the database and test the read/write hooks + source, err = db.GetByID(source.ID) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, expConfig, source.Config) + + // Get the one in source file store + _, err = db.GetByID(101) + if err != nil { + t.Fatal(err) + } +} + +func test_loginSources_List(t *testing.T, db *loginSources) { + setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{ + MockList: func(opts ListLoginSourceOpts) []*LoginSource { + if opts.OnlyActivated { + return []*LoginSource{ + {ID: 1}, + } + } + return []*LoginSource{ + {ID: 1}, + {ID: 2}, + } + }, + }) + + // Create two login sources in database, one activated and the other one not + _, err := db.Create(CreateLoginSourceOpts{ + Type: LoginPAM, + Name: "PAM", + Config: &PAMConfig{ + ServiceName: "PAM", + }, + }) + if err != nil { + t.Fatal(err) + } + _, err = db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + // List all login sources + sources, err := db.List(ListLoginSourceOpts{}) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 4, len(sources), "number of sources") + + // Only list activated login sources + sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true}) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 2, len(sources), "number of sources") +} + +func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) { + setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{ + MockList: func(opts ListLoginSourceOpts) []*LoginSource { + return []*LoginSource{ + { + File: &mockLoginSourceFileStore{ + MockSetGeneral: func(name, value string) { + assert.Equal(t, "is_default", name) + assert.Equal(t, "false", value) + }, + MockSave: func() error { + return nil + }, + }, + }, + } + }, + MockUpdate: func(source *LoginSource) {}, + }) + + // Create two login sources both have default on + source1, err := db.Create(CreateLoginSourceOpts{ + Type: LoginPAM, + Name: "PAM", + Default: true, + Config: &PAMConfig{ + ServiceName: "PAM", + }, + }) + if err != nil { + t.Fatal(err) + } + source2, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: true, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + // Set source 1 as default + err = db.ResetNonDefault(source1) + if err != nil { + t.Fatal(err) + } + + // Verify the default state + source1, err = db.GetByID(source1.ID) + if err != nil { + t.Fatal(err) + } + assert.True(t, source1.IsDefault) + + source2, err = db.GetByID(source2.ID) + if err != nil { + t.Fatal(err) + } + assert.False(t, source2.IsDefault) +} + +func test_loginSources_Save(t *testing.T, db *loginSources) { + t.Run("save to database", func(t *testing.T) { + // Create a login source with name "GitHub" + source, err := db.Create(CreateLoginSourceOpts{ + Type: LoginGitHub, + Name: "GitHub", + Activated: true, + Default: false, + Config: &GitHubConfig{ + APIEndpoint: "https://api.github.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + source.IsActived = false + source.Config = &GitHubConfig{ + APIEndpoint: "https://api2.github.com", + } + err = db.Save(source) + if err != nil { + t.Fatal(err) + } + + source, err = db.GetByID(source.ID) + if err != nil { + t.Fatal(err) + } + assert.False(t, source.IsActived) + assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint) + }) + + t.Run("save to file", func(t *testing.T) { + calledSave := false + source := &LoginSource{ + File: &mockLoginSourceFileStore{ + MockSetGeneral: func(name, value string) {}, + MockSetConfig: func(cfg interface{}) error { return nil }, + MockSave: func() error { + calledSave = true + return nil + }, + }, + } + err := db.Save(source) + if err != nil { + t.Fatal(err) + } + assert.True(t, calledSave) + }) +} diff --git a/internal/db/main_test.go b/internal/db/main_test.go index fdd7447f..172f00fe 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -41,7 +41,8 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func deleteTables(db *gorm.DB, tables ...interface{}) error { +// clearTables removes all rows from given tables. +func clearTables(db *gorm.DB, tables ...interface{}) error { for _, t := range tables { err := db.Delete(t).Error if err != nil { diff --git a/internal/db/mocks.go b/internal/db/mocks.go index 7e6ddced..9b1e9028 100644 --- a/internal/db/mocks.go +++ b/internal/db/mocks.go @@ -78,6 +78,59 @@ func SetMockLFSStore(t *testing.T, mock LFSStore) { }) } +var _ loginSourceFilesStore = (*mockLoginSourceFilesStore)(nil) + +type mockLoginSourceFilesStore struct { + MockGetByID func(id int64) (*LoginSource, error) + MockLen func() int + MockList func(opts ListLoginSourceOpts) []*LoginSource + MockUpdate func(source *LoginSource) +} + +func (m *mockLoginSourceFilesStore) GetByID(id int64) (*LoginSource, error) { + return m.MockGetByID(id) +} + +func (m *mockLoginSourceFilesStore) Len() int { + return m.MockLen() +} + +func (m *mockLoginSourceFilesStore) List(opts ListLoginSourceOpts) []*LoginSource { + return m.MockList(opts) +} + +func (m *mockLoginSourceFilesStore) Update(source *LoginSource) { + m.MockUpdate(source) +} + +func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) { + before := db.files + db.files = mock + t.Cleanup(func() { + db.files = before + }) +} + +var _ loginSourceFileStore = (*mockLoginSourceFileStore)(nil) + +type mockLoginSourceFileStore struct { + MockSetGeneral func(name, value string) + MockSetConfig func(cfg interface{}) error + MockSave func() error +} + +func (m *mockLoginSourceFileStore) SetGeneral(name, value string) { + m.MockSetGeneral(name, value) +} + +func (m *mockLoginSourceFileStore) SetConfig(cfg interface{}) error { + return m.MockSetConfig(cfg) +} + +func (m *mockLoginSourceFileStore) Save() error { + return m.MockSave() +} + var _ PermsStore = (*MockPermsStore)(nil) type MockPermsStore struct { diff --git a/internal/db/models.go b/internal/db/models.go index 9b5b0d9a..37783a13 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -53,7 +53,7 @@ func init() { new(Watch), new(Star), new(Follow), new(Action), new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser), new(Label), new(IssueLabel), new(Milestone), - new(Mirror), new(Release), new(LoginSource), new(Webhook), new(HookTask), + new(Mirror), new(Release), new(Webhook), new(HookTask), new(ProtectBranch), new(ProtectBranchWhitelist), new(Team), new(OrgUser), new(TeamUser), new(TeamRepo), new(Notice), new(EmailAddress)) @@ -200,7 +200,7 @@ func GetStatistic() (stats Statistic) { stats.Counter.Follow, _ = x.Count(new(Follow)) stats.Counter.Mirror, _ = x.Count(new(Mirror)) stats.Counter.Release, _ = x.Count(new(Release)) - stats.Counter.LoginSource = CountLoginSources() + stats.Counter.LoginSource = LoginSources.Count() stats.Counter.Webhook, _ = x.Count(new(Webhook)) stats.Counter.Milestone, _ = x.Count(new(Milestone)) stats.Counter.Label, _ = x.Count(new(Label)) @@ -295,13 +295,13 @@ func ImportDatabase(dirPath string, verbose bool) (err error) { tp := LoginType(com.StrTo(com.ToStr(meta["Type"])).MustInt64()) switch tp { case LoginLDAP, LoginDLDAP: - bean.Cfg = new(LDAPConfig) + bean.Config = new(LDAPConfig) case LoginSMTP: - bean.Cfg = new(SMTPConfig) + bean.Config = new(SMTPConfig) case LoginPAM: - bean.Cfg = new(PAMConfig) + bean.Config = new(PAMConfig) case LoginGitHub: - bean.Cfg = new(GitHubConfig) + bean.Config = new(GitHubConfig) default: return fmt.Errorf("unrecognized login source type:: %v", tp) } diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go index 9f3f4b1b..ea4ed1ea 100644 --- a/internal/db/perms_test.go +++ b/internal/db/perms_test.go @@ -17,8 +17,9 @@ func Test_perms(t *testing.T) { t.Parallel() + tables := []interface{}{new(Access)} db := &perms{ - DB: initTestDB(t, "perms", new(Access)), + DB: initTestDB(t, "perms", tables...), } for _, tc := range []struct { @@ -31,7 +32,7 @@ func Test_perms(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { t.Cleanup(func() { - err := deleteTables(db.DB, new(Access)) + err := clearTables(db.DB, tables...) if err != nil { t.Fatal(err) } diff --git a/internal/db/user.go b/internal/db/user.go index a61bb687..11697bc7 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -48,33 +48,33 @@ const ( // User represents the object of individual and member of organization. type User struct { ID int64 - LowerName string `xorm:"UNIQUE NOT NULL"` - Name string `xorm:"UNIQUE NOT NULL"` + LowerName string `xorm:"UNIQUE NOT NULL" gorm:"UNIQUE"` + Name string `xorm:"UNIQUE NOT NULL" gorm:"NOT NULL"` FullName string // Email is the primary email address (to be used for communication) - Email string `xorm:"NOT NULL"` - Passwd string `xorm:"NOT NULL"` + Email string `xorm:"NOT NULL" gorm:"NOT NULL"` + Passwd string `xorm:"NOT NULL" gorm:"NOT NULL"` LoginType LoginType - LoginSource int64 `xorm:"NOT NULL DEFAULT 0"` + LoginSource int64 `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"` LoginName string Type UserType - OwnedOrgs []*User `xorm:"-" json:"-"` - Orgs []*User `xorm:"-" json:"-"` - Repos []*Repository `xorm:"-" json:"-"` + OwnedOrgs []*User `xorm:"-" gorm:"-" json:"-"` + Orgs []*User `xorm:"-" gorm:"-" json:"-"` + Repos []*Repository `xorm:"-" gorm:"-" json:"-"` Location string Website string - Rands string `xorm:"VARCHAR(10)"` - Salt string `xorm:"VARCHAR(10)"` + Rands string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"` + Salt string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"` - Created time.Time `xorm:"-" json:"-"` + Created time.Time `xorm:"-" gorm:"-" json:"-"` CreatedUnix int64 - Updated time.Time `xorm:"-" json:"-"` + Updated time.Time `xorm:"-" gorm:"-" json:"-"` UpdatedUnix int64 // Remember visibility choice for convenience, true for private LastRepoVisibility bool - // Maximum repository creation limit, -1 means use gloabl default - MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1"` + // Maximum repository creation limit, -1 means use global default + MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1" gorm:"NOT NULL;DEFAULT:-1"` // Permissions IsActive bool // Activate primary email @@ -84,13 +84,13 @@ type User struct { ProhibitLogin bool // Avatar - Avatar string `xorm:"VARCHAR(2048) NOT NULL"` - AvatarEmail string `xorm:"NOT NULL"` + Avatar string `xorm:"VARCHAR(2048) NOT NULL" gorm:"TYPE:VARCHAR(2048);NOT NULL"` + AvatarEmail string `xorm:"NOT NULL" gorm:"NOT NULL"` UseCustomAvatar bool // Counters NumFollowers int - NumFollowing int `xorm:"NOT NULL DEFAULT 0"` + NumFollowing int `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"` NumStars int NumRepos int @@ -98,8 +98,8 @@ type User struct { Description string NumTeams int NumMembers int - Teams []*Team `xorm:"-" json:"-"` - Members []*User `xorm:"-" json:"-"` + Teams []*Team `xorm:"-" gorm:"-" json:"-"` + Members []*User `xorm:"-" gorm:"-" json:"-"` } func (u *User) BeforeInsert() { |