aboutsummaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/access_tokens_test.go13
-rw-r--r--internal/db/db.go9
-rw-r--r--internal/db/error.go33
-rw-r--r--internal/db/errors/login_source.go14
-rw-r--r--internal/db/lfs_test.go5
-rw-r--r--internal/db/login_source.go564
-rw-r--r--internal/db/login_source_files.go212
-rw-r--r--internal/db/login_sources.go276
-rw-r--r--internal/db/login_sources_test.go389
-rw-r--r--internal/db/main_test.go3
-rw-r--r--internal/db/mocks.go53
-rw-r--r--internal/db/models.go12
-rw-r--r--internal/db/perms_test.go5
-rw-r--r--internal/db/user.go38
14 files changed, 1022 insertions, 604 deletions
diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go
index d3dfb83f..0f979e95 100644
--- a/internal/db/access_tokens_test.go
+++ b/internal/db/access_tokens_test.go
@@ -21,8 +21,9 @@ func Test_accessTokens(t *testing.T) {
t.Parallel()
+ tables := []interface{}{new(AccessToken)}
db := &accessTokens{
- DB: initTestDB(t, "accessTokens", new(AccessToken)),
+ DB: initTestDB(t, "accessTokens", tables...),
}
for _, tc := range []struct {
@@ -37,7 +38,7 @@ func Test_accessTokens(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
- err := deleteTables(db.DB, new(AccessToken))
+ err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}
@@ -78,14 +79,14 @@ func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) {
t.Fatal(err)
}
- // We should be able to get it back
- _, err = db.GetBySHA(token.Sha1)
+ // Delete a token with mismatched user ID is noop
+ err = db.DeleteByID(2, token.ID)
if err != nil {
t.Fatal(err)
}
- // Delete a token with mismatched user ID is noop
- err = db.DeleteByID(2, token.ID)
+ // We should be able to get it back
+ _, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/db/db.go b/internal/db/db.go
index 1be2cc4b..77d78f53 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -124,7 +124,7 @@ func getLogWriter() (io.Writer, error) {
var tables = []interface{}{
new(AccessToken),
- new(LFSObject),
+ new(LFSObject), new(LoginSource),
}
func Init() error {
@@ -167,9 +167,14 @@ func Init() error {
return time.Now().UTC().Truncate(time.Microsecond)
}
+ sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"))
+ if err != nil {
+ return errors.Wrap(err, "load login source files")
+ }
+
// Initialize stores, sorted in alphabetical order.
AccessTokens = &accessTokens{DB: db}
- LoginSources = &loginSources{DB: db}
+ LoginSources = &loginSources{DB: db, files: sourceFiles}
LFS = &lfs{DB: db}
Perms = &perms{DB: db}
Repos = &repos{DB: db}
diff --git a/internal/db/error.go b/internal/db/error.go
index ed173d86..46e7dde5 100644
--- a/internal/db/error.go
+++ b/internal/db/error.go
@@ -327,39 +327,6 @@ func (err ErrRepoFileAlreadyExist) Error() string {
return fmt.Sprintf("repository file already exists [file_name: %s]", err.FileName)
}
-// .____ .__ _________
-// | | ____ ____ |__| ____ / _____/ ____ __ _________ ____ ____
-// | | / _ \ / ___\| |/ \ \_____ \ / _ \| | \_ __ \_/ ___\/ __ \
-// | |__( <_> ) /_/ > | | \ / ( <_> ) | /| | \/\ \__\ ___/
-// |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ >
-// \/ /_____/ \/ \/ \/ \/
-
-type ErrLoginSourceAlreadyExist struct {
- Name string
-}
-
-func IsErrLoginSourceAlreadyExist(err error) bool {
- _, ok := err.(ErrLoginSourceAlreadyExist)
- return ok
-}
-
-func (err ErrLoginSourceAlreadyExist) Error() string {
- return fmt.Sprintf("login source already exists [name: %s]", err.Name)
-}
-
-type ErrLoginSourceInUse struct {
- ID int64
-}
-
-func IsErrLoginSourceInUse(err error) bool {
- _, ok := err.(ErrLoginSourceInUse)
- return ok
-}
-
-func (err ErrLoginSourceInUse) Error() string {
- return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
-}
-
// ___________
// \__ ___/___ _____ _____
// | |_/ __ \\__ \ / \
diff --git a/internal/db/errors/login_source.go b/internal/db/errors/login_source.go
index 876a0820..db0cd1f9 100644
--- a/internal/db/errors/login_source.go
+++ b/internal/db/errors/login_source.go
@@ -6,19 +6,6 @@ package errors
import "fmt"
-type LoginSourceNotExist struct {
- ID int64
-}
-
-func IsLoginSourceNotExist(err error) bool {
- _, ok := err.(LoginSourceNotExist)
- return ok
-}
-
-func (err LoginSourceNotExist) Error() string {
- return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
-}
-
type LoginSourceNotActivated struct {
SourceID int64
}
@@ -44,4 +31,3 @@ func IsInvalidLoginSourceType(err error) bool {
func (err InvalidLoginSourceType) Error() string {
return fmt.Sprintf("invalid login source type [type: %v]", err.Type)
}
-
diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go
index 6eb14019..29c7d665 100644
--- a/internal/db/lfs_test.go
+++ b/internal/db/lfs_test.go
@@ -22,8 +22,9 @@ func Test_lfs(t *testing.T) {
t.Parallel()
+ tables := []interface{}{new(LFSObject)}
db := &lfs{
- DB: initTestDB(t, "lfs", new(LFSObject)),
+ DB: initTestDB(t, "lfs", tables...),
}
for _, tc := range []struct {
@@ -36,7 +37,7 @@ func Test_lfs(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
- err := deleteTables(db.DB, new(LFSObject))
+ err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/db/login_source.go b/internal/db/login_source.go
index b6665f4e..e821e186 100644
--- a/internal/db/login_source.go
+++ b/internal/db/login_source.go
@@ -10,30 +10,21 @@ import (
"fmt"
"net/smtp"
"net/textproto"
- "os"
- "path/filepath"
"strings"
- "sync"
- "time"
"github.com/go-macaron/binding"
- "github.com/json-iterator/go"
"github.com/unknwon/com"
- "gopkg.in/ini.v1"
- log "unknwon.dev/clog/v2"
- "xorm.io/core"
- "xorm.io/xorm"
"gogs.io/gogs/internal/auth/github"
"gogs.io/gogs/internal/auth/ldap"
"gogs.io/gogs/internal/auth/pam"
- "gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/db/errors"
)
type LoginType int
// Note: new type must append to the end of list to maintain compatibility.
+// TODO: Move to authutil.
const (
LoginNotype LoginType = iota
LoginPlain // 1
@@ -52,497 +43,24 @@ var LoginNames = map[LoginType]string{
LoginGitHub: "GitHub",
}
-var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
- ldap.SECURITY_PROTOCOL_UNENCRYPTED: "Unencrypted",
- ldap.SECURITY_PROTOCOL_LDAPS: "LDAPS",
- ldap.SECURITY_PROTOCOL_START_TLS: "StartTLS",
-}
-
-// Ensure structs implemented interface.
-var (
- _ core.Conversion = &LDAPConfig{}
- _ core.Conversion = &SMTPConfig{}
- _ core.Conversion = &PAMConfig{}
- _ core.Conversion = &GitHubConfig{}
-)
+// ***********************
+// ----- LDAP config -----
+// ***********************
type LDAPConfig struct {
- *ldap.Source `ini:"config"`
+ ldap.Source `ini:"config"`
}
-func (cfg *LDAPConfig) FromDB(bs []byte) error {
- return jsoniter.Unmarshal(bs, &cfg)
-}
-
-func (cfg *LDAPConfig) ToDB() ([]byte, error) {
- return jsoniter.Marshal(cfg)
+var SecurityProtocolNames = map[ldap.SecurityProtocol]string{
+ ldap.SecurityProtocolUnencrypted: "Unencrypted",
+ ldap.SecurityProtocolLDAPS: "LDAPS",
+ ldap.SecurityProtocolStartTLS: "StartTLS",
}
func (cfg *LDAPConfig) SecurityProtocolName() string {
return SecurityProtocolNames[cfg.SecurityProtocol]
}
-type SMTPConfig struct {
- Auth string
- Host string
- Port int
- AllowedDomains string `xorm:"TEXT"`
- TLS bool `ini:"tls"`
- SkipVerify bool
-}
-
-func (cfg *SMTPConfig) FromDB(bs []byte) error {
- return jsoniter.Unmarshal(bs, cfg)
-}
-
-func (cfg *SMTPConfig) ToDB() ([]byte, error) {
- return jsoniter.Marshal(cfg)
-}
-
-type PAMConfig struct {
- ServiceName string // PAM service (e.g. system-auth)
-}
-
-func (cfg *PAMConfig) FromDB(bs []byte) error {
- return jsoniter.Unmarshal(bs, &cfg)
-}
-
-func (cfg *PAMConfig) ToDB() ([]byte, error) {
- return jsoniter.Marshal(cfg)
-}
-
-type GitHubConfig struct {
- APIEndpoint string // GitHub service (e.g. https://api.github.com/)
-}
-
-func (cfg *GitHubConfig) FromDB(bs []byte) error {
- return jsoniter.Unmarshal(bs, &cfg)
-}
-
-func (cfg *GitHubConfig) ToDB() ([]byte, error) {
- return jsoniter.Marshal(cfg)
-}
-
-// AuthSourceFile contains information of an authentication source file.
-type AuthSourceFile struct {
- abspath string
- file *ini.File
-}
-
-// SetGeneral sets new value to the given key in the general (default) section.
-func (f *AuthSourceFile) SetGeneral(name, value string) {
- f.file.Section("").Key(name).SetValue(value)
-}
-
-// SetConfig sets new values to the "config" section.
-func (f *AuthSourceFile) SetConfig(cfg core.Conversion) error {
- return f.file.Section("config").ReflectFrom(cfg)
-}
-
-// Save writes updates into file system.
-func (f *AuthSourceFile) Save() error {
- return f.file.SaveTo(f.abspath)
-}
-
-// LoginSource represents an external way for authorizing users.
-type LoginSource struct {
- ID int64
- Type LoginType
- Name string `xorm:"UNIQUE"`
- IsActived bool `xorm:"NOT NULL DEFAULT false"`
- IsDefault bool `xorm:"DEFAULT false"`
- Cfg core.Conversion `xorm:"TEXT" gorm:"COLUMN:remove-me-when-migrated-to-gorm"`
- RawCfg string `xorm:"-" gorm:"COLUMN:cfg"` // TODO: Remove me when migrated to GORM.
-
- Created time.Time `xorm:"-" json:"-"`
- CreatedUnix int64
- Updated time.Time `xorm:"-" json:"-"`
- UpdatedUnix int64
-
- LocalFile *AuthSourceFile `xorm:"-" json:"-"`
-}
-
-func (s *LoginSource) BeforeInsert() {
- s.CreatedUnix = time.Now().Unix()
- s.UpdatedUnix = s.CreatedUnix
-}
-
-func (s *LoginSource) BeforeUpdate() {
- s.UpdatedUnix = time.Now().Unix()
-}
-
-// Cell2Int64 converts a xorm.Cell type to int64,
-// and handles possible irregular cases.
-func Cell2Int64(val xorm.Cell) int64 {
- switch (*val).(type) {
- case []uint8:
- log.Trace("Cell2Int64 ([]uint8): %v", *val)
- return com.StrTo(string((*val).([]uint8))).MustInt64()
- }
- return (*val).(int64)
-}
-
-func (s *LoginSource) BeforeSet(colName string, val xorm.Cell) {
- switch colName {
- case "type":
- switch LoginType(Cell2Int64(val)) {
- case LoginLDAP, LoginDLDAP:
- s.Cfg = new(LDAPConfig)
- case LoginSMTP:
- s.Cfg = new(SMTPConfig)
- case LoginPAM:
- s.Cfg = new(PAMConfig)
- case LoginGitHub:
- s.Cfg = new(GitHubConfig)
- default:
- panic("unrecognized login source type: " + com.ToStr(*val))
- }
- }
-}
-
-func (s *LoginSource) AfterSet(colName string, _ xorm.Cell) {
- switch colName {
- case "created_unix":
- s.Created = time.Unix(s.CreatedUnix, 0).Local()
- case "updated_unix":
- s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
- }
-}
-
-// NOTE: This is a GORM query hook.
-func (s *LoginSource) AfterFind() error {
- switch s.Type {
- case LoginLDAP, LoginDLDAP:
- s.Cfg = new(LDAPConfig)
- case LoginSMTP:
- s.Cfg = new(SMTPConfig)
- case LoginPAM:
- s.Cfg = new(PAMConfig)
- case LoginGitHub:
- s.Cfg = new(GitHubConfig)
- default:
- return fmt.Errorf("unrecognized login source type: %v", s.Type)
- }
- return jsoniter.UnmarshalFromString(s.RawCfg, s.Cfg)
-}
-
-func (s *LoginSource) TypeName() string {
- return LoginNames[s.Type]
-}
-
-func (s *LoginSource) IsLDAP() bool {
- return s.Type == LoginLDAP
-}
-
-func (s *LoginSource) IsDLDAP() bool {
- return s.Type == LoginDLDAP
-}
-
-func (s *LoginSource) IsSMTP() bool {
- return s.Type == LoginSMTP
-}
-
-func (s *LoginSource) IsPAM() bool {
- return s.Type == LoginPAM
-}
-
-func (s *LoginSource) IsGitHub() bool {
- return s.Type == LoginGitHub
-}
-
-func (s *LoginSource) HasTLS() bool {
- return ((s.IsLDAP() || s.IsDLDAP()) &&
- s.LDAP().SecurityProtocol > ldap.SECURITY_PROTOCOL_UNENCRYPTED) ||
- s.IsSMTP()
-}
-
-func (s *LoginSource) UseTLS() bool {
- switch s.Type {
- case LoginLDAP, LoginDLDAP:
- return s.LDAP().SecurityProtocol != ldap.SECURITY_PROTOCOL_UNENCRYPTED
- case LoginSMTP:
- return s.SMTP().TLS
- }
-
- return false
-}
-
-func (s *LoginSource) SkipVerify() bool {
- switch s.Type {
- case LoginLDAP, LoginDLDAP:
- return s.LDAP().SkipVerify
- case LoginSMTP:
- return s.SMTP().SkipVerify
- }
-
- return false
-}
-
-func (s *LoginSource) LDAP() *LDAPConfig {
- return s.Cfg.(*LDAPConfig)
-}
-
-func (s *LoginSource) SMTP() *SMTPConfig {
- return s.Cfg.(*SMTPConfig)
-}
-
-func (s *LoginSource) PAM() *PAMConfig {
- return s.Cfg.(*PAMConfig)
-}
-
-func (s *LoginSource) GitHub() *GitHubConfig {
- return s.Cfg.(*GitHubConfig)
-}
-
-func CreateLoginSource(source *LoginSource) error {
- has, err := x.Get(&LoginSource{Name: source.Name})
- if err != nil {
- return err
- } else if has {
- return ErrLoginSourceAlreadyExist{source.Name}
- }
-
- _, err = x.Insert(source)
- if err != nil {
- return err
- } else if source.IsDefault {
- return ResetNonDefaultLoginSources(source)
- }
- return nil
-}
-
-// ListLoginSources returns all login sources defined.
-func ListLoginSources() ([]*LoginSource, error) {
- sources := make([]*LoginSource, 0, 2)
- if err := x.Find(&sources); err != nil {
- return nil, err
- }
-
- return append(sources, localLoginSources.List()...), nil
-}
-
-// ActivatedLoginSources returns login sources that are currently activated.
-func ActivatedLoginSources() ([]*LoginSource, error) {
- sources := make([]*LoginSource, 0, 2)
- if err := x.Where("is_actived = ?", true).Find(&sources); err != nil {
- return nil, fmt.Errorf("find activated login sources: %v", err)
- }
- return append(sources, localLoginSources.ActivatedList()...), nil
-}
-
-// ResetNonDefaultLoginSources clean other default source flag
-func ResetNonDefaultLoginSources(source *LoginSource) error {
- // update changes to DB
- if _, err := x.NotIn("id", []int64{source.ID}).Cols("is_default").Update(&LoginSource{IsDefault: false}); err != nil {
- return err
- }
- // write changes to local authentications
- for i := range localLoginSources.sources {
- if localLoginSources.sources[i].LocalFile != nil && localLoginSources.sources[i].ID != source.ID {
- localLoginSources.sources[i].LocalFile.SetGeneral("is_default", "false")
- if err := localLoginSources.sources[i].LocalFile.SetConfig(source.Cfg); err != nil {
- return fmt.Errorf("LocalFile.SetConfig: %v", err)
- } else if err = localLoginSources.sources[i].LocalFile.Save(); err != nil {
- return fmt.Errorf("LocalFile.Save: %v", err)
- }
- }
- }
- // flush memory so that web page can show the same behaviors
- localLoginSources.UpdateLoginSource(source)
- return nil
-}
-
-// UpdateLoginSource updates information of login source to database or local file.
-func UpdateLoginSource(source *LoginSource) error {
- if source.LocalFile == nil {
- if _, err := x.Id(source.ID).AllCols().Update(source); err != nil {
- return err
- } else {
- return ResetNonDefaultLoginSources(source)
- }
-
- }
-
- source.LocalFile.SetGeneral("name", source.Name)
- source.LocalFile.SetGeneral("is_activated", com.ToStr(source.IsActived))
- source.LocalFile.SetGeneral("is_default", com.ToStr(source.IsDefault))
- if err := source.LocalFile.SetConfig(source.Cfg); err != nil {
- return fmt.Errorf("LocalFile.SetConfig: %v", err)
- } else if err = source.LocalFile.Save(); err != nil {
- return fmt.Errorf("LocalFile.Save: %v", err)
- }
- return ResetNonDefaultLoginSources(source)
-}
-
-func DeleteSource(source *LoginSource) error {
- count, err := x.Count(&User{LoginSource: source.ID})
- if err != nil {
- return err
- } else if count > 0 {
- return ErrLoginSourceInUse{source.ID}
- }
- _, err = x.Id(source.ID).Delete(new(LoginSource))
- return err
-}
-
-// CountLoginSources returns total number of login sources.
-func CountLoginSources() int64 {
- count, _ := x.Count(new(LoginSource))
- return count + int64(localLoginSources.Len())
-}
-
-// LocalLoginSources contains authentication sources configured and loaded from local files.
-// Calling its methods is thread-safe; otherwise, please maintain the mutex accordingly.
-type LocalLoginSources struct {
- sync.RWMutex
- sources []*LoginSource
-}
-
-func (s *LocalLoginSources) Len() int {
- return len(s.sources)
-}
-
-// List returns full clone of login sources.
-func (s *LocalLoginSources) List() []*LoginSource {
- s.RLock()
- defer s.RUnlock()
-
- list := make([]*LoginSource, s.Len())
- for i := range s.sources {
- list[i] = &LoginSource{}
- *list[i] = *s.sources[i]
- }
- return list
-}
-
-// ActivatedList returns clone of activated login sources.
-func (s *LocalLoginSources) ActivatedList() []*LoginSource {
- s.RLock()
- defer s.RUnlock()
-
- list := make([]*LoginSource, 0, 2)
- for i := range s.sources {
- if !s.sources[i].IsActived {
- continue
- }
- source := &LoginSource{}
- *source = *s.sources[i]
- list = append(list, source)
- }
- return list
-}
-
-// GetLoginSourceByID returns a clone of login source by given ID.
-func (s *LocalLoginSources) GetLoginSourceByID(id int64) (*LoginSource, error) {
- s.RLock()
- defer s.RUnlock()
-
- for i := range s.sources {
- if s.sources[i].ID == id {
- source := &LoginSource{}
- *source = *s.sources[i]
- return source, nil
- }
- }
-
- return nil, errors.LoginSourceNotExist{ID: id}
-}
-
-// UpdateLoginSource updates in-memory copy of the authentication source.
-func (s *LocalLoginSources) UpdateLoginSource(source *LoginSource) {
- s.Lock()
- defer s.Unlock()
-
- source.Updated = time.Now()
- for i := range s.sources {
- if s.sources[i].ID == source.ID {
- *s.sources[i] = *source
- } else if source.IsDefault {
- s.sources[i].IsDefault = false
- }
- }
-}
-
-var localLoginSources = &LocalLoginSources{}
-
-// LoadAuthSources loads authentication sources from local files
-// and converts them into login sources.
-func LoadAuthSources() {
- authdPath := filepath.Join(conf.CustomDir(), "conf", "auth.d")
- if !com.IsDir(authdPath) {
- return
- }
-
- paths, err := com.GetFileListBySuffix(authdPath, ".conf")
- if err != nil {
- log.Fatal("Failed to list authentication sources: %v", err)
- }
-
- localLoginSources.sources = make([]*LoginSource, 0, len(paths))
-
- for _, fpath := range paths {
- authSource, err := ini.Load(fpath)
- if err != nil {
- log.Fatal("Failed to load authentication source: %v", err)
- }
- authSource.NameMapper = ini.TitleUnderscore
-
- // Set general attributes
- s := authSource.Section("")
- loginSource := &LoginSource{
- ID: s.Key("id").MustInt64(),
- Name: s.Key("name").String(),
- IsActived: s.Key("is_activated").MustBool(),
- IsDefault: s.Key("is_default").MustBool(),
- LocalFile: &AuthSourceFile{
- abspath: fpath,
- file: authSource,
- },
- }
-
- fi, err := os.Stat(fpath)
- if err != nil {
- log.Fatal("Failed to load authentication source: %v", err)
- }
- loginSource.Updated = fi.ModTime()
-
- // Parse authentication source file
- authType := s.Key("type").String()
- switch authType {
- case "ldap_bind_dn":
- loginSource.Type = LoginLDAP
- loginSource.Cfg = &LDAPConfig{}
- case "ldap_simple_auth":
- loginSource.Type = LoginDLDAP
- loginSource.Cfg = &LDAPConfig{}
- case "smtp":
- loginSource.Type = LoginSMTP
- loginSource.Cfg = &SMTPConfig{}
- case "pam":
- loginSource.Type = LoginPAM
- loginSource.Cfg = &PAMConfig{}
- case "github":
- loginSource.Type = LoginGitHub
- loginSource.Cfg = &GitHubConfig{}
- default:
- log.Fatal("Failed to load authentication source: unknown type '%s'", authType)
- }
-
- if err = authSource.Section("config").MapTo(loginSource.Cfg); err != nil {
- log.Fatal("Failed to parse authentication source 'config': %v", err)
- }
-
- localLoginSources.sources = append(localLoginSources.sources, loginSource)
- }
-}
-
-// .____ ________ _____ __________
-// | | \______ \ / _ \\______ \
-// | | | | \ / /_\ \| ___/
-// | |___ | ` \/ | \ |
-// |_______ \/_______ /\____|__ /____|
-// \/ \/ \/
-
func composeFullName(firstname, surname, username string) string {
switch {
case len(firstname) == 0 && len(surname) == 0:
@@ -559,7 +77,7 @@ func composeFullName(firstname, surname, username string) string {
// LoginViaLDAP queries if login/password is valid against the LDAP directory pool,
// and create a local user if success when enabled.
func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool) (*User, error) {
- username, fn, sn, mail, isAdmin, succeed := source.Cfg.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
+ username, fn, sn, mail, isAdmin, succeed := source.Config.(*LDAPConfig).SearchEntry(login, password, source.Type == LoginDLDAP)
if !succeed {
// User not in LDAP, do nothing
return nil, ErrUserNotExist{args: map[string]interface{}{"login": login}}
@@ -606,12 +124,18 @@ func LoginViaLDAP(login, password string, source *LoginSource, autoRegister bool
return user, CreateUser(user)
}
-// _________ __________________________
-// / _____/ / \__ ___/\______ \
-// \_____ \ / \ / \| | | ___/
-// / \/ Y \ | | |
-// /_______ /\____|__ /____| |____|
-// \/ \/
+// ***********************
+// ----- SMTP config -----
+// ***********************
+
+type SMTPConfig struct {
+ Auth string
+ Host string
+ Port int
+ AllowedDomains string
+ TLS bool `ini:"tls"`
+ SkipVerify bool
+}
type smtpLoginAuth struct {
username, password string
@@ -634,11 +158,11 @@ func (auth *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
}
const (
- SMTP_PLAIN = "PLAIN"
- SMTP_LOGIN = "LOGIN"
+ SMTPPlain = "PLAIN"
+ SMTPLogin = "LOGIN"
)
-var SMTPAuths = []string{SMTP_PLAIN, SMTP_LOGIN}
+var SMTPAuths = []string{SMTPPlain, SMTPLogin}
func SMTPAuth(a smtp.Auth, cfg *SMTPConfig) error {
c, err := smtp.Dial(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port))
@@ -687,9 +211,9 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
}
var auth smtp.Auth
- if cfg.Auth == SMTP_PLAIN {
+ if cfg.Auth == SMTPPlain {
auth = smtp.PlainAuth("", login, password, cfg.Host)
- } else if cfg.Auth == SMTP_LOGIN {
+ } else if cfg.Auth == SMTPLogin {
auth = &smtpLoginAuth{login, password}
} else {
return nil, errors.New("Unsupported SMTP authentication type")
@@ -729,12 +253,14 @@ func LoginViaSMTP(login, password string, sourceID int64, cfg *SMTPConfig, autoR
return user, CreateUser(user)
}
-// __________ _____ _____
-// \______ \/ _ \ / \
-// | ___/ /_\ \ / \ / \
-// | | / | \/ Y \
-// |____| \____|__ /\____|__ /
-// \/ \/
+// **********************
+// ----- PAM config -----
+// **********************
+
+type PAMConfig struct {
+ // The name of the PAM service, e.g. system-auth.
+ ServiceName string
+}
// LoginViaPAM queries if login/password is valid against the PAM,
// and create a local user if success when enabled.
@@ -763,12 +289,14 @@ func LoginViaPAM(login, password string, sourceID int64, cfg *PAMConfig, autoReg
return user, CreateUser(user)
}
-// ________.__ __ ___ ___ ___.
-// / _____/|__|/ |_ / | \ __ _\_ |__
-// / \ ___| \ __\/ ~ \ | \ __ \
-// \ \_\ \ || | \ Y / | / \_\ \
-// \______ /__||__| \___|_ /|____/|___ /
-// \/ \/ \/
+// *************************
+// ----- GitHub config -----
+// *************************
+
+type GitHubConfig struct {
+ // the GitHub service endpoint, e.g. https://api.github.com/.
+ APIEndpoint string
+}
func LoginViaGitHub(login, password string, sourceID int64, cfg *GitHubConfig, autoRegister bool) (*User, error) {
fullname, email, url, location, err := github.Authenticate(cfg.APIEndpoint, login, password)
@@ -807,11 +335,11 @@ func authenticateViaLoginSource(source *LoginSource, login, password string, aut
case LoginLDAP, LoginDLDAP:
return LoginViaLDAP(login, password, source, autoRegister)
case LoginSMTP:
- return LoginViaSMTP(login, password, source.ID, source.Cfg.(*SMTPConfig), autoRegister)
+ return LoginViaSMTP(login, password, source.ID, source.Config.(*SMTPConfig), autoRegister)
case LoginPAM:
- return LoginViaPAM(login, password, source.ID, source.Cfg.(*PAMConfig), autoRegister)
+ return LoginViaPAM(login, password, source.ID, source.Config.(*PAMConfig), autoRegister)
case LoginGitHub:
- return LoginViaGitHub(login, password, source.ID, source.Cfg.(*GitHubConfig), autoRegister)
+ return LoginViaGitHub(login, password, source.ID, source.Config.(*GitHubConfig), autoRegister)
}
return nil, errors.InvalidLoginSourceType{Type: source.Type}
diff --git a/internal/db/login_source_files.go b/internal/db/login_source_files.go
new file mode 100644
index 00000000..4c5cb377
--- /dev/null
+++ b/internal/db/login_source_files.go
@@ -0,0 +1,212 @@
+// Copyright 2020 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+
+ "github.com/jinzhu/gorm"
+ "github.com/pkg/errors"
+ "gopkg.in/ini.v1"
+
+ "gogs.io/gogs/internal/errutil"
+ "gogs.io/gogs/internal/osutil"
+)
+
+// loginSourceFilesStore is the in-memory interface for login source files stored on file system.
+//
+// NOTE: All methods are sorted in alphabetical order.
+type loginSourceFilesStore interface {
+ // GetByID returns a clone of login source by given ID.
+ GetByID(id int64) (*LoginSource, error)
+ // Len returns number of login sources.
+ Len() int
+ // List returns a list of login sources filtered by options.
+ List(opts ListLoginSourceOpts) []*LoginSource
+ // Update updates in-memory copy of the authentication source.
+ Update(source *LoginSource)
+}
+
+var _ loginSourceFilesStore = (*loginSourceFiles)(nil)
+
+// loginSourceFiles contains authentication sources configured and loaded from local files.
+type loginSourceFiles struct {
+ sync.RWMutex
+ sources []*LoginSource
+}
+
+var _ errutil.NotFound = (*ErrLoginSourceNotExist)(nil)
+
+type ErrLoginSourceNotExist struct {
+ args errutil.Args
+}
+
+func IsErrLoginSourceNotExist(err error) bool {
+ _, ok := err.(ErrLoginSourceNotExist)
+ return ok
+}
+
+func (err ErrLoginSourceNotExist) Error() string {
+ return fmt.Sprintf("login source does not exist: %v", err.args)
+}
+
+func (ErrLoginSourceNotExist) NotFound() bool {
+ return true
+}
+
+func (s *loginSourceFiles) GetByID(id int64) (*LoginSource, error) {
+ s.RLock()
+ defer s.RUnlock()
+
+ for _, source := range s.sources {
+ if source.ID == id {
+ return source, nil
+ }
+ }
+
+ return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
+}
+
+func (s *loginSourceFiles) Len() int {
+ s.RLock()
+ defer s.RUnlock()
+ return len(s.sources)
+}
+
+func (s *loginSourceFiles) List(opts ListLoginSourceOpts) []*LoginSource {
+ s.RLock()
+ defer s.RUnlock()
+
+ list := make([]*LoginSource, 0, s.Len())
+ for _, source := range s.sources {
+ if opts.OnlyActivated && !source.IsActived {
+ continue
+ }
+
+ list = append(list, source)
+ }
+ return list
+}
+
+func (s *loginSourceFiles) Update(source *LoginSource) {
+ s.Lock()
+ defer s.Unlock()
+
+ source.Updated = gorm.NowFunc()
+ for _, old := range s.sources {
+ if old.ID == source.ID {
+ *old = *source
+ } else if source.IsDefault {
+ old.IsDefault = false
+ }
+ }
+}
+
+// loadLoginSourceFiles loads login sources from file system.
+func loadLoginSourceFiles(authdPath string) (loginSourceFilesStore, error) {
+ if !osutil.IsDir(authdPath) {
+ return &loginSourceFiles{}, nil
+ }
+
+ store := &loginSourceFiles{}
+ return store, filepath.Walk(authdPath, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ if path == authdPath || !strings.HasSuffix(path, ".conf") {
+ return nil
+ } else if info.IsDir() {
+ return filepath.SkipDir
+ }
+
+ authSource, err := ini.Load(path)
+ if err != nil {
+ return errors.Wrap(err, "load file")
+ }
+ authSource.NameMapper = ini.TitleUnderscore
+
+ // Set general attributes
+ s := authSource.Section("")
+ loginSource := &LoginSource{
+ ID: s.Key("id").MustInt64(),
+ Name: s.Key("name").String(),
+ IsActived: s.Key("is_activated").MustBool(),
+ IsDefault: s.Key("is_default").MustBool(),
+ File: &loginSourceFile{
+ path: path,
+ file: authSource,
+ },
+ }
+
+ fi, err := os.Stat(path)
+ if err != nil {
+ return errors.Wrap(err, "stat file")
+ }
+ loginSource.Updated = fi.ModTime()
+
+ // Parse authentication source file
+ authType := s.Key("type").String()
+ switch authType {
+ case "ldap_bind_dn":
+ loginSource.Type = LoginLDAP
+ loginSource.Config = &LDAPConfig{}
+ case "ldap_simple_auth":
+ loginSource.Type = LoginDLDAP
+ loginSource.Config = &LDAPConfig{}
+ case "smtp":
+ loginSource.Type = LoginSMTP
+ loginSource.Config = &SMTPConfig{}
+ case "pam":
+ loginSource.Type = LoginPAM
+ loginSource.Config = &PAMConfig{}
+ case "github":
+ loginSource.Type = LoginGitHub
+ loginSource.Config = &GitHubConfig{}
+ default:
+ return fmt.Errorf("unknown type %q", authType)
+ }
+
+ if err = authSource.Section("config").MapTo(loginSource.Config); err != nil {
+ return errors.Wrap(err, `map "config" section`)
+ }
+
+ store.sources = append(store.sources, loginSource)
+ return nil
+ })
+}
+
+// loginSourceFileStore is the persistent interface for a login source file.
+type loginSourceFileStore interface {
+ // SetGeneral sets new value to the given key in the general (default) section.
+ SetGeneral(name, value string)
+ // SetConfig sets new values to the "config" section.
+ SetConfig(cfg interface{}) error
+ // Save persists values to file system.
+ Save() error
+}
+
+var _ loginSourceFileStore = (*loginSourceFile)(nil)
+
+type loginSourceFile struct {
+ path string
+ file *ini.File
+}
+
+func (f *loginSourceFile) SetGeneral(name, value string) {
+ f.file.Section("").Key(name).SetValue(value)
+}
+
+func (f *loginSourceFile) SetConfig(cfg interface{}) error {
+ return f.file.Section("config").ReflectFrom(cfg)
+}
+
+func (f *loginSourceFile) Save() error {
+ return f.file.SaveTo(f.path)
+}
diff --git a/internal/db/login_sources.go b/internal/db/login_sources.go
index 41ee73b9..b620a442 100644
--- a/internal/db/login_sources.go
+++ b/internal/db/login_sources.go
@@ -5,22 +5,242 @@
package db
import (
+ "fmt"
+ "strconv"
+ "time"
+
"github.com/jinzhu/gorm"
+ jsoniter "github.com/json-iterator/go"
+ "github.com/pkg/errors"
+
+ "gogs.io/gogs/internal/auth/ldap"
+ "gogs.io/gogs/internal/errutil"
)
// LoginSourcesStore is the persistent interface for login sources.
//
// NOTE: All methods are sorted in alphabetical order.
type LoginSourcesStore interface {
+ // Create creates a new login source and persist to database.
+ // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
+ Create(opts CreateLoginSourceOpts) (*LoginSource, error)
+ // Count returns the total number of login sources.
+ Count() int64
+ // DeleteByID deletes a login source by given ID.
+ // It returns ErrLoginSourceInUse if at least one user is associated with the login source.
+ DeleteByID(id int64) error
// GetByID returns the login source with given ID.
// It returns ErrLoginSourceNotExist when not found.
GetByID(id int64) (*LoginSource, error)
+ // List returns a list of login sources filtered by options.
+ List(opts ListLoginSourceOpts) ([]*LoginSource, error)
+ // ResetNonDefault clears default flag for all the other login sources.
+ ResetNonDefault(source *LoginSource) error
+ // Save persists all values of given login source to database or local file.
+ // The Updated field is set to current time automatically.
+ Save(t *LoginSource) error
}
var LoginSources LoginSourcesStore
+// LoginSource represents an external way for authorizing users.
+type LoginSource struct {
+ ID int64
+ Type LoginType
+ Name string `xorm:"UNIQUE"`
+ IsActived bool `xorm:"NOT NULL DEFAULT false"`
+ IsDefault bool `xorm:"DEFAULT false"`
+ Config interface{} `xorm:"-" gorm:"-"`
+ RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
+
+ Created time.Time `xorm:"-" gorm:"-" json:"-"`
+ CreatedUnix int64
+ Updated time.Time `xorm:"-" gorm:"-" json:"-"`
+ UpdatedUnix int64
+
+ File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
+}
+
+// NOTE: This is a GORM save hook.
+func (s *LoginSource) BeforeSave() (err error) {
+ s.RawConfig, err = jsoniter.MarshalToString(s.Config)
+ return err
+}
+
+// NOTE: This is a GORM create hook.
+func (s *LoginSource) BeforeCreate() {
+ s.CreatedUnix = gorm.NowFunc().Unix()
+ s.UpdatedUnix = s.CreatedUnix
+}
+
+// NOTE: This is a GORM update hook.
+func (s *LoginSource) BeforeUpdate() {
+ s.UpdatedUnix = gorm.NowFunc().Unix()
+}
+
+// NOTE: This is a GORM query hook.
+func (s *LoginSource) AfterFind() error {
+ s.Created = time.Unix(s.CreatedUnix, 0).Local()
+ s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
+
+ switch s.Type {
+ case LoginLDAP, LoginDLDAP:
+ s.Config = new(LDAPConfig)
+ case LoginSMTP:
+ s.Config = new(SMTPConfig)
+ case LoginPAM:
+ s.Config = new(PAMConfig)
+ case LoginGitHub:
+ s.Config = new(GitHubConfig)
+ default:
+ return fmt.Errorf("unrecognized login source type: %v", s.Type)
+ }
+ return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
+}
+
+func (s *LoginSource) TypeName() string {
+ return LoginNames[s.Type]
+}
+
+func (s *LoginSource) IsLDAP() bool {
+ return s.Type == LoginLDAP
+}
+
+func (s *LoginSource) IsDLDAP() bool {
+ return s.Type == LoginDLDAP
+}
+
+func (s *LoginSource) IsSMTP() bool {
+ return s.Type == LoginSMTP
+}
+
+func (s *LoginSource) IsPAM() bool {
+ return s.Type == LoginPAM
+}
+
+func (s *LoginSource) IsGitHub() bool {
+ return s.Type == LoginGitHub
+}
+
+func (s *LoginSource) HasTLS() bool {
+ return ((s.IsLDAP() || s.IsDLDAP()) &&
+ s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
+ s.IsSMTP()
+}
+
+func (s *LoginSource) UseTLS() bool {
+ switch s.Type {
+ case LoginLDAP, LoginDLDAP:
+ return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
+ case LoginSMTP:
+ return s.SMTP().TLS
+ }
+
+ return false
+}
+
+func (s *LoginSource) SkipVerify() bool {
+ switch s.Type {
+ case LoginLDAP, LoginDLDAP:
+ return s.LDAP().SkipVerify
+ case LoginSMTP:
+ return s.SMTP().SkipVerify
+ }
+
+ return false
+}
+
+func (s *LoginSource) LDAP() *LDAPConfig {
+ return s.Config.(*LDAPConfig)
+}
+
+func (s *LoginSource) SMTP() *SMTPConfig {
+ return s.Config.(*SMTPConfig)
+}
+
+func (s *LoginSource) PAM() *PAMConfig {
+ return s.Config.(*PAMConfig)
+}
+
+func (s *LoginSource) GitHub() *GitHubConfig {
+ return s.Config.(*GitHubConfig)
+}
+
+var _ LoginSourcesStore = (*loginSources)(nil)
+
type loginSources struct {
*gorm.DB
+ files loginSourceFilesStore
+}
+
+type CreateLoginSourceOpts struct {
+ Type LoginType
+ Name string
+ Activated bool
+ Default bool
+ Config interface{}
+}
+
+type ErrLoginSourceAlreadyExist struct {
+ args errutil.Args
+}
+
+func IsErrLoginSourceAlreadyExist(err error) bool {
+ _, ok := err.(ErrLoginSourceAlreadyExist)
+ return ok
+}
+
+func (err ErrLoginSourceAlreadyExist) Error() string {
+ return fmt.Sprintf("login source already exists: %v", err.args)
+}
+
+func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
+ err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
+ if err == nil {
+ return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
+ } else if !gorm.IsRecordNotFoundError(err) {
+ return nil, err
+ }
+
+ source := &LoginSource{
+ Type: opts.Type,
+ Name: opts.Name,
+ IsActived: opts.Activated,
+ IsDefault: opts.Default,
+ Config: opts.Config,
+ }
+ return source, db.DB.Create(source).Error
+}
+
+func (db *loginSources) Count() int64 {
+ var count int64
+ db.Model(new(LoginSource)).Count(&count)
+ return count + int64(db.files.Len())
+}
+
+type ErrLoginSourceInUse struct {
+ args errutil.Args
+}
+
+func IsErrLoginSourceInUse(err error) bool {
+ _, ok := err.(ErrLoginSourceInUse)
+ return ok
+}
+
+func (err ErrLoginSourceInUse) Error() string {
+ return fmt.Sprintf("login source is still used by some users: %v", err.args)
+}
+
+func (db *loginSources) DeleteByID(id int64) error {
+ var count int64
+ err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
+ if err != nil {
+ return err
+ } else if count > 0 {
+ return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
+ }
+
+ return db.Where("id = ?", id).Delete(new(LoginSource)).Error
}
func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
@@ -28,9 +248,63 @@ func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
err := db.Where("id = ?", id).First(source).Error
if err != nil {
if gorm.IsRecordNotFoundError(err) {
- return localLoginSources.GetLoginSourceByID(id)
+ return db.files.GetByID(id)
}
return nil, err
}
return source, nil
}
+
+type ListLoginSourceOpts struct {
+ // Whether to only include activated login sources.
+ OnlyActivated bool
+}
+
+func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
+ var sources []*LoginSource
+ query := db.Order("id ASC")
+ if opts.OnlyActivated {
+ query = query.Where("is_actived = ?", true)
+ }
+ err := query.Find(&sources).Error
+ if err != nil {
+ return nil, err
+ }
+
+ return append(sources, db.files.List(opts)...), nil
+}
+
+func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
+ err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
+ if err != nil {
+ return err
+ }
+
+ for _, source := range db.files.List(ListLoginSourceOpts{}) {
+ if source.File != nil && source.ID != dflt.ID {
+ source.File.SetGeneral("is_default", "false")
+ if err = source.File.Save(); err != nil {
+ return errors.Wrap(err, "save file")
+ }
+ }
+ }
+
+ db.files.Update(dflt)
+ return nil
+}
+
+func (db *loginSources) Save(source *LoginSource) error {
+ if source.File == nil {
+ return db.DB.Save(source).Error
+ }
+
+ source.File.SetGeneral("name", source.Name)
+ source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
+ source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
+ if err := source.File.SetConfig(source.Config); err != nil {
+ return errors.Wrap(err, "set config")
+ } else if err = source.File.Save(); err != nil {
+ return errors.Wrap(err, "save file")
+ }
+ return nil
+}
diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go
new file mode 100644
index 00000000..8cc0ab19
--- /dev/null
+++ b/internal/db/login_sources_test.go
@@ -0,0 +1,389 @@
+// Copyright 2020 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import (
+ "testing"
+ "time"
+
+ "github.com/jinzhu/gorm"
+ "github.com/stretchr/testify/assert"
+
+ "gogs.io/gogs/internal/errutil"
+)
+
+func Test_loginSources(t *testing.T) {
+ if testing.Short() {
+ t.Skip()
+ }
+
+ t.Parallel()
+
+ tables := []interface{}{new(LoginSource), new(User)}
+ db := &loginSources{
+ DB: initTestDB(t, "loginSources", tables...),
+ }
+
+ for _, tc := range []struct {
+ name string
+ test func(*testing.T, *loginSources)
+ }{
+ {"Create", test_loginSources_Create},
+ {"Count", test_loginSources_Count},
+ {"DeleteByID", test_loginSources_DeleteByID},
+ {"GetByID", test_loginSources_GetByID},
+ {"List", test_loginSources_List},
+ {"ResetNonDefault", test_loginSources_ResetNonDefault},
+ {"Save", test_loginSources_Save},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Cleanup(func() {
+ err := clearTables(db.DB, tables...)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ tc.test(t, db)
+ })
+ }
+}
+
+func test_loginSources_Create(t *testing.T, db *loginSources) {
+ // Create first login source with name "GitHub"
+ source, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: false,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Get it back and check the Created field
+ source, err = db.GetByID(source.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Created.Format(time.RFC3339))
+ assert.Equal(t, gorm.NowFunc().Format(time.RFC3339), source.Updated.Format(time.RFC3339))
+
+ // Try create second login source with same name should fail
+ _, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
+ expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
+ assert.Equal(t, expErr, err)
+}
+
+func test_loginSources_Count(t *testing.T, db *loginSources) {
+ // Create two login sources, one in database and one as source file.
+ _, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: false,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
+ MockLen: func() int {
+ return 2
+ },
+ })
+
+ assert.Equal(t, int64(3), db.Count())
+}
+
+func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
+ t.Run("delete but in used", func(t *testing.T) {
+ source, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: false,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a user that uses this login source
+ user := &User{
+ LoginSource: source.ID,
+ }
+ err = db.DB.Create(user).Error
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Delete the login source will result in error
+ err = db.DeleteByID(source.ID)
+ expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
+ assert.Equal(t, expErr, err)
+ })
+
+ setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
+ MockGetByID: func(id int64) (*LoginSource, error) {
+ return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
+ },
+ })
+
+ // Create a login source with name "GitHub2"
+ source, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub2",
+ Activated: true,
+ Default: false,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Delete a non-existent ID is noop
+ err = db.DeleteByID(9999)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We should be able to get it back
+ _, err = db.GetByID(source.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Now delete this login source with ID
+ err = db.DeleteByID(source.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // We should get token not found error
+ _, err = db.GetByID(source.ID)
+ expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
+ assert.Equal(t, expErr, err)
+}
+
+func test_loginSources_GetByID(t *testing.T, db *loginSources) {
+ setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
+ MockGetByID: func(id int64) (*LoginSource, error) {
+ if id != 101 {
+ return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
+ }
+ return &LoginSource{ID: id}, nil
+ },
+ })
+
+ expConfig := &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ }
+
+ // Create a login source with name "GitHub"
+ source, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: false,
+ Config: expConfig,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Get the one in the database and test the read/write hooks
+ source, err = db.GetByID(source.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, expConfig, source.Config)
+
+ // Get the one in source file store
+ _, err = db.GetByID(101)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func test_loginSources_List(t *testing.T, db *loginSources) {
+ setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
+ MockList: func(opts ListLoginSourceOpts) []*LoginSource {
+ if opts.OnlyActivated {
+ return []*LoginSource{
+ {ID: 1},
+ }
+ }
+ return []*LoginSource{
+ {ID: 1},
+ {ID: 2},
+ }
+ },
+ })
+
+ // Create two login sources in database, one activated and the other one not
+ _, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginPAM,
+ Name: "PAM",
+ Config: &PAMConfig{
+ ServiceName: "PAM",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // List all login sources
+ sources, err := db.List(ListLoginSourceOpts{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, 4, len(sources), "number of sources")
+
+ // Only list activated login sources
+ sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, 2, len(sources), "number of sources")
+}
+
+func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
+ setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
+ MockList: func(opts ListLoginSourceOpts) []*LoginSource {
+ return []*LoginSource{
+ {
+ File: &mockLoginSourceFileStore{
+ MockSetGeneral: func(name, value string) {
+ assert.Equal(t, "is_default", name)
+ assert.Equal(t, "false", value)
+ },
+ MockSave: func() error {
+ return nil
+ },
+ },
+ },
+ }
+ },
+ MockUpdate: func(source *LoginSource) {},
+ })
+
+ // Create two login sources both have default on
+ source1, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginPAM,
+ Name: "PAM",
+ Default: true,
+ Config: &PAMConfig{
+ ServiceName: "PAM",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ source2, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: true,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Set source 1 as default
+ err = db.ResetNonDefault(source1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Verify the default state
+ source1, err = db.GetByID(source1.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.True(t, source1.IsDefault)
+
+ source2, err = db.GetByID(source2.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.False(t, source2.IsDefault)
+}
+
+func test_loginSources_Save(t *testing.T, db *loginSources) {
+ t.Run("save to database", func(t *testing.T) {
+ // Create a login source with name "GitHub"
+ source, err := db.Create(CreateLoginSourceOpts{
+ Type: LoginGitHub,
+ Name: "GitHub",
+ Activated: true,
+ Default: false,
+ Config: &GitHubConfig{
+ APIEndpoint: "https://api.github.com",
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ source.IsActived = false
+ source.Config = &GitHubConfig{
+ APIEndpoint: "https://api2.github.com",
+ }
+ err = db.Save(source)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ source, err = db.GetByID(source.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.False(t, source.IsActived)
+ assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
+ })
+
+ t.Run("save to file", func(t *testing.T) {
+ calledSave := false
+ source := &LoginSource{
+ File: &mockLoginSourceFileStore{
+ MockSetGeneral: func(name, value string) {},
+ MockSetConfig: func(cfg interface{}) error { return nil },
+ MockSave: func() error {
+ calledSave = true
+ return nil
+ },
+ },
+ }
+ err := db.Save(source)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.True(t, calledSave)
+ })
+}
diff --git a/internal/db/main_test.go b/internal/db/main_test.go
index fdd7447f..172f00fe 100644
--- a/internal/db/main_test.go
+++ b/internal/db/main_test.go
@@ -41,7 +41,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
-func deleteTables(db *gorm.DB, tables ...interface{}) error {
+// clearTables removes all rows from given tables.
+func clearTables(db *gorm.DB, tables ...interface{}) error {
for _, t := range tables {
err := db.Delete(t).Error
if err != nil {
diff --git a/internal/db/mocks.go b/internal/db/mocks.go
index 7e6ddced..9b1e9028 100644
--- a/internal/db/mocks.go
+++ b/internal/db/mocks.go
@@ -78,6 +78,59 @@ func SetMockLFSStore(t *testing.T, mock LFSStore) {
})
}
+var _ loginSourceFilesStore = (*mockLoginSourceFilesStore)(nil)
+
+type mockLoginSourceFilesStore struct {
+ MockGetByID func(id int64) (*LoginSource, error)
+ MockLen func() int
+ MockList func(opts ListLoginSourceOpts) []*LoginSource
+ MockUpdate func(source *LoginSource)
+}
+
+func (m *mockLoginSourceFilesStore) GetByID(id int64) (*LoginSource, error) {
+ return m.MockGetByID(id)
+}
+
+func (m *mockLoginSourceFilesStore) Len() int {
+ return m.MockLen()
+}
+
+func (m *mockLoginSourceFilesStore) List(opts ListLoginSourceOpts) []*LoginSource {
+ return m.MockList(opts)
+}
+
+func (m *mockLoginSourceFilesStore) Update(source *LoginSource) {
+ m.MockUpdate(source)
+}
+
+func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
+ before := db.files
+ db.files = mock
+ t.Cleanup(func() {
+ db.files = before
+ })
+}
+
+var _ loginSourceFileStore = (*mockLoginSourceFileStore)(nil)
+
+type mockLoginSourceFileStore struct {
+ MockSetGeneral func(name, value string)
+ MockSetConfig func(cfg interface{}) error
+ MockSave func() error
+}
+
+func (m *mockLoginSourceFileStore) SetGeneral(name, value string) {
+ m.MockSetGeneral(name, value)
+}
+
+func (m *mockLoginSourceFileStore) SetConfig(cfg interface{}) error {
+ return m.MockSetConfig(cfg)
+}
+
+func (m *mockLoginSourceFileStore) Save() error {
+ return m.MockSave()
+}
+
var _ PermsStore = (*MockPermsStore)(nil)
type MockPermsStore struct {
diff --git a/internal/db/models.go b/internal/db/models.go
index 9b5b0d9a..37783a13 100644
--- a/internal/db/models.go
+++ b/internal/db/models.go
@@ -53,7 +53,7 @@ func init() {
new(Watch), new(Star), new(Follow), new(Action),
new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser),
new(Label), new(IssueLabel), new(Milestone),
- new(Mirror), new(Release), new(LoginSource), new(Webhook), new(HookTask),
+ new(Mirror), new(Release), new(Webhook), new(HookTask),
new(ProtectBranch), new(ProtectBranchWhitelist),
new(Team), new(OrgUser), new(TeamUser), new(TeamRepo),
new(Notice), new(EmailAddress))
@@ -200,7 +200,7 @@ func GetStatistic() (stats Statistic) {
stats.Counter.Follow, _ = x.Count(new(Follow))
stats.Counter.Mirror, _ = x.Count(new(Mirror))
stats.Counter.Release, _ = x.Count(new(Release))
- stats.Counter.LoginSource = CountLoginSources()
+ stats.Counter.LoginSource = LoginSources.Count()
stats.Counter.Webhook, _ = x.Count(new(Webhook))
stats.Counter.Milestone, _ = x.Count(new(Milestone))
stats.Counter.Label, _ = x.Count(new(Label))
@@ -295,13 +295,13 @@ func ImportDatabase(dirPath string, verbose bool) (err error) {
tp := LoginType(com.StrTo(com.ToStr(meta["Type"])).MustInt64())
switch tp {
case LoginLDAP, LoginDLDAP:
- bean.Cfg = new(LDAPConfig)
+ bean.Config = new(LDAPConfig)
case LoginSMTP:
- bean.Cfg = new(SMTPConfig)
+ bean.Config = new(SMTPConfig)
case LoginPAM:
- bean.Cfg = new(PAMConfig)
+ bean.Config = new(PAMConfig)
case LoginGitHub:
- bean.Cfg = new(GitHubConfig)
+ bean.Config = new(GitHubConfig)
default:
return fmt.Errorf("unrecognized login source type:: %v", tp)
}
diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go
index 9f3f4b1b..ea4ed1ea 100644
--- a/internal/db/perms_test.go
+++ b/internal/db/perms_test.go
@@ -17,8 +17,9 @@ func Test_perms(t *testing.T) {
t.Parallel()
+ tables := []interface{}{new(Access)}
db := &perms{
- DB: initTestDB(t, "perms", new(Access)),
+ DB: initTestDB(t, "perms", tables...),
}
for _, tc := range []struct {
@@ -31,7 +32,7 @@ func Test_perms(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
- err := deleteTables(db.DB, new(Access))
+ err := clearTables(db.DB, tables...)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/db/user.go b/internal/db/user.go
index a61bb687..11697bc7 100644
--- a/internal/db/user.go
+++ b/internal/db/user.go
@@ -48,33 +48,33 @@ const (
// User represents the object of individual and member of organization.
type User struct {
ID int64
- LowerName string `xorm:"UNIQUE NOT NULL"`
- Name string `xorm:"UNIQUE NOT NULL"`
+ LowerName string `xorm:"UNIQUE NOT NULL" gorm:"UNIQUE"`
+ Name string `xorm:"UNIQUE NOT NULL" gorm:"NOT NULL"`
FullName string
// Email is the primary email address (to be used for communication)
- Email string `xorm:"NOT NULL"`
- Passwd string `xorm:"NOT NULL"`
+ Email string `xorm:"NOT NULL" gorm:"NOT NULL"`
+ Passwd string `xorm:"NOT NULL" gorm:"NOT NULL"`
LoginType LoginType
- LoginSource int64 `xorm:"NOT NULL DEFAULT 0"`
+ LoginSource int64 `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
LoginName string
Type UserType
- OwnedOrgs []*User `xorm:"-" json:"-"`
- Orgs []*User `xorm:"-" json:"-"`
- Repos []*Repository `xorm:"-" json:"-"`
+ OwnedOrgs []*User `xorm:"-" gorm:"-" json:"-"`
+ Orgs []*User `xorm:"-" gorm:"-" json:"-"`
+ Repos []*Repository `xorm:"-" gorm:"-" json:"-"`
Location string
Website string
- Rands string `xorm:"VARCHAR(10)"`
- Salt string `xorm:"VARCHAR(10)"`
+ Rands string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
+ Salt string `xorm:"VARCHAR(10)" gorm:"TYPE:VARCHAR(10)"`
- Created time.Time `xorm:"-" json:"-"`
+ Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
- Updated time.Time `xorm:"-" json:"-"`
+ Updated time.Time `xorm:"-" gorm:"-" json:"-"`
UpdatedUnix int64
// Remember visibility choice for convenience, true for private
LastRepoVisibility bool
- // Maximum repository creation limit, -1 means use gloabl default
- MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1"`
+ // Maximum repository creation limit, -1 means use global default
+ MaxRepoCreation int `xorm:"NOT NULL DEFAULT -1" gorm:"NOT NULL;DEFAULT:-1"`
// Permissions
IsActive bool // Activate primary email
@@ -84,13 +84,13 @@ type User struct {
ProhibitLogin bool
// Avatar
- Avatar string `xorm:"VARCHAR(2048) NOT NULL"`
- AvatarEmail string `xorm:"NOT NULL"`
+ Avatar string `xorm:"VARCHAR(2048) NOT NULL" gorm:"TYPE:VARCHAR(2048);NOT NULL"`
+ AvatarEmail string `xorm:"NOT NULL" gorm:"NOT NULL"`
UseCustomAvatar bool
// Counters
NumFollowers int
- NumFollowing int `xorm:"NOT NULL DEFAULT 0"`
+ NumFollowing int `xorm:"NOT NULL DEFAULT 0" gorm:"NOT NULL;DEFAULT:0"`
NumStars int
NumRepos int
@@ -98,8 +98,8 @@ type User struct {
Description string
NumTeams int
NumMembers int
- Teams []*Team `xorm:"-" json:"-"`
- Members []*User `xorm:"-" json:"-"`
+ Teams []*Team `xorm:"-" gorm:"-" json:"-"`
+ Members []*User `xorm:"-" gorm:"-" json:"-"`
}
func (u *User) BeforeInsert() {