diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/error.go | 13 | ||||
-rw-r--r-- | models/errors/login_source.go | 27 | ||||
-rw-r--r-- | models/login_source.go | 339 |
3 files changed, 296 insertions, 83 deletions
diff --git a/models/error.go b/models/error.go index 410a5a7e..63e06f6e 100644 --- a/models/error.go +++ b/models/error.go @@ -505,19 +505,6 @@ func (err ErrAttachmentNotExist) Error() string { // |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ > // \/ /_____/ \/ \/ \/ \/ -type ErrLoginSourceNotExist struct { - ID int64 -} - -func IsErrLoginSourceNotExist(err error) bool { - _, ok := err.(ErrLoginSourceNotExist) - return ok -} - -func (err ErrLoginSourceNotExist) Error() string { - return fmt.Sprintf("login source does not exist [id: %d]", err.ID) -} - type ErrLoginSourceAlreadyExist struct { Name string } diff --git a/models/errors/login_source.go b/models/errors/login_source.go index db0cd1f9..dd18664e 100644 --- a/models/errors/login_source.go +++ b/models/errors/login_source.go @@ -6,6 +6,19 @@ package errors import "fmt" +type LoginSourceNotExist struct { + ID int64 +} + +func IsLoginSourceNotExist(err error) bool { + _, ok := err.(LoginSourceNotExist) + return ok +} + +func (err LoginSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist [id: %d]", err.ID) +} + type LoginSourceNotActivated struct { SourceID int64 } @@ -31,3 +44,17 @@ func IsInvalidLoginSourceType(err error) bool { func (err InvalidLoginSourceType) Error() string { return fmt.Sprintf("invalid login source type [type: %v]", err.Type) } + +type LoginSourceMismatch struct { + Expect int64 + Actual int64 +} + +func IsLoginSourceMismatch(err error) bool { + _, ok := err.(LoginSourceMismatch) + return ok +} + +func (err LoginSourceMismatch) Error() string { + return fmt.Sprintf("login source mismatch [expect: %d, actual: %d]", err.Expect, err.Actual) +} diff --git a/models/login_source.go b/models/login_source.go index 12db7864..3a1162ec 100644 --- a/models/login_source.go +++ b/models/login_source.go @@ -10,7 +10,10 @@ import ( "fmt" "net/smtp" "net/textproto" + "os" + "path" "strings" + "sync" "time" "github.com/Unknwon/com" @@ -18,10 +21,12 @@ import ( "github.com/go-xorm/core" "github.com/go-xorm/xorm" log "gopkg.in/clog.v1" + "gopkg.in/ini.v1" "github.com/gogits/gogs/models/errors" "github.com/gogits/gogs/pkg/auth/ldap" "github.com/gogits/gogs/pkg/auth/pam" + "github.com/gogits/gogs/pkg/setting" ) type LoginType int @@ -57,7 +62,7 @@ var ( ) type LDAPConfig struct { - *ldap.Source + *ldap.Source `ini:"config"` } func (cfg *LDAPConfig) FromDB(bs []byte) error { @@ -77,7 +82,7 @@ type SMTPConfig struct { Host string Port int AllowedDomains string `xorm:"TEXT"` - TLS bool + TLS bool `ini:"tls"` SkipVerify bool } @@ -90,7 +95,7 @@ func (cfg *SMTPConfig) ToDB() ([]byte, error) { } type PAMConfig struct { - ServiceName string // pam service (e.g. system-auth) + ServiceName string // PAM service (e.g. system-auth) } func (cfg *PAMConfig) FromDB(bs []byte) error { @@ -101,6 +106,27 @@ func (cfg *PAMConfig) ToDB() ([]byte, error) { return json.Marshal(cfg) } +// AuthSourceFile contains information of an authentication source file. +type AuthSourceFile struct { + abspath string + file *ini.File +} + +// SetGeneral sets new value to the given key in the general (default) section. +func (f *AuthSourceFile) SetGeneral(name, value string) { + f.file.Section("").Key(name).SetValue(value) +} + +// SetConfig sets new values to the "config" section. +func (f *AuthSourceFile) SetConfig(cfg core.Conversion) error { + return f.file.Section("config").ReflectFrom(cfg) +} + +// Save writes updates into file system. +func (f *AuthSourceFile) Save() error { + return f.file.SaveTo(f.abspath) +} + // LoginSource represents an external way for authorizing users. type LoginSource struct { ID int64 @@ -113,6 +139,8 @@ type LoginSource struct { CreatedUnix int64 Updated time.Time `xorm:"-"` UpdatedUnix int64 + + LocalFile *AuthSourceFile `xorm:"-"` } func (s *LoginSource) BeforeInsert() { @@ -135,16 +163,16 @@ func Cell2Int64(val xorm.Cell) int64 { return (*val).(int64) } -func (source *LoginSource) BeforeSet(colName string, val xorm.Cell) { +func (s *LoginSource) BeforeSet(colName string, val xorm.Cell) { switch colName { case "type": switch LoginType(Cell2Int64(val)) { case LOGIN_LDAP, LOGIN_DLDAP: - source.Cfg = new(LDAPConfig) + s.Cfg = new(LDAPConfig) case LOGIN_SMTP: - source.Cfg = new(SMTPConfig) + s.Cfg = new(SMTPConfig) case LOGIN_PAM: - source.Cfg = new(PAMConfig) + s.Cfg = new(PAMConfig) default: panic("unrecognized login source type: " + com.ToStr(*val)) } @@ -160,65 +188,66 @@ func (s *LoginSource) AfterSet(colName string, _ xorm.Cell) { } } -func (source *LoginSource) TypeName() string { - return LoginNames[source.Type] +func (s *LoginSource) TypeName() string { + return LoginNames[s.Type] } -func (source *LoginSource) IsLDAP() bool { - return source.Type == LOGIN_LDAP +func (s *LoginSource) IsLDAP() bool { + return s.Type == LOGIN_LDAP } -func (source *LoginSource) IsDLDAP() bool { - return source.Type == LOGIN_DLDAP +func (s *LoginSource) IsDLDAP() bool { + return s.Type == LOGIN_DLDAP } -func (source *LoginSource) IsSMTP() bool { - return source.Type == LOGIN_SMTP +func (s *LoginSource) IsSMTP() bool { + return s.Type == LOGIN_SMTP } -func (source *LoginSource) IsPAM() bool { - return source.Type == LOGIN_PAM +func (s *LoginSource) IsPAM() bool { + return s.Type == LOGIN_PAM } -func (source *LoginSource) HasTLS() bool { - return ((source.IsLDAP() || source.IsDLDAP()) && - source.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) || - source.IsSMTP() +func (s *LoginSource) HasTLS() bool { + return ((s.IsLDAP() || s.IsDLDAP()) && + s.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) || + s.IsSMTP() } -func (source *LoginSource) UseTLS() bool { - switch source.Type { +func (s *LoginSource) UseTLS() bool { + switch s.Type { case LOGIN_LDAP, LOGIN_DLDAP: - return source.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED + return s.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED case LOGIN_SMTP: - return source.SMTP().TLS + return s.SMTP().TLS } return false } -func (source *LoginSource) SkipVerify() bool { - switch source.Type { +func (s *LoginSource) SkipVerify() bool { + switch s.Type { case LOGIN_LDAP, LOGIN_DLDAP: - return source.LDAP().SkipVerify + return s.LDAP().SkipVerify case LOGIN_SMTP: - return source.SMTP().SkipVerify + return s.SMTP().SkipVerify } return false } -func (source *LoginSource) LDAP() *LDAPConfig { - return source.Cfg.(*LDAPConfig) +func (s *LoginSource) LDAP() *LDAPConfig { + return s.Cfg.(*LDAPConfig) } -func (source *LoginSource) SMTP() *SMTPConfig { - return source.Cfg.(*SMTPConfig) +func (s *LoginSource) SMTP() *SMTPConfig { + return s.Cfg.(*SMTPConfig) } -func (source *LoginSource) PAM() *PAMConfig { - return source.Cfg.(*PAMConfig) +func (s *LoginSource) PAM() *PAMConfig { + return s.Cfg.(*PAMConfig) } + func CreateLoginSource(source *LoginSource) error { has, err := x.Get(&LoginSource{Name: source.Name}) if err != nil { @@ -231,9 +260,23 @@ func CreateLoginSource(source *LoginSource) error { return err } +// LoginSources returns all login sources defined. func LoginSources() ([]*LoginSource, error) { - auths := make([]*LoginSource, 0, 5) - return auths, x.Find(&auths) + sources := make([]*LoginSource, 0, 2) + if err := x.Find(&sources); err != nil { + return nil, err + } + + return append(sources, localLoginSources.List()...), nil +} + +// ActivatedLoginSources returns login sources that are currently activated. +func ActivatedLoginSources() ([]*LoginSource, error) { + sources := make([]*LoginSource, 0, 2) + if err := x.Where("is_actived = ?", true).Find(&sources); err != nil { + return nil, fmt.Errorf("find activated login sources: %v", err) + } + return append(sources, localLoginSources.ActivatedList()...), nil } // GetLoginSourceByID returns login source by given ID. @@ -243,14 +286,28 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) { if err != nil { return nil, err } else if !has { - return nil, ErrLoginSourceNotExist{id} + return localLoginSources.GetLoginSourceByID(id) } return source, nil } -func UpdateSource(source *LoginSource) error { - _, err := x.Id(source.ID).AllCols().Update(source) - return err +// UpdateLoginSource updates information of login source to database or local file. +func UpdateLoginSource(source *LoginSource) error { + if source.LocalFile == nil { + _, err := x.Id(source.ID).AllCols().Update(source) + return err + } + + source.LocalFile.SetGeneral("name", source.Name) + source.LocalFile.SetGeneral("is_activated", com.ToStr(source.IsActived)) + if err := source.LocalFile.SetConfig(source.Cfg); err != nil { + return fmt.Errorf("LocalFile.SetConfig: %v", err) + } else if err = source.LocalFile.Save(); err != nil { + return fmt.Errorf("LocalFile.Save: %v", err) + } + + localLoginSources.UpdateLoginSource(source) + return nil } func DeleteSource(source *LoginSource) error { @@ -264,10 +321,151 @@ func DeleteSource(source *LoginSource) error { return err } -// CountLoginSources returns number of login sources. +// CountLoginSources returns total number of login sources. func CountLoginSources() int64 { count, _ := x.Count(new(LoginSource)) - return count + return count + int64(localLoginSources.Len()) +} + +// LocalLoginSources contains authentication sources configured and loaded from local files. +// Calling its methods is thread-safe; otherwise, please maintain the mutex accordingly. +type LocalLoginSources struct { + sync.RWMutex + sources []*LoginSource +} + +func (s *LocalLoginSources) Len() int { + return len(s.sources) +} + +// List returns full clone of login sources. +func (s *LocalLoginSources) List() []*LoginSource { + s.RLock() + defer s.RUnlock() + + list := make([]*LoginSource, s.Len()) + for i := range s.sources { + list[i] = &LoginSource{} + *list[i] = *s.sources[i] + } + return list +} + +// ActivatedList returns clone of activated login sources. +func (s *LocalLoginSources) ActivatedList() []*LoginSource { + s.RLock() + defer s.RUnlock() + + list := make([]*LoginSource, 0, 2) + for i := range s.sources { + if !s.sources[i].IsActived { + continue + } + + source := &LoginSource{} + *source = *s.sources[i] + list = append(list, source) + } + return list +} + +// GetLoginSourceByID returns a clone of login source by given ID. +func (s *LocalLoginSources) GetLoginSourceByID(id int64) (*LoginSource, error) { + s.RLock() + defer s.RUnlock() + + for i := range s.sources { + if s.sources[i].ID == id { + source := &LoginSource{} + *source = *s.sources[i] + return source, nil + } + } + + return nil, errors.LoginSourceNotExist{id} +} + +// UpdateLoginSource updates in-memory copy of the authentication source. +func (s *LocalLoginSources) UpdateLoginSource(source *LoginSource) { + s.Lock() + defer s.Unlock() + + source.Updated = time.Now() + for i := range s.sources { + if s.sources[i].ID == source.ID { + *s.sources[i] = *source + break + } + } +} + +var localLoginSources = &LocalLoginSources{} + +// LoadAuthSources loads authentication sources from local files +// and converts them into login sources. +func LoadAuthSources() { + authdPath := path.Join(setting.CustomPath, "conf/auth.d") + if !com.IsDir(authdPath) { + return + } + + paths, err := com.GetFileListBySuffix(authdPath, ".conf") + if err != nil { + log.Fatal(2, "Failed to list authentication sources: %v", err) + } + + localLoginSources.sources = make([]*LoginSource, 0, len(paths)) + + for _, fpath := range paths { + authSource, err := ini.Load(fpath) + if err != nil { + log.Fatal(2, "Failed to load authentication source: %v", err) + } + authSource.NameMapper = ini.TitleUnderscore + + // Set general attributes + s := authSource.Section("") + loginSource := &LoginSource{ + ID: s.Key("id").MustInt64(), + Name: s.Key("name").String(), + IsActived: s.Key("is_activated").MustBool(), + LocalFile: &AuthSourceFile{ + abspath: fpath, + file: authSource, + }, + } + + fi, err := os.Stat(fpath) + if err != nil { + log.Fatal(2, "Failed to load authentication source: %v", err) + } + loginSource.Updated = fi.ModTime() + + // Parse authentication source file + authType := s.Key("type").String() + switch authType { + case "ldap_bind_dn": + loginSource.Type = LOGIN_LDAP + loginSource.Cfg = &LDAPConfig{} + case "ldap_simple_auth": + loginSource.Type = LOGIN_DLDAP + loginSource.Cfg = &LDAPConfig{} + case "smtp": + loginSource.Type = LOGIN_SMTP + loginSource.Cfg = &SMTPConfig{} + case "pam": + loginSource.Type = LOGIN_PAM + loginSource.Cfg = &PAMConfig{} + default: + log.Fatal(2, "Failed to load authentication source: unknown type '%s'", authType) + } + + if err = authSource.Section("config").MapTo(loginSource.Cfg); err != nil { + log.Fatal(2, "Failed to parse authentication source 'config': %v", err) + } + + localLoginSources.sources = append(localLoginSources.sources, loginSource) + } } // .____ ________ _____ __________ @@ -497,7 +695,7 @@ func LoginViaPAM(user *User, login, password string, sourceID int64, cfg *PAMCon return user, CreateUser(user) } -func ExternalUserLogin(user *User, login, password string, source *LoginSource, autoRegister bool) (*User, error) { +func remoteUserLogin(user *User, login, password string, source *LoginSource, autoRegister bool) (*User, error) { if !source.IsActived { return nil, errors.LoginSourceNotActivated{source.ID} } @@ -514,8 +712,9 @@ func ExternalUserLogin(user *User, login, password string, source *LoginSource, return nil, errors.InvalidLoginSourceType{source.Type} } -// UserSignIn validates user name and password. -func UserSignIn(username, password string) (*User, error) { +// UserLogin validates user name and password via given login source ID. +// If the loginSourceID is negative, it will abort login process if user is not found. +func UserLogin(username, password string, loginSourceID int64) (*User, error) { var user *User if strings.Contains(username, "@") { user = &User{Email: strings.ToLower(username)} @@ -525,44 +724,44 @@ func UserSignIn(username, password string) (*User, error) { hasUser, err := x.Get(user) if err != nil { - return nil, err + return nil, fmt.Errorf("get user record: %v", err) } if hasUser { - switch user.LoginType { - case LOGIN_NOTYPE, LOGIN_PLAIN: + // Note: This check is unnecessary but to reduce user confusion at login page + // and make it more consistent at user's perspective. + if loginSourceID >= 0 && user.LoginSource != loginSourceID { + return nil, errors.LoginSourceMismatch{loginSourceID, user.LoginSource} + } + + // Validate password hash fetched from database for local accounts + if user.LoginType == LOGIN_NOTYPE || + user.LoginType == LOGIN_PLAIN { if user.ValidatePassword(password) { return user, nil } return nil, errors.UserNotExist{user.ID, user.Name} + } - default: - var source LoginSource - hasSource, err := x.Id(user.LoginSource).Get(&source) - if err != nil { - return nil, err - } else if !hasSource { - return nil, ErrLoginSourceNotExist{user.LoginSource} - } - - return ExternalUserLogin(user, user.LoginName, password, &source, false) + // Remote login to the login source the user is associated with + source, err := GetLoginSourceByID(user.LoginSource) + if err != nil { + return nil, err } - } - sources := make([]*LoginSource, 0, 3) - if err = x.UseBool().Find(&sources, &LoginSource{IsActived: true}); err != nil { - return nil, err + return remoteUserLogin(user, user.LoginName, password, source, false) } - for _, source := range sources { - authUser, err := ExternalUserLogin(nil, username, password, source, true) - if err == nil { - return authUser, nil - } + // Non-local login source is always greater than 0 + if loginSourceID <= 0 { + return nil, errors.UserNotExist{-1, username} + } - log.Warn("Failed to login '%s' via '%s': %v", username, source.Name, err) + source, err := GetLoginSourceByID(loginSourceID) + if err != nil { + return nil, err } - return nil, errors.UserNotExist{user.ID, user.Name} + return remoteUserLogin(nil, username, password, source, true) } |