aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/auth/admin.go4
-rw-r--r--modules/auth/auth.go24
-rw-r--r--modules/auth/issue.go4
-rw-r--r--modules/auth/release.go50
-rw-r--r--modules/auth/repo.go46
-rw-r--r--modules/auth/setting.go4
-rw-r--r--modules/auth/user.go5
-rw-r--r--modules/avatar/avatar.go5
-rw-r--r--modules/base/base.go48
-rw-r--r--modules/base/base_memcache.go11
-rw-r--r--modules/base/base_redis.go11
-rw-r--r--modules/base/conf.go112
-rw-r--r--modules/base/markdown.go57
-rw-r--r--modules/base/template.go136
-rw-r--r--modules/base/tool.go149
-rw-r--r--modules/cron/cron.go17
-rw-r--r--modules/log/log.go3
-rw-r--r--modules/mailer/mail.go53
-rw-r--r--modules/middleware/auth.go2
-rw-r--r--modules/middleware/binding.go426
-rw-r--r--modules/middleware/binding_test.go701
-rw-r--r--modules/middleware/context.go86
-rw-r--r--modules/middleware/render.go2
-rw-r--r--modules/middleware/repo.go96
-rw-r--r--modules/social/social.go396
25 files changed, 2225 insertions, 223 deletions
diff --git a/modules/auth/admin.go b/modules/auth/admin.go
index fe889c23..877af19a 100644
--- a/modules/auth/admin.go
+++ b/modules/auth/admin.go
@@ -10,8 +10,6 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
-
"github.com/gogits/gogs/modules/base"
"github.com/gogits/gogs/modules/log"
)
@@ -35,7 +33,7 @@ func (f *AdminEditUserForm) Name(field string) string {
return names[field]
}
-func (f *AdminEditUserForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *AdminEditUserForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
diff --git a/modules/auth/auth.go b/modules/auth/auth.go
index 4561dd83..350ef4fc 100644
--- a/modules/auth/auth.go
+++ b/modules/auth/auth.go
@@ -11,8 +11,6 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
-
"github.com/gogits/gogs/modules/base"
"github.com/gogits/gogs/modules/log"
)
@@ -39,7 +37,7 @@ func (f *RegisterForm) Name(field string) string {
return names[field]
}
-func (f *RegisterForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *RegisterForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
@@ -72,7 +70,7 @@ func (f *LogInForm) Name(field string) string {
return names[field]
}
-func (f *LogInForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *LogInForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
@@ -100,7 +98,7 @@ func getMinMaxSize(field reflect.StructField) string {
return ""
}
-func validate(errors *binding.Errors, data base.TmplData, form Form) {
+func validate(errors *base.BindingErrors, data base.TmplData, form Form) {
typ := reflect.TypeOf(form)
val := reflect.ValueOf(form)
@@ -121,16 +119,18 @@ func validate(errors *binding.Errors, data base.TmplData, form Form) {
if err, ok := errors.Fields[field.Name]; ok {
data["Err_"+field.Name] = true
switch err {
- case binding.RequireError:
+ case base.BindingRequireError:
data["ErrorMsg"] = form.Name(field.Name) + " cannot be empty"
- case binding.AlphaDashError:
+ case base.BindingAlphaDashError:
data["ErrorMsg"] = form.Name(field.Name) + " must be valid alpha or numeric or dash(-_) characters"
- case binding.MinSizeError:
+ case base.BindingMinSizeError:
data["ErrorMsg"] = form.Name(field.Name) + " must contain at least " + getMinMaxSize(field) + " characters"
- case binding.MaxSizeError:
+ case base.BindingMaxSizeError:
data["ErrorMsg"] = form.Name(field.Name) + " must contain at most " + getMinMaxSize(field) + " characters"
- case binding.EmailError:
- data["ErrorMsg"] = form.Name(field.Name) + " is not valid"
+ case base.BindingEmailError:
+ data["ErrorMsg"] = form.Name(field.Name) + " is not a valid e-mail address"
+ case base.BindingUrlError:
+ data["ErrorMsg"] = form.Name(field.Name) + " is not a valid URL"
default:
data["ErrorMsg"] = "Unknown error: " + err
}
@@ -194,7 +194,7 @@ func (f *InstallForm) Name(field string) string {
return names[field]
}
-func (f *InstallForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *InstallForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
diff --git a/modules/auth/issue.go b/modules/auth/issue.go
index 36c87627..f73ddc74 100644
--- a/modules/auth/issue.go
+++ b/modules/auth/issue.go
@@ -10,8 +10,6 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
-
"github.com/gogits/gogs/modules/base"
"github.com/gogits/gogs/modules/log"
)
@@ -31,7 +29,7 @@ func (f *CreateIssueForm) Name(field string) string {
return names[field]
}
-func (f *CreateIssueForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *CreateIssueForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
diff --git a/modules/auth/release.go b/modules/auth/release.go
new file mode 100644
index 00000000..a29028e0
--- /dev/null
+++ b/modules/auth/release.go
@@ -0,0 +1,50 @@
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "net/http"
+ "reflect"
+
+ "github.com/go-martini/martini"
+
+ "github.com/gogits/gogs/modules/base"
+ "github.com/gogits/gogs/modules/log"
+)
+
+type NewReleaseForm struct {
+ TagName string `form:"tag_name" binding:"Required"`
+ Title string `form:"title" binding:"Required"`
+ Content string `form:"content" binding:"Required"`
+ Prerelease bool `form:"prerelease"`
+}
+
+func (f *NewReleaseForm) Name(field string) string {
+ names := map[string]string{
+ "TagName": "Tag name",
+ "Title": "Release title",
+ "Content": "Release content",
+ }
+ return names[field]
+}
+
+func (f *NewReleaseForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
+ if req.Method == "GET" || errors.Count() == 0 {
+ return
+ }
+
+ data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData)
+ data["HasError"] = true
+ AssignForm(f, data)
+
+ if len(errors.Overall) > 0 {
+ for _, err := range errors.Overall {
+ log.Error("NewReleaseForm.Validate: %v", err)
+ }
+ return
+ }
+
+ validate(errors, data, f)
+}
diff --git a/modules/auth/repo.go b/modules/auth/repo.go
index eddd6475..f67fbf67 100644
--- a/modules/auth/repo.go
+++ b/modules/auth/repo.go
@@ -10,19 +10,17 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
-
"github.com/gogits/gogs/modules/base"
"github.com/gogits/gogs/modules/log"
)
type CreateRepoForm struct {
RepoName string `form:"repo" binding:"Required;AlphaDash"`
- Visibility string `form:"visibility"`
+ Private bool `form:"private"`
Description string `form:"desc" binding:"MaxSize(100)"`
Language string `form:"language"`
License string `form:"license"`
- InitReadme string `form:"initReadme"`
+ InitReadme bool `form:"initReadme"`
}
func (f *CreateRepoForm) Name(field string) string {
@@ -33,7 +31,7 @@ func (f *CreateRepoForm) Name(field string) string {
return names[field]
}
-func (f *CreateRepoForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *CreateRepoForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
@@ -51,3 +49,41 @@ func (f *CreateRepoForm) Validate(errors *binding.Errors, req *http.Request, con
validate(errors, data, f)
}
+
+type MigrateRepoForm struct {
+ Url string `form:"url" binding:"Url"`
+ AuthUserName string `form:"auth_username"`
+ AuthPasswd string `form:"auth_password"`
+ RepoName string `form:"repo" binding:"Required;AlphaDash"`
+ Mirror bool `form:"mirror"`
+ Private bool `form:"private"`
+ Description string `form:"desc" binding:"MaxSize(100)"`
+}
+
+func (f *MigrateRepoForm) Name(field string) string {
+ names := map[string]string{
+ "Url": "Migration URL",
+ "RepoName": "Repository name",
+ "Description": "Description",
+ }
+ return names[field]
+}
+
+func (f *MigrateRepoForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
+ if req.Method == "GET" || errors.Count() == 0 {
+ return
+ }
+
+ data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData)
+ data["HasError"] = true
+ AssignForm(f, data)
+
+ if len(errors.Overall) > 0 {
+ for _, err := range errors.Overall {
+ log.Error("MigrateRepoForm.Validate: %v", err)
+ }
+ return
+ }
+
+ validate(errors, data, f)
+}
diff --git a/modules/auth/setting.go b/modules/auth/setting.go
index cada7eea..7cee00de 100644
--- a/modules/auth/setting.go
+++ b/modules/auth/setting.go
@@ -11,8 +11,6 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
-
"github.com/gogits/gogs/modules/base"
"github.com/gogits/gogs/modules/log"
)
@@ -30,7 +28,7 @@ func (f *AddSSHKeyForm) Name(field string) string {
return names[field]
}
-func (f *AddSSHKeyForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *AddSSHKeyForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData)
AssignForm(f, data)
diff --git a/modules/auth/user.go b/modules/auth/user.go
index 015059f7..97389422 100644
--- a/modules/auth/user.go
+++ b/modules/auth/user.go
@@ -10,7 +10,6 @@ import (
"github.com/go-martini/martini"
- "github.com/gogits/binding"
"github.com/gogits/session"
"github.com/gogits/gogs/models"
@@ -93,7 +92,7 @@ func (f *UpdateProfileForm) Name(field string) string {
return names[field]
}
-func (f *UpdateProfileForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *UpdateProfileForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
@@ -126,7 +125,7 @@ func (f *UpdatePasswdForm) Name(field string) string {
return names[field]
}
-func (f *UpdatePasswdForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) {
+func (f *UpdatePasswdForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) {
if req.Method == "GET" || errors.Count() == 0 {
return
}
diff --git a/modules/avatar/avatar.go b/modules/avatar/avatar.go
index edeb256f..5ed5d16a 100644
--- a/modules/avatar/avatar.go
+++ b/modules/avatar/avatar.go
@@ -157,9 +157,9 @@ func (this *service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
avatar := New(hash, this.cacheDir)
avatar.AlterImage = this.altImage
if avatar.Expired() {
- err := avatar.UpdateTimeout(time.Millisecond * 500)
- if err != nil {
+ if err := avatar.UpdateTimeout(time.Millisecond * 1000); err != nil {
log.Trace("avatar update error: %v", err)
+ return
}
}
if modtime, err := avatar.Modtime(); err == nil {
@@ -250,6 +250,7 @@ func (this *thunderTask) Fetch() {
var client = &http.Client{}
func (this *thunderTask) fetch() error {
+ log.Debug("avatar.fetch(fetch new avatar): %s", this.Url)
req, _ := http.NewRequest("GET", this.Url, nil)
req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/jpeg,image/png,*/*;q=0.8")
req.Header.Set("Accept-Encoding", "deflate,sdch")
diff --git a/modules/base/base.go b/modules/base/base.go
index 7c08dcc5..3e80a436 100644
--- a/modules/base/base.go
+++ b/modules/base/base.go
@@ -8,3 +8,51 @@ type (
// Type TmplData represents data in the templates.
TmplData map[string]interface{}
)
+
+// __________.__ .___.__
+// \______ \__| ____ __| _/|__| ____ ____
+// | | _/ |/ \ / __ | | |/ \ / ___\
+// | | \ | | \/ /_/ | | | | \/ /_/ >
+// |______ /__|___| /\____ | |__|___| /\___ /
+// \/ \/ \/ \//_____/
+
+// Errors represents the contract of the response body when the
+// binding step fails before getting to the application.
+type BindingErrors struct {
+ Overall map[string]string `json:"overall"`
+ Fields map[string]string `json:"fields"`
+}
+
+// Total errors is the sum of errors with the request overall
+// and errors on individual fields.
+func (err BindingErrors) Count() int {
+ return len(err.Overall) + len(err.Fields)
+}
+
+func (this *BindingErrors) Combine(other BindingErrors) {
+ for key, val := range other.Fields {
+ if _, exists := this.Fields[key]; !exists {
+ this.Fields[key] = val
+ }
+ }
+ for key, val := range other.Overall {
+ if _, exists := this.Overall[key]; !exists {
+ this.Overall[key] = val
+ }
+ }
+}
+
+const (
+ BindingRequireError string = "Required"
+ BindingAlphaDashError string = "AlphaDash"
+ BindingMinSizeError string = "MinSize"
+ BindingMaxSizeError string = "MaxSize"
+ BindingEmailError string = "Email"
+ BindingUrlError string = "Url"
+ BindingDeserializationError string = "DeserializationError"
+ BindingIntegerTypeError string = "IntegerTypeError"
+ BindingBooleanTypeError string = "BooleanTypeError"
+ BindingFloatTypeError string = "FloatTypeError"
+)
+
+var GoGetMetas = make(map[string]bool)
diff --git a/modules/base/base_memcache.go b/modules/base/base_memcache.go
new file mode 100644
index 00000000..93ca93a4
--- /dev/null
+++ b/modules/base/base_memcache.go
@@ -0,0 +1,11 @@
+// +build memcache
+
+package base
+
+import (
+ _ "github.com/gogits/cache/memcache"
+)
+
+func init() {
+ EnableMemcache = true
+}
diff --git a/modules/base/base_redis.go b/modules/base/base_redis.go
new file mode 100644
index 00000000..327af841
--- /dev/null
+++ b/modules/base/base_redis.go
@@ -0,0 +1,11 @@
+// +build redis
+
+package base
+
+import (
+ _ "github.com/gogits/cache/redis"
+)
+
+func init() {
+ EnableRedis = true
+}
diff --git a/modules/base/conf.go b/modules/base/conf.go
index 3ebc4ede..abb67f6d 100644
--- a/modules/base/conf.go
+++ b/modules/base/conf.go
@@ -14,6 +14,7 @@ import (
"github.com/Unknwon/com"
"github.com/Unknwon/goconfig"
+ qlog "github.com/qiniu/log"
"github.com/gogits/cache"
"github.com/gogits/session"
@@ -21,22 +22,38 @@ import (
"github.com/gogits/gogs/modules/log"
)
-// Mailer represents a mail service.
+// Mailer represents mail service.
type Mailer struct {
Name string
Host string
User, Passwd string
}
+type OauthInfo struct {
+ ClientId, ClientSecret string
+ Scopes string
+ AuthUrl, TokenUrl string
+}
+
+// Oauther represents oauth service.
+type Oauther struct {
+ GitHub, Google, Tencent,
+ Twitter, Weibo bool
+ OauthInfos map[string]*OauthInfo
+}
+
var (
- AppVer string
- AppName string
- AppLogo string
- AppUrl string
- Domain string
- SecretKey string
- RunUser string
+ AppVer string
+ AppName string
+ AppLogo string
+ AppUrl string
+ IsProdMode bool
+ Domain string
+ SecretKey string
+ RunUser string
+
RepoRootPath string
+ ScriptType string
InstallLock bool
@@ -44,8 +61,9 @@ var (
CookieUserName string
CookieRememberName string
- Cfg *goconfig.ConfigFile
- MailService *Mailer
+ Cfg *goconfig.ConfigFile
+ MailService *Mailer
+ OauthService *Oauther
LogMode string
LogConfig string
@@ -59,11 +77,14 @@ var (
SessionManager *session.Manager
PictureService string
+
+ EnableRedis bool
+ EnableMemcache bool
)
var Service struct {
RegisterEmailConfirm bool
- DisenableRegisteration bool
+ DisableRegistration bool
RequireSignInView bool
EnableCacheAvatar bool
NotifyMail bool
@@ -95,7 +116,7 @@ var logLevels = map[string]string{
func newService() {
Service.ActiveCodeLives = Cfg.MustInt("service", "ACTIVE_CODE_LIVE_MINUTES", 180)
Service.ResetPwdCodeLives = Cfg.MustInt("service", "RESET_PASSWD_CODE_LIVE_MINUTES", 180)
- Service.DisenableRegisteration = Cfg.MustBool("service", "DISENABLE_REGISTERATION", false)
+ Service.DisableRegistration = Cfg.MustBool("service", "DISABLE_REGISTRATION", false)
Service.RequireSignInView = Cfg.MustBool("service", "REQUIRE_SIGNIN_VIEW", false)
Service.EnableCacheAvatar = Cfg.MustBool("service", "ENABLE_CACHE_AVATAR", false)
}
@@ -105,16 +126,14 @@ func newLogService() {
LogMode = Cfg.MustValue("log", "MODE", "console")
modeSec := "log." + LogMode
if _, err := Cfg.GetSection(modeSec); err != nil {
- fmt.Printf("Unknown log mode: %s\n", LogMode)
- os.Exit(2)
+ qlog.Fatalf("Unknown log mode: %s\n", LogMode)
}
// Log level.
levelName := Cfg.MustValue("log."+LogMode, "LEVEL", "Trace")
level, ok := logLevels[levelName]
if !ok {
- fmt.Printf("Unknown log level: %s\n", levelName)
- os.Exit(2)
+ qlog.Fatalf("Unknown log level: %s\n", levelName)
}
// Generate log configuration.
@@ -151,12 +170,19 @@ func newLogService() {
Cfg.MustValue(modeSec, "CONN"))
}
+ log.Info("%s %s", AppName, AppVer)
log.NewLogger(Cfg.MustInt64("log", "BUFFER_LEN", 10000), LogMode, LogConfig)
log.Info("Log Mode: %s(%s)", strings.Title(LogMode), levelName)
}
func newCacheService() {
CacheAdapter = Cfg.MustValue("cache", "ADAPTER", "memory")
+ if EnableRedis {
+ log.Info("Redis Enabled")
+ }
+ if EnableMemcache {
+ log.Info("Memcache Enabled")
+ }
switch CacheAdapter {
case "memory":
@@ -164,16 +190,14 @@ func newCacheService() {
case "redis", "memcache":
CacheConfig = fmt.Sprintf(`{"conn":"%s"}`, Cfg.MustValue("cache", "HOST"))
default:
- fmt.Printf("Unknown cache adapter: %s\n", CacheAdapter)
- os.Exit(2)
+ qlog.Fatalf("Unknown cache adapter: %s\n", CacheAdapter)
}
var err error
Cache, err = cache.NewCache(CacheAdapter, CacheConfig)
if err != nil {
- fmt.Printf("Init cache system failed, adapter: %s, config: %s, %v\n",
+ qlog.Fatalf("Init cache system failed, adapter: %s, config: %s, %v\n",
CacheAdapter, CacheConfig, err)
- os.Exit(2)
}
log.Info("Cache Service Enabled")
@@ -199,9 +223,8 @@ func newSessionService() {
var err error
SessionManager, err = session.NewManager(SessionProvider, *SessionConfig)
if err != nil {
- fmt.Printf("Init session system failed, provider: %s, %v\n",
+ qlog.Fatalf("Init session system failed, provider: %s, %v\n",
SessionProvider, err)
- os.Exit(2)
}
log.Info("Session Service Enabled")
@@ -209,15 +232,17 @@ func newSessionService() {
func newMailService() {
// Check mailer setting.
- if Cfg.MustBool("mailer", "ENABLED") {
- MailService = &Mailer{
- Name: Cfg.MustValue("mailer", "NAME", AppName),
- Host: Cfg.MustValue("mailer", "HOST"),
- User: Cfg.MustValue("mailer", "USER"),
- Passwd: Cfg.MustValue("mailer", "PASSWD"),
- }
- log.Info("Mail Service Enabled")
+ if !Cfg.MustBool("mailer", "ENABLED") {
+ return
+ }
+
+ MailService = &Mailer{
+ Name: Cfg.MustValue("mailer", "NAME", AppName),
+ Host: Cfg.MustValue("mailer", "HOST"),
+ User: Cfg.MustValue("mailer", "USER"),
+ Passwd: Cfg.MustValue("mailer", "PASSWD"),
}
+ log.Info("Mail Service Enabled")
}
func newRegisterMailService() {
@@ -246,23 +271,20 @@ func NewConfigContext() {
//var err error
workDir, err := ExecDir()
if err != nil {
- fmt.Printf("Fail to get work directory: %s\n", err)
- os.Exit(2)
+ qlog.Fatalf("Fail to get work directory: %s\n", err)
}
cfgPath := filepath.Join(workDir, "conf/app.ini")
Cfg, err = goconfig.LoadConfigFile(cfgPath)
if err != nil {
- fmt.Printf("Cannot load config file(%s): %v\n", cfgPath, err)
- os.Exit(2)
+ qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err)
}
Cfg.BlockMode = false
cfgPath = filepath.Join(workDir, "custom/conf/app.ini")
if com.IsFile(cfgPath) {
if err = Cfg.AppendFiles(cfgPath); err != nil {
- fmt.Printf("Cannot load config file(%s): %v\n", cfgPath, err)
- os.Exit(2)
+ qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err)
}
}
@@ -275,14 +297,13 @@ func NewConfigContext() {
InstallLock = Cfg.MustBool("security", "INSTALL_LOCK", false)
RunUser = Cfg.MustValue("", "RUN_USER")
- curUser := os.Getenv("USERNAME")
+ curUser := os.Getenv("USER")
if len(curUser) == 0 {
- curUser = os.Getenv("USER")
+ curUser = os.Getenv("USERNAME")
}
// Does not check run user when the install lock is off.
if InstallLock && RunUser != curUser {
- fmt.Printf("Expect user(%s) but current user is: %s\n", RunUser, curUser)
- os.Exit(2)
+ qlog.Fatalf("Expect user(%s) but current user is: %s\n", RunUser, curUser)
}
LogInRememberDays = Cfg.MustInt("security", "LOGIN_REMEMBER_DAYS")
@@ -294,17 +315,16 @@ func NewConfigContext() {
// Determine and create root git reposiroty path.
homeDir, err := com.HomeDir()
if err != nil {
- fmt.Printf("Fail to get home directory): %v\n", err)
- os.Exit(2)
+ qlog.Fatalf("Fail to get home directory): %v\n", err)
}
- RepoRootPath = Cfg.MustValue("repository", "ROOT", filepath.Join(homeDir, "git/gogs-repositories"))
+ RepoRootPath = Cfg.MustValue("repository", "ROOT", filepath.Join(homeDir, "gogs-repositories"))
if err = os.MkdirAll(RepoRootPath, os.ModePerm); err != nil {
- fmt.Printf("Fail to create RepoRootPath(%s): %v\n", RepoRootPath, err)
- os.Exit(2)
+ qlog.Fatalf("Fail to create RepoRootPath(%s): %v\n", RepoRootPath, err)
}
+ ScriptType = Cfg.MustValue("repository", "SCRIPT_TYPE", "bash")
}
-func NewServices() {
+func NewBaseServices() {
newService()
newLogService()
newCacheService()
diff --git a/modules/base/markdown.go b/modules/base/markdown.go
index 962e1ae1..95b4b212 100644
--- a/modules/base/markdown.go
+++ b/modules/base/markdown.go
@@ -6,9 +6,11 @@ package base
import (
"bytes"
+ "fmt"
"net/http"
"path"
"path/filepath"
+ "regexp"
"strings"
"github.com/gogits/gfm"
@@ -87,13 +89,58 @@ func (options *CustomRender) Link(out *bytes.Buffer, link []byte, title []byte,
options.Renderer.Link(out, link, title, content)
}
+var (
+ MentionPattern = regexp.MustCompile(`@[0-9a-zA-Z_]{1,}`)
+ commitPattern = regexp.MustCompile(`(\s|^)https?.*commit/[0-9a-zA-Z]+(#+[0-9a-zA-Z-]*)?`)
+ issueFullPattern = regexp.MustCompile(`(\s|^)https?.*issues/[0-9]+(#+[0-9a-zA-Z-]*)?`)
+ issueIndexPattern = regexp.MustCompile(`#[0-9]+`)
+)
+
+func RenderSpecialLink(rawBytes []byte, urlPrefix string) []byte {
+ ms := MentionPattern.FindAll(rawBytes, -1)
+ for _, m := range ms {
+ rawBytes = bytes.Replace(rawBytes, m,
+ []byte(fmt.Sprintf(`<a href="/user/%s">%s</a>`, m[1:], m)), -1)
+ }
+ ms = commitPattern.FindAll(rawBytes, -1)
+ for _, m := range ms {
+ m = bytes.TrimSpace(m)
+ i := strings.Index(string(m), "commit/")
+ j := strings.Index(string(m), "#")
+ if j == -1 {
+ j = len(m)
+ }
+ rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf(
+ ` <code><a href="%s">%s</a></code>`, m, ShortSha(string(m[i+7:j])))), -1)
+ }
+ ms = issueFullPattern.FindAll(rawBytes, -1)
+ for _, m := range ms {
+ m = bytes.TrimSpace(m)
+ i := strings.Index(string(m), "issues/")
+ j := strings.Index(string(m), "#")
+ if j == -1 {
+ j = len(m)
+ }
+ rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf(
+ ` <a href="%s">#%s</a>`, m, ShortSha(string(m[i+7:j])))), -1)
+ }
+ ms = issueIndexPattern.FindAll(rawBytes, -1)
+ for _, m := range ms {
+ rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf(
+ `<a href="%s/issues/%s">%s</a>`, urlPrefix, m[1:], m)), -1)
+ }
+ return rawBytes
+}
+
func RenderMarkdown(rawBytes []byte, urlPrefix string) []byte {
+ body := RenderSpecialLink(rawBytes, urlPrefix)
+ // fmt.Println(string(body))
htmlFlags := 0
// htmlFlags |= gfm.HTML_USE_XHTML
// htmlFlags |= gfm.HTML_USE_SMARTYPANTS
// htmlFlags |= gfm.HTML_SMARTYPANTS_FRACTIONS
// htmlFlags |= gfm.HTML_SMARTYPANTS_LATEX_DASHES
- htmlFlags |= gfm.HTML_SKIP_HTML
+ // htmlFlags |= gfm.HTML_SKIP_HTML
htmlFlags |= gfm.HTML_SKIP_STYLE
htmlFlags |= gfm.HTML_SKIP_SCRIPT
htmlFlags |= gfm.HTML_GITHUB_BLOCKCODE
@@ -115,7 +162,11 @@ func RenderMarkdown(rawBytes []byte, urlPrefix string) []byte {
extensions |= gfm.EXTENSION_SPACE_HEADERS
extensions |= gfm.EXTENSION_NO_EMPTY_LINE_BEFORE_BLOCK
- body := gfm.Markdown(rawBytes, renderer, extensions)
-
+ body = gfm.Markdown(body, renderer, extensions)
+ // fmt.Println(string(body))
return body
}
+
+func RenderMarkdownString(raw, urlPrefix string) string {
+ return string(RenderMarkdown([]byte(raw), urlPrefix))
+}
diff --git a/modules/base/template.go b/modules/base/template.go
index dfcae931..79aeeb9d 100644
--- a/modules/base/template.go
+++ b/modules/base/template.go
@@ -5,7 +5,9 @@
package base
import (
+ "bytes"
"container/list"
+ "encoding/json"
"fmt"
"html/template"
"strings"
@@ -54,6 +56,9 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{
"AppDomain": func() string {
return Domain
},
+ "IsProdMode": func() bool {
+ return IsProdMode
+ },
"LoadTimes": func(startTime time.Time) string {
return fmt.Sprint(time.Since(startTime).Nanoseconds()/1e6) + "ms"
},
@@ -62,11 +67,18 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{
"TimeSince": TimeSince,
"FileSize": FileSize,
"Subtract": Subtract,
+ "Add": func(a, b int) int {
+ return a + b
+ },
"ActionIcon": ActionIcon,
"ActionDesc": ActionDesc,
"DateFormat": DateFormat,
"List": List,
"Mail2Domain": func(mail string) string {
+ if !strings.Contains(mail, "@") {
+ return "try.gogits.org"
+ }
+
suffix := strings.SplitN(mail, "@", 2)[1]
domain, ok := mailDomains[suffix]
if !ok {
@@ -80,4 +92,128 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{
"DiffTypeToStr": DiffTypeToStr,
"DiffLineTypeToStr": DiffLineTypeToStr,
"ShortSha": ShortSha,
+ "Oauth2Icon": Oauth2Icon,
+}
+
+type Actioner interface {
+ GetOpType() int
+ GetActUserName() string
+ GetActEmail() string
+ GetRepoName() string
+ GetBranch() string
+ GetContent() string
+}
+
+// ActionIcon accepts a int that represents action operation type
+// and returns a icon class name.
+func ActionIcon(opType int) string {
+ switch opType {
+ case 1: // Create repository.
+ return "plus-circle"
+ case 5, 9: // Commit repository.
+ return "arrow-circle-o-right"
+ case 6: // Create issue.
+ return "exclamation-circle"
+ case 8: // Transfer repository.
+ return "share"
+ default:
+ return "invalid type"
+ }
+}
+
+const (
+ TPL_CREATE_REPO = `<a href="/user/%s">%s</a> created repository <a href="/%s">%s</a>`
+ TPL_COMMIT_REPO = `<a href="/user/%s">%s</a> pushed to <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>%s`
+ TPL_COMMIT_REPO_LI = `<div><img src="%s?s=16" alt="user-avatar"/> <a href="/%s/commit/%s">%s</a> %s</div>`
+ TPL_CREATE_ISSUE = `<a href="/user/%s">%s</a> opened issue <a href="/%s/issues/%s">%s#%s</a>
+<div><img src="%s?s=16" alt="user-avatar"/> %s</div>`
+ TPL_TRANSFER_REPO = `<a href="/user/%s">%s</a> transfered repository <code>%s</code> to <a href="/%s">%s</a>`
+ TPL_PUSH_TAG = `<a href="/user/%s">%s</a> pushed tag <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>`
+)
+
+type PushCommit struct {
+ Sha1 string
+ Message string
+ AuthorEmail string
+ AuthorName string
+}
+
+type PushCommits struct {
+ Len int
+ Commits []*PushCommit
+}
+
+// ActionDesc accepts int that represents action operation type
+// and returns the description.
+func ActionDesc(act Actioner) string {
+ actUserName := act.GetActUserName()
+ email := act.GetActEmail()
+ repoName := act.GetRepoName()
+ repoLink := actUserName + "/" + repoName
+ branch := act.GetBranch()
+ content := act.GetContent()
+ switch act.GetOpType() {
+ case 1: // Create repository.
+ return fmt.Sprintf(TPL_CREATE_REPO, actUserName, actUserName, repoLink, repoName)
+ case 5: // Commit repository.
+ var push *PushCommits
+ if err := json.Unmarshal([]byte(content), &push); err != nil {
+ return err.Error()
+ }
+ buf := bytes.NewBuffer([]byte("\n"))
+ for _, commit := range push.Commits {
+ buf.WriteString(fmt.Sprintf(TPL_COMMIT_REPO_LI, AvatarLink(commit.AuthorEmail), repoLink, commit.Sha1, commit.Sha1[:7], commit.Message) + "\n")
+ }
+ if push.Len > 3 {
+ buf.WriteString(fmt.Sprintf(`<div><a href="/%s/%s/commits/%s">%d other commits >></a></div>`, actUserName, repoName, branch, push.Len))
+ }
+ return fmt.Sprintf(TPL_COMMIT_REPO, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink,
+ buf.String())
+ case 6: // Create issue.
+ infos := strings.SplitN(content, "|", 2)
+ return fmt.Sprintf(TPL_CREATE_ISSUE, actUserName, actUserName, repoLink, infos[0], repoLink, infos[0],
+ AvatarLink(email), infos[1])
+ case 8: // Transfer repository.
+ newRepoLink := content + "/" + repoName
+ return fmt.Sprintf(TPL_TRANSFER_REPO, actUserName, actUserName, repoLink, newRepoLink, newRepoLink)
+ case 9: // Push tag.
+ return fmt.Sprintf(TPL_PUSH_TAG, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink)
+ default:
+ return "invalid type"
+ }
+}
+
+func DiffTypeToStr(diffType int) string {
+ diffTypes := map[int]string{
+ 1: "add", 2: "modify", 3: "del",
+ }
+ return diffTypes[diffType]
+}
+
+func DiffLineTypeToStr(diffType int) string {
+ switch diffType {
+ case 2:
+ return "add"
+ case 3:
+ return "del"
+ case 4:
+ return "tag"
+ }
+ return "same"
+}
+
+func Oauth2Icon(t int) string {
+ switch t {
+ case 1:
+ return "fa-github-square"
+ case 2:
+ return "fa-google-plus-square"
+ case 3:
+ return "fa-twitter-square"
+ case 4:
+ return "fa-linux"
+ case 5:
+ return "fa-weibo"
+ }
+ return ""
}
diff --git a/modules/base/tool.go b/modules/base/tool.go
index 3946c4b5..9b165b97 100644
--- a/modules/base/tool.go
+++ b/modules/base/tool.go
@@ -5,13 +5,13 @@
package base
import (
- "bytes"
+ "crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"encoding/hex"
- "encoding/json"
"fmt"
+ "hash"
"math"
"strconv"
"strings"
@@ -40,6 +40,44 @@ func GetRandomString(n int, alphabets ...byte) string {
return string(bytes)
}
+// http://code.google.com/p/go/source/browse/pbkdf2/pbkdf2.go?repo=crypto
+func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
+ prf := hmac.New(h, password)
+ hashLen := prf.Size()
+ numBlocks := (keyLen + hashLen - 1) / hashLen
+
+ var buf [4]byte
+ dk := make([]byte, 0, numBlocks*hashLen)
+ U := make([]byte, hashLen)
+ for block := 1; block <= numBlocks; block++ {
+ // N.B.: || means concatenation, ^ means XOR
+ // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
+ // U_1 = PRF(password, salt || uint(i))
+ prf.Reset()
+ prf.Write(salt)
+ buf[0] = byte(block >> 24)
+ buf[1] = byte(block >> 16)
+ buf[2] = byte(block >> 8)
+ buf[3] = byte(block)
+ prf.Write(buf[:4])
+ dk = prf.Sum(dk)
+ T := dk[len(dk)-hashLen:]
+ copy(U, T)
+
+ // U_n = PRF(password, U_(n-1))
+ for n := 2; n <= iter; n++ {
+ prf.Reset()
+ prf.Write(U)
+ U = U[:0]
+ U = prf.Sum(U)
+ for x := range U {
+ T[x] ^= U[x]
+ }
+ }
+ }
+ return dk[:keyLen]
+}
+
// verify time limit code
func VerifyTimeLimitCode(data string, minutes int, code string) bool {
if len(code) <= 18 {
@@ -105,7 +143,7 @@ func AvatarLink(email string) string {
if Service.EnableCacheAvatar {
return "/avatar/" + EncodeMd5(email)
}
- return "http://1.gravatar.com/avatar/" + EncodeMd5(email)
+ return "//1.gravatar.com/avatar/" + EncodeMd5(email)
}
// Seconds-based time units
@@ -246,7 +284,6 @@ func TimeSince(then time.Time) string {
default:
return fmt.Sprintf("%d years %s", diff/Year, lbl)
}
- return then.String()
}
const (
@@ -474,107 +511,3 @@ func (a argInt) Get(i int, args ...int) (r int) {
}
return
}
-
-type Actioner interface {
- GetOpType() int
- GetActUserName() string
- GetActEmail() string
- GetRepoName() string
- GetBranch() string
- GetContent() string
-}
-
-// ActionIcon accepts a int that represents action operation type
-// and returns a icon class name.
-func ActionIcon(opType int) string {
- switch opType {
- case 1: // Create repository.
- return "plus-circle"
- case 5: // Commit repository.
- return "arrow-circle-o-right"
- case 6: // Create issue.
- return "exclamation-circle"
- case 8: // Transfer repository.
- return "share"
- default:
- return "invalid type"
- }
-}
-
-const (
- TPL_CREATE_REPO = `<a href="/user/%s">%s</a> created repository <a href="/%s">%s</a>`
- TPL_COMMIT_REPO = `<a href="/user/%s">%s</a> pushed to <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>%s`
- TPL_COMMIT_REPO_LI = `<div><img src="%s?s=16" alt="user-avatar"/> <a href="/%s/commit/%s">%s</a> %s</div>`
- TPL_CREATE_ISSUE = `<a href="/user/%s">%s</a> opened issue <a href="/%s/issues/%s">%s#%s</a>
-<div><img src="%s?s=16" alt="user-avatar"/> %s</div>`
- TPL_TRANSFER_REPO = `<a href="/user/%s">%s</a> transfered repository <code>%s</code> to <a href="/%s">%s</a>`
-)
-
-type PushCommit struct {
- Sha1 string
- Message string
- AuthorEmail string
- AuthorName string
-}
-
-type PushCommits struct {
- Len int
- Commits []*PushCommit
-}
-
-// ActionDesc accepts int that represents action operation type
-// and returns the description.
-func ActionDesc(act Actioner) string {
- actUserName := act.GetActUserName()
- email := act.GetActEmail()
- repoName := act.GetRepoName()
- repoLink := actUserName + "/" + repoName
- branch := act.GetBranch()
- content := act.GetContent()
- switch act.GetOpType() {
- case 1: // Create repository.
- return fmt.Sprintf(TPL_CREATE_REPO, actUserName, actUserName, repoLink, repoName)
- case 5: // Commit repository.
- var push *PushCommits
- if err := json.Unmarshal([]byte(content), &push); err != nil {
- return err.Error()
- }
- buf := bytes.NewBuffer([]byte("\n"))
- for _, commit := range push.Commits {
- buf.WriteString(fmt.Sprintf(TPL_COMMIT_REPO_LI, AvatarLink(commit.AuthorEmail), repoLink, commit.Sha1, commit.Sha1[:7], commit.Message) + "\n")
- }
- if push.Len > 3 {
- buf.WriteString(fmt.Sprintf(`<div><a href="/%s/%s/commits/%s">%d other commits >></a></div>`, actUserName, repoName, branch, push.Len))
- }
- return fmt.Sprintf(TPL_COMMIT_REPO, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink,
- buf.String())
- case 6: // Create issue.
- infos := strings.SplitN(content, "|", 2)
- return fmt.Sprintf(TPL_CREATE_ISSUE, actUserName, actUserName, repoLink, infos[0], repoLink, infos[0],
- AvatarLink(email), infos[1])
- case 8: // Transfer repository.
- newRepoLink := content + "/" + repoName
- return fmt.Sprintf(TPL_TRANSFER_REPO, actUserName, actUserName, repoLink, newRepoLink, newRepoLink)
- default:
- return "invalid type"
- }
-}
-
-func DiffTypeToStr(diffType int) string {
- diffTypes := map[int]string{
- 1: "add", 2: "modify", 3: "del",
- }
- return diffTypes[diffType]
-}
-
-func DiffLineTypeToStr(diffType int) string {
- switch diffType {
- case 2:
- return "add"
- case 3:
- return "del"
- case 4:
- return "tag"
- }
- return "same"
-}
diff --git a/modules/cron/cron.go b/modules/cron/cron.go
new file mode 100644
index 00000000..27b1fc41
--- /dev/null
+++ b/modules/cron/cron.go
@@ -0,0 +1,17 @@
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package cron
+
+import (
+ "github.com/robfig/cron"
+
+ "github.com/gogits/gogs/models"
+)
+
+func NewCronContext() {
+ c := cron.New()
+ c.AddFunc("@every 1h", models.MirrorUpdate)
+ c.Start()
+}
diff --git a/modules/log/log.go b/modules/log/log.go
index 65150237..636ea787 100644
--- a/modules/log/log.go
+++ b/modules/log/log.go
@@ -21,8 +21,7 @@ func init() {
func NewLogger(bufLen int64, mode, config string) {
Mode, Config = mode, config
logger = logs.NewLogger(bufLen)
- logger.EnableFuncCallDepth(true)
- logger.SetLogFuncCallDepth(4)
+ logger.SetLogFuncCallDepth(3)
logger.SetLogger(mode, config)
}
diff --git a/modules/mailer/mail.go b/modules/mailer/mail.go
index b99fc8fd..834f4a89 100644
--- a/modules/mailer/mail.go
+++ b/modules/mailer/mail.go
@@ -86,16 +86,36 @@ func SendActiveMail(r *middleware.Render, user *models.User) {
}
msg := NewMailMessage([]string{user.Email}, subject, body)
- msg.Info = fmt.Sprintf("UID: %d, send email verify mail", user.Id)
+ msg.Info = fmt.Sprintf("UID: %d, send active mail", user.Id)
SendAsync(&msg)
}
-// SendNotifyMail sends mail notification of all watchers.
-func SendNotifyMail(user, owner *models.User, repo *models.Repository, issue *models.Issue) error {
+// Send reset password email.
+func SendResetPasswdMail(r *middleware.Render, user *models.User) {
+ code := CreateUserActiveCode(user, nil)
+
+ subject := "Reset your password"
+
+ data := GetMailTmplData(user)
+ data["Code"] = code
+ body, err := r.HTMLString("mail/auth/reset_passwd", data)
+ if err != nil {
+ log.Error("mail.SendResetPasswdMail(fail to render): %v", err)
+ return
+ }
+
+ msg := NewMailMessage([]string{user.Email}, subject, body)
+ msg.Info = fmt.Sprintf("UID: %d, send reset password email", user.Id)
+
+ SendAsync(&msg)
+}
+
+// SendIssueNotifyMail sends mail notification of all watchers of repository.
+func SendIssueNotifyMail(user, owner *models.User, repo *models.Repository, issue *models.Issue) ([]string, error) {
watches, err := models.GetWatches(repo.Id)
if err != nil {
- return errors.New("mail.NotifyWatchers(get watches): " + err.Error())
+ return nil, errors.New("mail.NotifyWatchers(get watches): " + err.Error())
}
tos := make([]string, 0, len(watches))
@@ -106,20 +126,37 @@ func SendNotifyMail(user, owner *models.User, repo *models.Repository, issue *mo
}
u, err := models.GetUserById(uid)
if err != nil {
- return errors.New("mail.NotifyWatchers(get user): " + err.Error())
+ return nil, errors.New("mail.NotifyWatchers(get user): " + err.Error())
}
tos = append(tos, u.Email)
}
if len(tos) == 0 {
- return nil
+ return tos, nil
}
subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name)
content := fmt.Sprintf("%s<br>-<br> <a href=\"%s%s/%s/issues/%d\">View it on Gogs</a>.",
- issue.Content, base.AppUrl, owner.Name, repo.Name, issue.Index)
+ base.RenderSpecialLink([]byte(issue.Content), owner.Name+"/"+repo.Name),
+ base.AppUrl, owner.Name, repo.Name, issue.Index)
+ msg := NewMailMessageFrom(tos, user.Name, subject, content)
+ msg.Info = fmt.Sprintf("Subject: %s, send issue notify emails", subject)
+ SendAsync(&msg)
+ return tos, nil
+}
+
+// SendIssueMentionMail sends mail notification for who are mentioned in issue.
+func SendIssueMentionMail(user, owner *models.User, repo *models.Repository, issue *models.Issue, tos []string) error {
+ if len(tos) == 0 {
+ return nil
+ }
+
+ issueLink := fmt.Sprintf("%s%s/%s/issues/%d", base.AppUrl, owner.Name, repo.Name, issue.Index)
+ body := fmt.Sprintf(`%s mentioned you.`, user.Name)
+ subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name)
+ content := fmt.Sprintf("%s<br>-<br> <a href=\"%s\">View it on Gogs</a>.", body, issueLink)
msg := NewMailMessageFrom(tos, user.Name, subject, content)
- msg.Info = fmt.Sprintf("Subject: %s, send notify emails", subject)
+ msg.Info = fmt.Sprintf("Subject: %s, send issue mention emails", subject)
SendAsync(&msg)
return nil
}
diff --git a/modules/middleware/auth.go b/modules/middleware/auth.go
index bde3be72..39b7796d 100644
--- a/modules/middleware/auth.go
+++ b/modules/middleware/auth.go
@@ -47,7 +47,7 @@ func Toggle(options *ToggleOptions) martini.Handler {
return
} else if !ctx.User.IsActive && base.Service.RegisterEmailConfirm {
ctx.Data["Title"] = "Activate Your Account"
- ctx.HTML(200, "user/active")
+ ctx.HTML(200, "user/activate")
return
}
}
diff --git a/modules/middleware/binding.go b/modules/middleware/binding.go
new file mode 100644
index 00000000..cde9ae9c
--- /dev/null
+++ b/modules/middleware/binding.go
@@ -0,0 +1,426 @@
+// Copyright 2013 The Martini Contrib Authors. All rights reserved.
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package middleware
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "reflect"
+ "regexp"
+ "strconv"
+ "strings"
+ "unicode/utf8"
+
+ "github.com/go-martini/martini"
+
+ "github.com/gogits/gogs/modules/base"
+)
+
+/*
+ To the land of Middle-ware Earth:
+
+ One func to rule them all,
+ One func to find them,
+ One func to bring them all,
+ And in this package BIND them.
+*/
+
+// Bind accepts a copy of an empty struct and populates it with
+// values from the request (if deserialization is successful). It
+// wraps up the functionality of the Form and Json middleware
+// according to the Content-Type of the request, and it guesses
+// if no Content-Type is specified. Bind invokes the ErrorHandler
+// middleware to bail out if errors occurred. If you want to perform
+// your own error handling, use Form or Json middleware directly.
+// An interface pointer can be added as a second argument in order
+// to map the struct to a specific interface.
+func Bind(obj interface{}, ifacePtr ...interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ contentType := req.Header.Get("Content-Type")
+
+ if strings.Contains(contentType, "form-urlencoded") {
+ context.Invoke(Form(obj, ifacePtr...))
+ } else if strings.Contains(contentType, "multipart/form-data") {
+ context.Invoke(MultipartForm(obj, ifacePtr...))
+ } else if strings.Contains(contentType, "json") {
+ context.Invoke(Json(obj, ifacePtr...))
+ } else {
+ context.Invoke(Json(obj, ifacePtr...))
+ if getErrors(context).Count() > 0 {
+ context.Invoke(Form(obj, ifacePtr...))
+ }
+ }
+
+ context.Invoke(ErrorHandler)
+ }
+}
+
+// BindIgnErr will do the exactly same thing as Bind but without any
+// error handling, which user has freedom to deal with them.
+// This allows user take advantages of validation.
+func BindIgnErr(obj interface{}, ifacePtr ...interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ contentType := req.Header.Get("Content-Type")
+
+ if strings.Contains(contentType, "form-urlencoded") {
+ context.Invoke(Form(obj, ifacePtr...))
+ } else if strings.Contains(contentType, "multipart/form-data") {
+ context.Invoke(MultipartForm(obj, ifacePtr...))
+ } else if strings.Contains(contentType, "json") {
+ context.Invoke(Json(obj, ifacePtr...))
+ } else {
+ context.Invoke(Json(obj, ifacePtr...))
+ if getErrors(context).Count() > 0 {
+ context.Invoke(Form(obj, ifacePtr...))
+ }
+ }
+ }
+}
+
+// Form is middleware to deserialize form-urlencoded data from the request.
+// It gets data from the form-urlencoded body, if present, or from the
+// query string. It uses the http.Request.ParseForm() method
+// to perform deserialization, then reflection is used to map each field
+// into the struct with the proper type. Structs with primitive slice types
+// (bool, float, int, string) can support deserialization of repeated form
+// keys, for example: key=val1&key=val2&key=val3
+// An interface pointer can be added as a second argument in order
+// to map the struct to a specific interface.
+func Form(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ ensureNotPointer(formStruct)
+ formStruct := reflect.New(reflect.TypeOf(formStruct))
+ errors := newErrors()
+ parseErr := req.ParseForm()
+
+ // Format validation of the request body or the URL would add considerable overhead,
+ // and ParseForm does not complain when URL encoding is off.
+ // Because an empty request body or url can also mean absence of all needed values,
+ // it is not in all cases a bad request, so let's return 422.
+ if parseErr != nil {
+ errors.Overall[base.BindingDeserializationError] = parseErr.Error()
+ }
+
+ mapForm(formStruct, req.Form, errors)
+
+ validateAndMap(formStruct, context, errors, ifacePtr...)
+ }
+}
+
+func MultipartForm(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ ensureNotPointer(formStruct)
+ formStruct := reflect.New(reflect.TypeOf(formStruct))
+ errors := newErrors()
+
+ // Workaround for multipart forms returning nil instead of an error
+ // when content is not multipart
+ // https://code.google.com/p/go/issues/detail?id=6334
+ multipartReader, err := req.MultipartReader()
+ if err != nil {
+ errors.Overall[base.BindingDeserializationError] = err.Error()
+ } else {
+ form, parseErr := multipartReader.ReadForm(MaxMemory)
+
+ if parseErr != nil {
+ errors.Overall[base.BindingDeserializationError] = parseErr.Error()
+ }
+
+ req.MultipartForm = form
+ }
+
+ mapForm(formStruct, req.MultipartForm.Value, errors)
+
+ validateAndMap(formStruct, context, errors, ifacePtr...)
+ }
+}
+
+// Json is middleware to deserialize a JSON payload from the request
+// into the struct that is passed in. The resulting struct is then
+// validated, but no error handling is actually performed here.
+// An interface pointer can be added as a second argument in order
+// to map the struct to a specific interface.
+func Json(jsonStruct interface{}, ifacePtr ...interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ ensureNotPointer(jsonStruct)
+ jsonStruct := reflect.New(reflect.TypeOf(jsonStruct))
+ errors := newErrors()
+
+ if req.Body != nil {
+ defer req.Body.Close()
+ }
+
+ if err := json.NewDecoder(req.Body).Decode(jsonStruct.Interface()); err != nil && err != io.EOF {
+ errors.Overall[base.BindingDeserializationError] = err.Error()
+ }
+
+ validateAndMap(jsonStruct, context, errors, ifacePtr...)
+ }
+}
+
+// Validate is middleware to enforce required fields. If the struct
+// passed in is a Validator, then the user-defined Validate method
+// is executed, and its errors are mapped to the context. This middleware
+// performs no error handling: it merely detects them and maps them.
+func Validate(obj interface{}) martini.Handler {
+ return func(context martini.Context, req *http.Request) {
+ errors := newErrors()
+ validateStruct(errors, obj)
+
+ if validator, ok := obj.(Validator); ok {
+ validator.Validate(errors, req, context)
+ }
+ context.Map(*errors)
+ }
+}
+
+var (
+ alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]")
+ emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?")
+ urlPattern = regexp.MustCompile(`(http|https):\/\/[\w\-_]+(\.[\w\-_]+)+([\w\-\.,@?^=%&amp;:/~\+#]*[\w\-\@?^=%&amp;/~\+#])?`)
+)
+
+func validateStruct(errors *base.BindingErrors, obj interface{}) {
+ typ := reflect.TypeOf(obj)
+ val := reflect.ValueOf(obj)
+
+ if typ.Kind() == reflect.Ptr {
+ typ = typ.Elem()
+ val = val.Elem()
+ }
+
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+
+ // Allow ignored fields in the struct
+ if field.Tag.Get("form") == "-" {
+ continue
+ }
+
+ fieldValue := val.Field(i).Interface()
+ if field.Type.Kind() == reflect.Struct {
+ validateStruct(errors, fieldValue)
+ continue
+ }
+
+ zero := reflect.Zero(field.Type).Interface()
+
+ // Match rules.
+ for _, rule := range strings.Split(field.Tag.Get("binding"), ";") {
+ if len(rule) == 0 {
+ continue
+ }
+
+ switch {
+ case rule == "Required":
+ if reflect.DeepEqual(zero, fieldValue) {
+ errors.Fields[field.Name] = base.BindingRequireError
+ break
+ }
+ case rule == "AlphaDash":
+ if alphaDashPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
+ errors.Fields[field.Name] = base.BindingAlphaDashError
+ break
+ }
+ case strings.HasPrefix(rule, "MinSize("):
+ min, err := strconv.Atoi(rule[8 : len(rule)-1])
+ if err != nil {
+ errors.Overall["MinSize"] = err.Error()
+ break
+ }
+ if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) < min {
+ errors.Fields[field.Name] = base.BindingMinSizeError
+ break
+ }
+ v := reflect.ValueOf(fieldValue)
+ if v.Kind() == reflect.Slice && v.Len() < min {
+ errors.Fields[field.Name] = base.BindingMinSizeError
+ break
+ }
+ case strings.HasPrefix(rule, "MaxSize("):
+ max, err := strconv.Atoi(rule[8 : len(rule)-1])
+ if err != nil {
+ errors.Overall["MaxSize"] = err.Error()
+ break
+ }
+ if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) > max {
+ errors.Fields[field.Name] = base.BindingMaxSizeError
+ break
+ }
+ v := reflect.ValueOf(fieldValue)
+ if v.Kind() == reflect.Slice && v.Len() > max {
+ errors.Fields[field.Name] = base.BindingMinSizeError
+ break
+ }
+ case rule == "Email":
+ if !emailPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
+ errors.Fields[field.Name] = base.BindingEmailError
+ break
+ }
+ case rule == "Url":
+ if !urlPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
+ errors.Fields[field.Name] = base.BindingUrlError
+ break
+ }
+ }
+ }
+ }
+}
+
+func mapForm(formStruct reflect.Value, form map[string][]string, errors *base.BindingErrors) {
+ typ := formStruct.Elem().Type()
+
+ for i := 0; i < typ.NumField(); i++ {
+ typeField := typ.Field(i)
+ if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
+ structField := formStruct.Elem().Field(i)
+ if !structField.CanSet() {
+ continue
+ }
+
+ inputValue, exists := form[inputFieldName]
+
+ if !exists {
+ continue
+ }
+
+ numElems := len(inputValue)
+ if structField.Kind() == reflect.Slice && numElems > 0 {
+ sliceOf := structField.Type().Elem().Kind()
+ slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ setWithProperType(sliceOf, inputValue[i], slice.Index(i), inputFieldName, errors)
+ }
+ formStruct.Elem().Field(i).Set(slice)
+ } else {
+ setWithProperType(typeField.Type.Kind(), inputValue[0], structField, inputFieldName, errors)
+ }
+ }
+ }
+}
+
+// ErrorHandler simply counts the number of errors in the
+// context and, if more than 0, writes a 400 Bad Request
+// response and a JSON payload describing the errors with
+// the "Content-Type" set to "application/json".
+// Middleware remaining on the stack will not even see the request
+// if, by this point, there are any errors.
+// This is a "default" handler, of sorts, and you are
+// welcome to use your own instead. The Bind middleware
+// invokes this automatically for convenience.
+func ErrorHandler(errs base.BindingErrors, resp http.ResponseWriter) {
+ if errs.Count() > 0 {
+ resp.Header().Set("Content-Type", "application/json; charset=utf-8")
+ if _, ok := errs.Overall[base.BindingDeserializationError]; ok {
+ resp.WriteHeader(http.StatusBadRequest)
+ } else {
+ resp.WriteHeader(422)
+ }
+ errOutput, _ := json.Marshal(errs)
+ resp.Write(errOutput)
+ return
+ }
+}
+
+// This sets the value in a struct of an indeterminate type to the
+// matching value from the request (via Form middleware) in the
+// same type, so that not all deserialized values have to be strings.
+// Supported types are string, int, float, and bool.
+func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value, nameInTag string, errors *base.BindingErrors) {
+ switch valueKind {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if val == "" {
+ val = "0"
+ }
+ intVal, err := strconv.ParseInt(val, 10, 64)
+ if err != nil {
+ errors.Fields[nameInTag] = base.BindingIntegerTypeError
+ } else {
+ structField.SetInt(intVal)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if val == "" {
+ val = "0"
+ }
+ uintVal, err := strconv.ParseUint(val, 10, 64)
+ if err != nil {
+ errors.Fields[nameInTag] = base.BindingIntegerTypeError
+ } else {
+ structField.SetUint(uintVal)
+ }
+ case reflect.Bool:
+ structField.SetBool(val == "on")
+ case reflect.Float32:
+ if val == "" {
+ val = "0.0"
+ }
+ floatVal, err := strconv.ParseFloat(val, 32)
+ if err != nil {
+ errors.Fields[nameInTag] = base.BindingFloatTypeError
+ } else {
+ structField.SetFloat(floatVal)
+ }
+ case reflect.Float64:
+ if val == "" {
+ val = "0.0"
+ }
+ floatVal, err := strconv.ParseFloat(val, 64)
+ if err != nil {
+ errors.Fields[nameInTag] = base.BindingFloatTypeError
+ } else {
+ structField.SetFloat(floatVal)
+ }
+ case reflect.String:
+ structField.SetString(val)
+ }
+}
+
+// Don't pass in pointers to bind to. Can lead to bugs. See:
+// https://github.com/codegangsta/martini-contrib/issues/40
+// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
+func ensureNotPointer(obj interface{}) {
+ if reflect.TypeOf(obj).Kind() == reflect.Ptr {
+ panic("Pointers are not accepted as binding models")
+ }
+}
+
+// Performs validation and combines errors from validation
+// with errors from deserialization, then maps both the
+// resulting struct and the errors to the context.
+func validateAndMap(obj reflect.Value, context martini.Context, errors *base.BindingErrors, ifacePtr ...interface{}) {
+ context.Invoke(Validate(obj.Interface()))
+ errors.Combine(getErrors(context))
+ context.Map(*errors)
+ context.Map(obj.Elem().Interface())
+ if len(ifacePtr) > 0 {
+ context.MapTo(obj.Elem().Interface(), ifacePtr[0])
+ }
+}
+
+func newErrors() *base.BindingErrors {
+ return &base.BindingErrors{make(map[string]string), make(map[string]string)}
+}
+
+func getErrors(context martini.Context) base.BindingErrors {
+ return context.Get(reflect.TypeOf(base.BindingErrors{})).Interface().(base.BindingErrors)
+}
+
+type (
+ // Implement the Validator interface to define your own input
+ // validation before the request even gets to your application.
+ // The Validate method will be executed during the validation phase.
+ Validator interface {
+ Validate(*base.BindingErrors, *http.Request, martini.Context)
+ }
+)
+
+var (
+ // Maximum amount of memory to use when parsing a multipart form.
+ // Set this to whatever value you prefer; default is 10 MB.
+ MaxMemory = int64(1024 * 1024 * 10)
+)
diff --git a/modules/middleware/binding_test.go b/modules/middleware/binding_test.go
new file mode 100644
index 00000000..2a74e1a6
--- /dev/null
+++ b/modules/middleware/binding_test.go
@@ -0,0 +1,701 @@
+// Copyright 2013 The Martini Contrib Authors. All rights reserved.
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package middleware
+
+import (
+ "bytes"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/codegangsta/martini"
+)
+
+func TestBind(t *testing.T) {
+ testBind(t, false)
+}
+
+func TestBindWithInterface(t *testing.T) {
+ testBind(t, true)
+}
+
+func TestMultipartBind(t *testing.T) {
+ index := 0
+ for test, expectStatus := range bindMultipartTests {
+ handler := func(post BlogPost, errors Errors) {
+ handle(test, t, index, post, errors)
+ }
+ recorder := testMultipart(t, test, Bind(BlogPost{}), handler, index)
+
+ if recorder.Code != expectStatus {
+ t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus)
+ }
+
+ index++
+ }
+}
+
+func TestForm(t *testing.T) {
+ testForm(t, false)
+}
+
+func TestFormWithInterface(t *testing.T) {
+ testForm(t, true)
+}
+
+func TestEmptyForm(t *testing.T) {
+ testEmptyForm(t)
+}
+
+func TestMultipartForm(t *testing.T) {
+ for index, test := range multipartformTests {
+ handler := func(post BlogPost, errors Errors) {
+ handle(test, t, index, post, errors)
+ }
+ testMultipart(t, test, MultipartForm(BlogPost{}), handler, index)
+ }
+}
+
+func TestMultipartFormWithInterface(t *testing.T) {
+ for index, test := range multipartformTests {
+ handler := func(post Modeler, errors Errors) {
+ post.Create(test, t, index)
+ }
+ testMultipart(t, test, MultipartForm(BlogPost{}, (*Modeler)(nil)), handler, index)
+ }
+}
+
+func TestJson(t *testing.T) {
+ testJson(t, false)
+}
+
+func TestJsonWithInterface(t *testing.T) {
+ testJson(t, true)
+}
+
+func TestEmptyJson(t *testing.T) {
+ testEmptyJson(t)
+}
+
+func TestValidate(t *testing.T) {
+ handlerMustErr := func(errors Errors) {
+ if errors.Count() == 0 {
+ t.Error("Expected at least one error, got 0")
+ }
+ }
+ handlerNoErr := func(errors Errors) {
+ if errors.Count() > 0 {
+ t.Error("Expected no errors, got", errors.Count())
+ }
+ }
+
+ performValidationTest(&BlogPost{"", "...", 0, 0, []int{}}, handlerMustErr, t)
+ performValidationTest(&BlogPost{"Good Title", "Good content", 0, 0, []int{}}, handlerNoErr, t)
+
+ performValidationTest(&User{Name: "Jim", Home: Address{"", ""}}, handlerMustErr, t)
+ performValidationTest(&User{Name: "Jim", Home: Address{"required", ""}}, handlerNoErr, t)
+}
+
+func handle(test testCase, t *testing.T, index int, post BlogPost, errors Errors) {
+ assertEqualField(t, "Title", index, test.ref.Title, post.Title)
+ assertEqualField(t, "Content", index, test.ref.Content, post.Content)
+ assertEqualField(t, "Views", index, test.ref.Views, post.Views)
+
+ for i := range test.ref.Multiple {
+ if i >= len(post.Multiple) {
+ t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", post.Multiple, len(post.Multiple), test.ref.Multiple, len(test.ref.Multiple))
+ break
+ }
+ if test.ref.Multiple[i] != post.Multiple[i] {
+ t.Errorf("Expected: %v to deep equal: %v", post.Multiple, test.ref.Multiple)
+ break
+ }
+ }
+
+ if test.ok && errors.Count() > 0 {
+ t.Errorf("%+v should be OK (0 errors), but had errors: %+v", test, errors)
+ } else if !test.ok && errors.Count() == 0 {
+ t.Errorf("%+v should have errors, but was OK (0 errors)", test)
+ }
+}
+
+func handleEmpty(test emptyPayloadTestCase, t *testing.T, index int, section BlogSection, errors Errors) {
+ assertEqualField(t, "Title", index, test.ref.Title, section.Title)
+ assertEqualField(t, "Content", index, test.ref.Content, section.Content)
+
+ if test.ok && errors.Count() > 0 {
+ t.Errorf("%+v should be OK (0 errors), but had errors: %+v", test, errors)
+ } else if !test.ok && errors.Count() == 0 {
+ t.Errorf("%+v should have errors, but was OK (0 errors)", test)
+ }
+}
+
+func testBind(t *testing.T, withInterface bool) {
+ index := 0
+ for test, expectStatus := range bindTests {
+ m := martini.Classic()
+ recorder := httptest.NewRecorder()
+ handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
+ binding := Bind(BlogPost{})
+
+ if withInterface {
+ handler = func(post BlogPost, errors Errors) {
+ post.Create(test, t, index)
+ }
+ binding = Bind(BlogPost{}, (*Modeler)(nil))
+ }
+
+ switch test.method {
+ case "GET":
+ m.Get(route, binding, handler)
+ case "POST":
+ m.Post(route, binding, handler)
+ }
+
+ req, err := http.NewRequest(test.method, test.path, strings.NewReader(test.payload))
+ req.Header.Add("Content-Type", test.contentType)
+
+ if err != nil {
+ t.Error(err)
+ }
+ m.ServeHTTP(recorder, req)
+
+ if recorder.Code != expectStatus {
+ t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus)
+ }
+
+ index++
+ }
+}
+
+func testJson(t *testing.T, withInterface bool) {
+ for index, test := range jsonTests {
+ recorder := httptest.NewRecorder()
+ handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
+ binding := Json(BlogPost{})
+
+ if withInterface {
+ handler = func(post BlogPost, errors Errors) {
+ post.Create(test, t, index)
+ }
+ binding = Bind(BlogPost{}, (*Modeler)(nil))
+ }
+
+ m := martini.Classic()
+ switch test.method {
+ case "GET":
+ m.Get(route, binding, handler)
+ case "POST":
+ m.Post(route, binding, handler)
+ case "PUT":
+ m.Put(route, binding, handler)
+ case "DELETE":
+ m.Delete(route, binding, handler)
+ }
+
+ req, err := http.NewRequest(test.method, route, strings.NewReader(test.payload))
+ if err != nil {
+ t.Error(err)
+ }
+ m.ServeHTTP(recorder, req)
+ }
+}
+
+func testEmptyJson(t *testing.T) {
+ for index, test := range emptyPayloadTests {
+ recorder := httptest.NewRecorder()
+ handler := func(section BlogSection, errors Errors) { handleEmpty(test, t, index, section, errors) }
+ binding := Json(BlogSection{})
+
+ m := martini.Classic()
+ switch test.method {
+ case "GET":
+ m.Get(route, binding, handler)
+ case "POST":
+ m.Post(route, binding, handler)
+ case "PUT":
+ m.Put(route, binding, handler)
+ case "DELETE":
+ m.Delete(route, binding, handler)
+ }
+
+ req, err := http.NewRequest(test.method, route, strings.NewReader(test.payload))
+ if err != nil {
+ t.Error(err)
+ }
+ m.ServeHTTP(recorder, req)
+ }
+}
+
+func testForm(t *testing.T, withInterface bool) {
+ for index, test := range formTests {
+ recorder := httptest.NewRecorder()
+ handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) }
+ binding := Form(BlogPost{})
+
+ if withInterface {
+ handler = func(post BlogPost, errors Errors) {
+ post.Create(test, t, index)
+ }
+ binding = Form(BlogPost{}, (*Modeler)(nil))
+ }
+
+ m := martini.Classic()
+ switch test.method {
+ case "GET":
+ m.Get(route, binding, handler)
+ case "POST":
+ m.Post(route, binding, handler)
+ }
+
+ req, err := http.NewRequest(test.method, test.path, nil)
+ if err != nil {
+ t.Error(err)
+ }
+ m.ServeHTTP(recorder, req)
+ }
+}
+
+func testEmptyForm(t *testing.T) {
+ for index, test := range emptyPayloadTests {
+ recorder := httptest.NewRecorder()
+ handler := func(section BlogSection, errors Errors) { handleEmpty(test, t, index, section, errors) }
+ binding := Form(BlogSection{})
+
+ m := martini.Classic()
+ switch test.method {
+ case "GET":
+ m.Get(route, binding, handler)
+ case "POST":
+ m.Post(route, binding, handler)
+ }
+
+ req, err := http.NewRequest(test.method, test.path, nil)
+ if err != nil {
+ t.Error(err)
+ }
+ m.ServeHTTP(recorder, req)
+ }
+}
+
+func testMultipart(t *testing.T, test testCase, middleware martini.Handler, handler martini.Handler, index int) *httptest.ResponseRecorder {
+ recorder := httptest.NewRecorder()
+
+ m := martini.Classic()
+ m.Post(route, middleware, handler)
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ writer.WriteField("title", test.ref.Title)
+ writer.WriteField("content", test.ref.Content)
+ writer.WriteField("views", strconv.Itoa(test.ref.Views))
+ if len(test.ref.Multiple) != 0 {
+ for _, value := range test.ref.Multiple {
+ writer.WriteField("multiple", strconv.Itoa(value))
+ }
+ }
+
+ req, err := http.NewRequest(test.method, test.path, body)
+ req.Header.Add("Content-Type", writer.FormDataContentType())
+
+ if err != nil {
+ t.Error(err)
+ }
+
+ err = writer.Close()
+ if err != nil {
+ t.Error(err)
+ }
+
+ m.ServeHTTP(recorder, req)
+
+ return recorder
+}
+
+func assertEqualField(t *testing.T, fieldname string, testcasenumber int, expected interface{}, got interface{}) {
+ if expected != got {
+ t.Errorf("%s: expected=%s, got=%s in test case %d\n", fieldname, expected, got, testcasenumber)
+ }
+}
+
+func performValidationTest(data interface{}, handler func(Errors), t *testing.T) {
+ recorder := httptest.NewRecorder()
+ m := martini.Classic()
+ m.Get(route, Validate(data), handler)
+
+ req, err := http.NewRequest("GET", route, nil)
+ if err != nil {
+ t.Error("HTTP error:", err)
+ }
+
+ m.ServeHTTP(recorder, req)
+}
+
+func (self BlogPost) Validate(errors *Errors, req *http.Request) {
+ if len(self.Title) < 4 {
+ errors.Fields["Title"] = "Too short; minimum 4 characters"
+ }
+ if len(self.Content) > 1024 {
+ errors.Fields["Content"] = "Too long; maximum 1024 characters"
+ }
+ if len(self.Content) < 5 {
+ errors.Fields["Content"] = "Too short; minimum 5 characters"
+ }
+}
+
+func (self BlogPost) Create(test testCase, t *testing.T, index int) {
+ assertEqualField(t, "Title", index, test.ref.Title, self.Title)
+ assertEqualField(t, "Content", index, test.ref.Content, self.Content)
+ assertEqualField(t, "Views", index, test.ref.Views, self.Views)
+
+ for i := range test.ref.Multiple {
+ if i >= len(self.Multiple) {
+ t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", self.Multiple, len(self.Multiple), test.ref.Multiple, len(test.ref.Multiple))
+ break
+ }
+ if test.ref.Multiple[i] != self.Multiple[i] {
+ t.Errorf("Expected: %v to deep equal: %v", self.Multiple, test.ref.Multiple)
+ break
+ }
+ }
+}
+
+func (self BlogSection) Create(test emptyPayloadTestCase, t *testing.T, index int) {
+ // intentionally left empty
+}
+
+type (
+ testCase struct {
+ method string
+ path string
+ payload string
+ contentType string
+ ok bool
+ ref *BlogPost
+ }
+
+ emptyPayloadTestCase struct {
+ method string
+ path string
+ payload string
+ contentType string
+ ok bool
+ ref *BlogSection
+ }
+
+ Modeler interface {
+ Create(test testCase, t *testing.T, index int)
+ }
+
+ BlogPost struct {
+ Title string `form:"title" json:"title" binding:"required"`
+ Content string `form:"content" json:"content"`
+ Views int `form:"views" json:"views"`
+ internal int `form:"-"`
+ Multiple []int `form:"multiple"`
+ }
+
+ BlogSection struct {
+ Title string `form:"title" json:"title"`
+ Content string `form:"content" json:"content"`
+ }
+
+ User struct {
+ Name string `json:"name" binding:"required"`
+ Home Address `json:"address" binding:"required"`
+ }
+
+ Address struct {
+ Street1 string `json:"street1" binding:"required"`
+ Street2 string `json:"street2"`
+ }
+)
+
+var (
+ bindTests = map[testCase]int{
+ // These should bail at the deserialization/binding phase
+ testCase{
+ "POST",
+ path,
+ `{ bad JSON `,
+ "application/json",
+ false,
+ new(BlogPost),
+ }: http.StatusBadRequest,
+ testCase{
+ "POST",
+ path,
+ `not multipart but has content-type`,
+ "multipart/form-data",
+ false,
+ new(BlogPost),
+ }: http.StatusBadRequest,
+ testCase{
+ "POST",
+ path,
+ `no content-type and not URL-encoded or JSON"`,
+ "",
+ false,
+ new(BlogPost),
+ }: http.StatusBadRequest,
+
+ // These should deserialize, then bail at the validation phase
+ testCase{
+ "POST",
+ path + "?title= This is wrong ",
+ `not URL-encoded but has content-type`,
+ "x-www-form-urlencoded",
+ false,
+ new(BlogPost),
+ }: 422, // according to comments in Form() -> although the request is not url encoded, ParseForm does not complain
+ testCase{
+ "GET",
+ path + "?content=This+is+the+content",
+ ``,
+ "x-www-form-urlencoded",
+ false,
+ &BlogPost{Title: "", Content: "This is the content"},
+ }: 422,
+ testCase{
+ "GET",
+ path + "",
+ `{"content":"", "title":"Blog Post Title"}`,
+ "application/json",
+ false,
+ &BlogPost{Title: "Blog Post Title", Content: ""},
+ }: 422,
+
+ // These should succeed
+ testCase{
+ "GET",
+ path + "",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "application/json",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ }: http.StatusOK,
+ testCase{
+ "GET",
+ path + "?content=This+is+the+content&title=Blog+Post+Title",
+ ``,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ }: http.StatusOK,
+ testCase{
+ "GET",
+ path + "?content=This is the content&title=Blog+Post+Title",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ }: http.StatusOK,
+ testCase{
+ "GET",
+ path + "",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ }: http.StatusOK,
+ }
+
+ bindMultipartTests = map[testCase]int{
+ // This should deserialize, then bail at the validation phase
+ testCase{
+ "POST",
+ path,
+ "",
+ "multipart/form-data",
+ false,
+ &BlogPost{Title: "", Content: "This is the content"},
+ }: 422,
+ // This should succeed
+ testCase{
+ "POST",
+ path,
+ "",
+ "multipart/form-data",
+ true,
+ &BlogPost{Title: "This is the Title", Content: "This is the content"},
+ }: http.StatusOK,
+ }
+
+ formTests = []testCase{
+ {
+ "GET",
+ path + "?content=This is the content",
+ "",
+ "",
+ false,
+ &BlogPost{Title: "", Content: "This is the content"},
+ },
+ {
+ "POST",
+ path + "?content=This+is+the+content&title=Blog+Post+Title&views=3",
+ "",
+ "",
+ false, // false because POST requests should have a body, not just a query string
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3},
+ },
+ {
+ "GET",
+ path + "?content=This+is+the+content&title=Blog+Post+Title&views=3&multiple=5&multiple=10&multiple=15&multiple=20",
+ "",
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}},
+ },
+ }
+
+ multipartformTests = []testCase{
+ {
+ "POST",
+ path,
+ "",
+ "multipart/form-data",
+ false,
+ &BlogPost{Title: "", Content: "This is the content"},
+ },
+ {
+ "POST",
+ path,
+ "",
+ "multipart/form-data",
+ false,
+ &BlogPost{Title: "Blog Post Title", Views: 3},
+ },
+ {
+ "POST",
+ path,
+ "",
+ "multipart/form-data",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}},
+ },
+ }
+
+ emptyPayloadTests = []emptyPayloadTestCase{
+ {
+ "GET",
+ "",
+ "",
+ "",
+ true,
+ &BlogSection{},
+ },
+ {
+ "POST",
+ "",
+ "",
+ "",
+ true,
+ &BlogSection{},
+ },
+ {
+ "PUT",
+ "",
+ "",
+ "",
+ true,
+ &BlogSection{},
+ },
+ {
+ "DELETE",
+ "",
+ "",
+ "",
+ true,
+ &BlogSection{},
+ },
+ }
+
+ jsonTests = []testCase{
+ // bad requests
+ {
+ "GET",
+ "",
+ `{blah blah blah}`,
+ "",
+ false,
+ &BlogPost{},
+ },
+ {
+ "POST",
+ "",
+ `{asdf}`,
+ "",
+ false,
+ &BlogPost{},
+ },
+ {
+ "PUT",
+ "",
+ `{blah blah blah}`,
+ "",
+ false,
+ &BlogPost{},
+ },
+ {
+ "DELETE",
+ "",
+ `{;sdf _SDf- }`,
+ "",
+ false,
+ &BlogPost{},
+ },
+
+ // Valid-JSON requests
+ {
+ "GET",
+ "",
+ `{"content":"This is the content"}`,
+ "",
+ false,
+ &BlogPost{Title: "", Content: "This is the content"},
+ },
+ {
+ "POST",
+ "",
+ `{}`,
+ "application/json",
+ false,
+ &BlogPost{Title: "", Content: ""},
+ },
+ {
+ "POST",
+ "",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ },
+ {
+ "PUT",
+ "",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ },
+ {
+ "DELETE",
+ "",
+ `{"content":"This is the content", "title":"Blog Post Title"}`,
+ "",
+ true,
+ &BlogPost{Title: "Blog Post Title", Content: "This is the content"},
+ },
+ }
+)
+
+const (
+ route = "/blogposts/create"
+ path = "http://localhost:3000" + route
+)
diff --git a/modules/middleware/context.go b/modules/middleware/context.go
index 8129b13b..1330172f 100644
--- a/modules/middleware/context.go
+++ b/modules/middleware/context.go
@@ -10,7 +10,10 @@ import (
"encoding/base64"
"fmt"
"html/template"
+ "io"
"net/http"
+ "net/url"
+ "path/filepath"
"strconv"
"strings"
"time"
@@ -34,6 +37,7 @@ type Context struct {
p martini.Params
Req *http.Request
Res http.ResponseWriter
+ Flash *Flash
Session session.SessionStore
Cache cache.Cache
User *models.User
@@ -47,6 +51,7 @@ type Context struct {
IsBranch bool
IsTag bool
IsCommit bool
+ HasAccess bool
Repository *models.Repository
Owner *models.User
Commit *git.Commit
@@ -59,6 +64,7 @@ type Context struct {
HTTPS string
Git string
}
+ Mirror *models.Mirror
}
}
@@ -78,6 +84,8 @@ func (ctx *Context) HasError() bool {
if !ok {
return false
}
+ ctx.Flash.ErrorMsg = ctx.Data["ErrorMsg"].(string)
+ ctx.Data["Flash"] = ctx.Flash
return hasErr.(bool)
}
@@ -88,23 +96,21 @@ func (ctx *Context) HTML(status int, name string, htmlOpt ...HTMLOptions) {
// RenderWithErr used for page has form validation but need to prompt error to users.
func (ctx *Context) RenderWithErr(msg, tpl string, form auth.Form) {
- ctx.Data["HasError"] = true
- ctx.Data["ErrorMsg"] = msg
if form != nil {
auth.AssignForm(form, ctx.Data)
}
+ ctx.Flash.ErrorMsg = msg
+ ctx.Data["Flash"] = ctx.Flash
ctx.HTML(200, tpl)
}
// Handle handles and logs error by given status.
func (ctx *Context) Handle(status int, title string, err error) {
log.Error("%s: %v", title, err)
- if martini.Dev == martini.Prod {
- ctx.HTML(500, "status/500")
- return
+ if martini.Dev != martini.Prod {
+ ctx.Data["ErrorMsg"] = err
}
- ctx.Data["ErrorMsg"] = err
ctx.HTML(status, fmt.Sprintf("status/%d", status))
}
@@ -239,6 +245,56 @@ func (ctx *Context) CsrfTokenValid() bool {
return true
}
+func (ctx *Context) ServeFile(file string, names ...string) {
+ var name string
+ if len(names) > 0 {
+ name = names[0]
+ } else {
+ name = filepath.Base(file)
+ }
+ ctx.Res.Header().Set("Content-Description", "File Transfer")
+ ctx.Res.Header().Set("Content-Type", "application/octet-stream")
+ ctx.Res.Header().Set("Content-Disposition", "attachment; filename="+name)
+ ctx.Res.Header().Set("Content-Transfer-Encoding", "binary")
+ ctx.Res.Header().Set("Expires", "0")
+ ctx.Res.Header().Set("Cache-Control", "must-revalidate")
+ ctx.Res.Header().Set("Pragma", "public")
+ http.ServeFile(ctx.Res, ctx.Req, file)
+}
+
+func (ctx *Context) ServeContent(name string, r io.ReadSeeker, params ...interface{}) {
+ modtime := time.Now()
+ for _, p := range params {
+ switch v := p.(type) {
+ case time.Time:
+ modtime = v
+ }
+ }
+ ctx.Res.Header().Set("Content-Description", "File Transfer")
+ ctx.Res.Header().Set("Content-Type", "application/octet-stream")
+ ctx.Res.Header().Set("Content-Disposition", "attachment; filename="+name)
+ ctx.Res.Header().Set("Content-Transfer-Encoding", "binary")
+ ctx.Res.Header().Set("Expires", "0")
+ ctx.Res.Header().Set("Cache-Control", "must-revalidate")
+ ctx.Res.Header().Set("Pragma", "public")
+ http.ServeContent(ctx.Res, ctx.Req, name, modtime, r)
+}
+
+type Flash struct {
+ url.Values
+ ErrorMsg, SuccessMsg string
+}
+
+func (f *Flash) Error(msg string) {
+ f.Set("error", msg)
+ f.ErrorMsg = msg
+}
+
+func (f *Flash) Success(msg string) {
+ f.Set("success", msg)
+ f.SuccessMsg = msg
+}
+
// InitContext initializes a classic context for a request.
func InitContext() martini.Handler {
return func(res http.ResponseWriter, r *http.Request, c martini.Context, rd *Render) {
@@ -256,9 +312,27 @@ func InitContext() martini.Handler {
// start session
ctx.Session = base.SessionManager.SessionStart(res, r)
+
+ // Get flash.
+ values, err := url.ParseQuery(ctx.GetCookie("gogs_flash"))
+ if err != nil {
+ log.Error("InitContext.ParseQuery(flash): %v", err)
+ } else if len(values) > 0 {
+ ctx.Flash = &Flash{Values: values}
+ ctx.Flash.ErrorMsg = ctx.Flash.Get("error")
+ ctx.Flash.SuccessMsg = ctx.Flash.Get("success")
+ ctx.Data["Flash"] = ctx.Flash
+ ctx.SetCookie("gogs_flash", "", -1)
+ }
+ ctx.Flash = &Flash{Values: url.Values{}}
+
rw := res.(martini.ResponseWriter)
rw.Before(func(martini.ResponseWriter) {
ctx.Session.SessionRelease(res)
+
+ if flash := ctx.Flash.Encode(); len(flash) > 0 {
+ ctx.SetCookie("gogs_flash", ctx.Flash.Encode(), 0)
+ }
})
// Get user from session if logined.
diff --git a/modules/middleware/render.go b/modules/middleware/render.go
index 98d485af..66289988 100644
--- a/modules/middleware/render.go
+++ b/modules/middleware/render.go
@@ -146,7 +146,7 @@ func compile(options RenderOptions) *template.Template {
tmpl := t.New(filepath.ToSlash(name))
for _, funcs := range options.Funcs {
- tmpl.Funcs(funcs)
+ tmpl = tmpl.Funcs(funcs)
}
template.Must(tmpl.Funcs(helperFuncs).Parse(string(buf)))
diff --git a/modules/middleware/repo.go b/modules/middleware/repo.go
index 2139742c..34144fe3 100644
--- a/modules/middleware/repo.go
+++ b/modules/middleware/repo.go
@@ -15,6 +15,7 @@ import (
"github.com/gogits/gogs/models"
"github.com/gogits/gogs/modules/base"
+ "github.com/gogits/gogs/modules/log"
)
func RepoAssignment(redirect bool, args ...bool) martini.Handler {
@@ -39,7 +40,7 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler {
userName := params["username"]
repoName := params["reponame"]
- branchName := params["branchname"]
+ refName := params["branchname"]
// get repository owner
ctx.Repo.IsOwner = ctx.IsSigned && ctx.User.LowerName == strings.ToLower(userName)
@@ -66,34 +67,69 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler {
ctx.Handle(200, "RepoAssignment", errors.New("invliad user account for single repository"))
return
}
+ ctx.Repo.Owner = user
// get repository
repo, err := models.GetRepositoryByName(user.Id, repoName)
if err != nil {
if err == models.ErrRepoNotExist {
ctx.Handle(404, "RepoAssignment", err)
+ return
} else if redirect {
ctx.Redirect("/")
return
}
- ctx.Handle(404, "RepoAssignment", err)
+ ctx.Handle(500, "RepoAssignment", err)
return
}
+
+ // Check access.
+ if repo.IsPrivate {
+ if ctx.User == nil {
+ ctx.Handle(404, "RepoAssignment(HasAccess)", nil)
+ return
+ }
+
+ hasAccess, err := models.HasAccess(ctx.User.Name, ctx.Repo.Owner.Name+"/"+repo.Name, models.AU_READABLE)
+ if err != nil {
+ ctx.Handle(500, "RepoAssignment(HasAccess)", err)
+ return
+ } else if !hasAccess {
+ ctx.Handle(404, "RepoAssignment(HasAccess)", nil)
+ return
+ }
+ }
+ ctx.Repo.HasAccess = true
+ ctx.Data["HasAccess"] = true
+
+ if repo.IsMirror {
+ ctx.Repo.Mirror, err = models.GetMirror(repo.Id)
+ if err != nil {
+ ctx.Handle(500, "RepoAssignment(GetMirror)", err)
+ return
+ }
+ ctx.Data["MirrorInterval"] = ctx.Repo.Mirror.Interval
+ }
+
repo.NumOpenIssues = repo.NumIssues - repo.NumClosedIssues
ctx.Repo.Repository = repo
-
ctx.Data["IsBareRepo"] = ctx.Repo.Repository.IsBare
gitRepo, err := git.OpenRepository(models.RepoPath(userName, repoName))
if err != nil {
- ctx.Handle(404, "RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err)
+ ctx.Handle(500, "RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err)
return
}
ctx.Repo.GitRepo = gitRepo
-
- ctx.Repo.Owner = user
ctx.Repo.RepoLink = "/" + user.Name + "/" + repo.Name
+ tags, err := ctx.Repo.GitRepo.GetTags()
+ if err != nil {
+ ctx.Handle(500, "RepoAssignment(GetTags))", err)
+ return
+ }
+ ctx.Repo.Repository.NumTags = len(tags)
+
ctx.Data["Title"] = user.Name + "/" + repo.Name
ctx.Data["Repository"] = repo
ctx.Data["Owner"] = user
@@ -105,29 +141,43 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler {
ctx.Repo.CloneLink.HTTPS = fmt.Sprintf("%s%s/%s.git", base.AppUrl, user.LowerName, repo.LowerName)
ctx.Data["CloneLink"] = ctx.Repo.CloneLink
+ if ctx.Repo.Repository.IsGoget {
+ ctx.Data["GoGetLink"] = fmt.Sprintf("%s%s/%s", base.AppUrl, user.LowerName, repo.LowerName)
+ ctx.Data["GoGetImport"] = fmt.Sprintf("%s/%s/%s", base.Domain, user.LowerName, repo.LowerName)
+ }
+
// when repo is bare, not valid branch
if !ctx.Repo.Repository.IsBare && validBranch {
detect:
- if len(branchName) > 0 {
- // TODO check tag
- if models.IsBranchExist(user.Name, repoName, branchName) {
+ if len(refName) > 0 {
+ if gitRepo.IsBranchExist(refName) {
ctx.Repo.IsBranch = true
- ctx.Repo.BranchName = branchName
+ ctx.Repo.BranchName = refName
- ctx.Repo.Commit, err = gitRepo.GetCommitOfBranch(branchName)
+ ctx.Repo.Commit, err = gitRepo.GetCommitOfBranch(refName)
if err != nil {
ctx.Handle(404, "RepoAssignment invalid branch", nil)
return
}
+ ctx.Repo.CommitId = ctx.Repo.Commit.Id.String()
- ctx.Repo.CommitId = ctx.Repo.Commit.Oid.String()
+ } else if gitRepo.IsTagExist(refName) {
+ ctx.Repo.IsBranch = true
+ ctx.Repo.BranchName = refName
- } else if len(branchName) == 40 {
+ ctx.Repo.Commit, err = gitRepo.GetCommitOfTag(refName)
+ if err != nil {
+ ctx.Handle(404, "RepoAssignment invalid tag", nil)
+ return
+ }
+ ctx.Repo.CommitId = ctx.Repo.Commit.Id.String()
+
+ } else if len(refName) == 40 {
ctx.Repo.IsCommit = true
- ctx.Repo.CommitId = branchName
- ctx.Repo.BranchName = branchName
+ ctx.Repo.CommitId = refName
+ ctx.Repo.BranchName = refName
- ctx.Repo.Commit, err = gitRepo.GetCommit(branchName)
+ ctx.Repo.Commit, err = gitRepo.GetCommit(refName)
if err != nil {
ctx.Handle(404, "RepoAssignment invalid commit", nil)
return
@@ -138,16 +188,23 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler {
}
} else {
- branchName = "master"
+ refName = ctx.Repo.Repository.DefaultBranch
+ if len(refName) == 0 {
+ refName = "master"
+ }
goto detect
}
ctx.Data["IsBranch"] = ctx.Repo.IsBranch
ctx.Data["IsCommit"] = ctx.Repo.IsCommit
+ log.Debug("Repo.Commit: %v", ctx.Repo.Commit)
}
+ log.Debug("displayBare: %v; IsBare: %v", displayBare, ctx.Repo.Repository.IsBare)
+
// repo is bare and display enable
if displayBare && ctx.Repo.Repository.IsBare {
+ log.Debug("Bare repository: %s", ctx.Repo.RepoLink)
ctx.HTML(200, "repo/single_bare")
return
}
@@ -157,6 +214,11 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler {
}
ctx.Data["BranchName"] = ctx.Repo.BranchName
+ brs, err := ctx.Repo.GitRepo.GetBranches()
+ if err != nil {
+ log.Error("RepoAssignment(GetBranches): %v", err)
+ }
+ ctx.Data["Branches"] = brs
ctx.Data["CommitId"] = ctx.Repo.CommitId
ctx.Data["IsRepositoryWatching"] = ctx.Repo.IsWatching
}
diff --git a/modules/social/social.go b/modules/social/social.go
new file mode 100644
index 00000000..a628bfe6
--- /dev/null
+++ b/modules/social/social.go
@@ -0,0 +1,396 @@
+// Copyright 2014 Google Inc. All Rights Reserved.
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package social
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+
+ oauth "github.com/gogits/oauth2"
+
+ "github.com/gogits/gogs/models"
+ "github.com/gogits/gogs/modules/base"
+ "github.com/gogits/gogs/modules/log"
+)
+
+type BasicUserInfo struct {
+ Identity string
+ Name string
+ Email string
+}
+
+type SocialConnector interface {
+ Type() int
+ SetRedirectUrl(string)
+ UserInfo(*oauth.Token, *url.URL) (*BasicUserInfo, error)
+
+ AuthCodeURL(string) string
+ Exchange(string) (*oauth.Token, error)
+}
+
+var (
+ SocialBaseUrl = "/user/login"
+ SocialMap = make(map[string]SocialConnector)
+)
+
+func NewOauthService() {
+ if !base.Cfg.MustBool("oauth", "ENABLED") {
+ return
+ }
+
+ base.OauthService = &base.Oauther{}
+ base.OauthService.OauthInfos = make(map[string]*base.OauthInfo)
+
+ socialConfigs := make(map[string]*oauth.Config)
+ allOauthes := []string{"github", "google", "qq", "twitter", "weibo"}
+ // Load all OAuth config data.
+ for _, name := range allOauthes {
+ base.OauthService.OauthInfos[name] = &base.OauthInfo{
+ ClientId: base.Cfg.MustValue("oauth."+name, "CLIENT_ID"),
+ ClientSecret: base.Cfg.MustValue("oauth."+name, "CLIENT_SECRET"),
+ Scopes: base.Cfg.MustValue("oauth."+name, "SCOPES"),
+ AuthUrl: base.Cfg.MustValue("oauth."+name, "AUTH_URL"),
+ TokenUrl: base.Cfg.MustValue("oauth."+name, "TOKEN_URL"),
+ }
+ socialConfigs[name] = &oauth.Config{
+ ClientId: base.OauthService.OauthInfos[name].ClientId,
+ ClientSecret: base.OauthService.OauthInfos[name].ClientSecret,
+ RedirectURL: strings.TrimSuffix(base.AppUrl, "/") + SocialBaseUrl + name,
+ Scope: base.OauthService.OauthInfos[name].Scopes,
+ AuthURL: base.OauthService.OauthInfos[name].AuthUrl,
+ TokenURL: base.OauthService.OauthInfos[name].TokenUrl,
+ }
+ }
+
+ enabledOauths := make([]string, 0, 10)
+
+ // GitHub.
+ if base.Cfg.MustBool("oauth.github", "ENABLED") {
+ base.OauthService.GitHub = true
+ newGitHubOauth(socialConfigs["github"])
+ enabledOauths = append(enabledOauths, "GitHub")
+ }
+
+ // Google.
+ if base.Cfg.MustBool("oauth.google", "ENABLED") {
+ base.OauthService.Google = true
+ newGoogleOauth(socialConfigs["google"])
+ enabledOauths = append(enabledOauths, "Google")
+ }
+
+ // QQ.
+ if base.Cfg.MustBool("oauth.qq", "ENABLED") {
+ base.OauthService.Tencent = true
+ newTencentOauth(socialConfigs["qq"])
+ enabledOauths = append(enabledOauths, "QQ")
+ }
+
+ // Twitter.
+ if base.Cfg.MustBool("oauth.twitter", "ENABLED") {
+ base.OauthService.Twitter = true
+ newTwitterOauth(socialConfigs["twitter"])
+ enabledOauths = append(enabledOauths, "Twitter")
+ }
+
+ // Weibo.
+ if base.Cfg.MustBool("oauth.weibo", "ENABLED") {
+ base.OauthService.Weibo = true
+ newWeiboOauth(socialConfigs["weibo"])
+ enabledOauths = append(enabledOauths, "Weibo")
+ }
+
+ log.Info("Oauth Service Enabled %s", enabledOauths)
+}
+
+// ________.__ __ ___ ___ ___.
+// / _____/|__|/ |_ / | \ __ _\_ |__
+// / \ ___| \ __\/ ~ \ | \ __ \
+// \ \_\ \ || | \ Y / | / \_\ \
+// \______ /__||__| \___|_ /|____/|___ /
+// \/ \/ \/
+
+type SocialGithub struct {
+ Token *oauth.Token
+ *oauth.Transport
+}
+
+func (s *SocialGithub) Type() int {
+ return models.OT_GITHUB
+}
+
+func newGitHubOauth(config *oauth.Config) {
+ SocialMap["github"] = &SocialGithub{
+ Transport: &oauth.Transport{
+ Config: config,
+ Transport: http.DefaultTransport,
+ },
+ }
+}
+
+func (s *SocialGithub) SetRedirectUrl(url string) {
+ s.Transport.Config.RedirectURL = url
+}
+
+func (s *SocialGithub) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) {
+ transport := &oauth.Transport{
+ Token: token,
+ }
+ var data struct {
+ Id int `json:"id"`
+ Name string `json:"login"`
+ Email string `json:"email"`
+ }
+ var err error
+ r, err := transport.Client().Get(s.Transport.Scope)
+ if err != nil {
+ return nil, err
+ }
+ defer r.Body.Close()
+ if err = json.NewDecoder(r.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+ return &BasicUserInfo{
+ Identity: strconv.Itoa(data.Id),
+ Name: data.Name,
+ Email: data.Email,
+ }, nil
+}
+
+// ________ .__
+// / _____/ ____ ____ ____ | | ____
+// / \ ___ / _ \ / _ \ / ___\| | _/ __ \
+// \ \_\ ( <_> | <_> ) /_/ > |_\ ___/
+// \______ /\____/ \____/\___ /|____/\___ >
+// \/ /_____/ \/
+
+type SocialGoogle struct {
+ Token *oauth.Token
+ *oauth.Transport
+}
+
+func (s *SocialGoogle) Type() int {
+ return models.OT_GOOGLE
+}
+
+func newGoogleOauth(config *oauth.Config) {
+ SocialMap["google"] = &SocialGoogle{
+ Transport: &oauth.Transport{
+ Config: config,
+ Transport: http.DefaultTransport,
+ },
+ }
+}
+
+func (s *SocialGoogle) SetRedirectUrl(url string) {
+ s.Transport.Config.RedirectURL = url
+}
+
+func (s *SocialGoogle) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) {
+ transport := &oauth.Transport{Token: token}
+ var data struct {
+ Id string `json:"id"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ }
+ var err error
+
+ reqUrl := "https://www.googleapis.com/oauth2/v1/userinfo"
+ r, err := transport.Client().Get(reqUrl)
+ if err != nil {
+ return nil, err
+ }
+ defer r.Body.Close()
+ if err = json.NewDecoder(r.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+ return &BasicUserInfo{
+ Identity: data.Id,
+ Name: data.Name,
+ Email: data.Email,
+ }, nil
+}
+
+// ________ ________
+// \_____ \ \_____ \
+// / / \ \ / / \ \
+// / \_/. \/ \_/. \
+// \_____\ \_/\_____\ \_/
+// \__> \__>
+
+type SocialTencent struct {
+ Token *oauth.Token
+ *oauth.Transport
+ reqUrl string
+}
+
+func (s *SocialTencent) Type() int {
+ return models.OT_QQ
+}
+
+func newTencentOauth(config *oauth.Config) {
+ SocialMap["qq"] = &SocialTencent{
+ reqUrl: "https://open.t.qq.com/api/user/info",
+ Transport: &oauth.Transport{
+ Config: config,
+ Transport: http.DefaultTransport,
+ },
+ }
+}
+
+func (s *SocialTencent) SetRedirectUrl(url string) {
+ s.Transport.Config.RedirectURL = url
+}
+
+func (s *SocialTencent) UserInfo(token *oauth.Token, URL *url.URL) (*BasicUserInfo, error) {
+ var data struct {
+ Data struct {
+ Id string `json:"openid"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ } `json:"data"`
+ }
+ var err error
+ // https://open.t.qq.com/api/user/info?
+ //oauth_consumer_key=APP_KEY&
+ //access_token=ACCESSTOKEN&openid=openid
+ //clientip=CLIENTIP&oauth_version=2.a
+ //scope=all
+ var urls = url.Values{
+ "oauth_consumer_key": {s.Transport.Config.ClientId},
+ "access_token": {token.AccessToken},
+ "openid": URL.Query()["openid"],
+ "oauth_version": {"2.a"},
+ "scope": {"all"},
+ }
+ r, err := http.Get(s.reqUrl + "?" + urls.Encode())
+ if err != nil {
+ return nil, err
+ }
+ defer r.Body.Close()
+ if err = json.NewDecoder(r.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+ return &BasicUserInfo{
+ Identity: data.Data.Id,
+ Name: data.Data.Name,
+ Email: data.Data.Email,
+ }, nil
+}
+
+// ___________ .__ __ __
+// \__ ___/_ _ _|__|/ |__/ |_ ___________
+// | | \ \/ \/ / \ __\ __\/ __ \_ __ \
+// | | \ /| || | | | \ ___/| | \/
+// |____| \/\_/ |__||__| |__| \___ >__|
+// \/
+
+type SocialTwitter struct {
+ Token *oauth.Token
+ *oauth.Transport
+}
+
+func (s *SocialTwitter) Type() int {
+ return models.OT_TWITTER
+}
+
+func newTwitterOauth(config *oauth.Config) {
+ SocialMap["twitter"] = &SocialTwitter{
+ Transport: &oauth.Transport{
+ Config: config,
+ Transport: http.DefaultTransport,
+ },
+ }
+}
+
+func (s *SocialTwitter) SetRedirectUrl(url string) {
+ s.Transport.Config.RedirectURL = url
+}
+
+//https://github.com/mrjones/oauth
+func (s *SocialTwitter) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) {
+ // transport := &oauth.Transport{Token: token}
+ // var data struct {
+ // Id string `json:"id"`
+ // Name string `json:"name"`
+ // Email string `json:"email"`
+ // }
+ // var err error
+
+ // reqUrl := "https://www.googleapis.com/oauth2/v1/userinfo"
+ // r, err := transport.Client().Get(reqUrl)
+ // if err != nil {
+ // return nil, err
+ // }
+ // defer r.Body.Close()
+ // if err = json.NewDecoder(r.Body).Decode(&data); err != nil {
+ // return nil, err
+ // }
+ // return &BasicUserInfo{
+ // Identity: data.Id,
+ // Name: data.Name,
+ // Email: data.Email,
+ // }, nil
+ return nil, nil
+}
+
+// __ __ ._____.
+// / \ / \ ____ |__\_ |__ ____
+// \ \/\/ // __ \| || __ \ / _ \
+// \ /\ ___/| || \_\ ( <_> )
+// \__/\ / \___ >__||___ /\____/
+// \/ \/ \/
+
+type SocialWeibo struct {
+ Token *oauth.Token
+ *oauth.Transport
+}
+
+func (s *SocialWeibo) Type() int {
+ return models.OT_WEIBO
+}
+
+func newWeiboOauth(config *oauth.Config) {
+ SocialMap["weibo"] = &SocialWeibo{
+ Transport: &oauth.Transport{
+ Config: config,
+ Transport: http.DefaultTransport,
+ },
+ }
+}
+
+func (s *SocialWeibo) SetRedirectUrl(url string) {
+ s.Transport.Config.RedirectURL = url
+}
+
+func (s *SocialWeibo) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) {
+ transport := &oauth.Transport{Token: token}
+ var data struct {
+ Name string `json:"name"`
+ }
+ var err error
+
+ var urls = url.Values{
+ "access_token": {token.AccessToken},
+ "uid": {token.Extra["id_token"]},
+ }
+ reqUrl := "https://api.weibo.com/2/users/show.json"
+ r, err := transport.Client().Get(reqUrl + "?" + urls.Encode())
+ if err != nil {
+ return nil, err
+ }
+ defer r.Body.Close()
+ if err = json.NewDecoder(r.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+ return &BasicUserInfo{
+ Identity: token.Extra["id_token"],
+ Name: data.Name,
+ }, nil
+ return nil, nil
+}