diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/auth/admin.go | 4 | ||||
-rw-r--r-- | modules/auth/auth.go | 24 | ||||
-rw-r--r-- | modules/auth/issue.go | 4 | ||||
-rw-r--r-- | modules/auth/release.go | 50 | ||||
-rw-r--r-- | modules/auth/repo.go | 46 | ||||
-rw-r--r-- | modules/auth/setting.go | 4 | ||||
-rw-r--r-- | modules/auth/user.go | 5 | ||||
-rw-r--r-- | modules/avatar/avatar.go | 5 | ||||
-rw-r--r-- | modules/base/base.go | 48 | ||||
-rw-r--r-- | modules/base/base_memcache.go | 11 | ||||
-rw-r--r-- | modules/base/base_redis.go | 11 | ||||
-rw-r--r-- | modules/base/conf.go | 112 | ||||
-rw-r--r-- | modules/base/markdown.go | 57 | ||||
-rw-r--r-- | modules/base/template.go | 136 | ||||
-rw-r--r-- | modules/base/tool.go | 149 | ||||
-rw-r--r-- | modules/cron/cron.go | 17 | ||||
-rw-r--r-- | modules/log/log.go | 3 | ||||
-rw-r--r-- | modules/mailer/mail.go | 53 | ||||
-rw-r--r-- | modules/middleware/auth.go | 2 | ||||
-rw-r--r-- | modules/middleware/binding.go | 426 | ||||
-rw-r--r-- | modules/middleware/binding_test.go | 701 | ||||
-rw-r--r-- | modules/middleware/context.go | 86 | ||||
-rw-r--r-- | modules/middleware/render.go | 2 | ||||
-rw-r--r-- | modules/middleware/repo.go | 96 | ||||
-rw-r--r-- | modules/social/social.go | 396 |
25 files changed, 2225 insertions, 223 deletions
diff --git a/modules/auth/admin.go b/modules/auth/admin.go index fe889c23..877af19a 100644 --- a/modules/auth/admin.go +++ b/modules/auth/admin.go @@ -10,8 +10,6 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" - "github.com/gogits/gogs/modules/base" "github.com/gogits/gogs/modules/log" ) @@ -35,7 +33,7 @@ func (f *AdminEditUserForm) Name(field string) string { return names[field] } -func (f *AdminEditUserForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *AdminEditUserForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } diff --git a/modules/auth/auth.go b/modules/auth/auth.go index 4561dd83..350ef4fc 100644 --- a/modules/auth/auth.go +++ b/modules/auth/auth.go @@ -11,8 +11,6 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" - "github.com/gogits/gogs/modules/base" "github.com/gogits/gogs/modules/log" ) @@ -39,7 +37,7 @@ func (f *RegisterForm) Name(field string) string { return names[field] } -func (f *RegisterForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *RegisterForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } @@ -72,7 +70,7 @@ func (f *LogInForm) Name(field string) string { return names[field] } -func (f *LogInForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *LogInForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } @@ -100,7 +98,7 @@ func getMinMaxSize(field reflect.StructField) string { return "" } -func validate(errors *binding.Errors, data base.TmplData, form Form) { +func validate(errors *base.BindingErrors, data base.TmplData, form Form) { typ := reflect.TypeOf(form) val := reflect.ValueOf(form) @@ -121,16 +119,18 @@ func validate(errors *binding.Errors, data base.TmplData, form Form) { if err, ok := errors.Fields[field.Name]; ok { data["Err_"+field.Name] = true switch err { - case binding.RequireError: + case base.BindingRequireError: data["ErrorMsg"] = form.Name(field.Name) + " cannot be empty" - case binding.AlphaDashError: + case base.BindingAlphaDashError: data["ErrorMsg"] = form.Name(field.Name) + " must be valid alpha or numeric or dash(-_) characters" - case binding.MinSizeError: + case base.BindingMinSizeError: data["ErrorMsg"] = form.Name(field.Name) + " must contain at least " + getMinMaxSize(field) + " characters" - case binding.MaxSizeError: + case base.BindingMaxSizeError: data["ErrorMsg"] = form.Name(field.Name) + " must contain at most " + getMinMaxSize(field) + " characters" - case binding.EmailError: - data["ErrorMsg"] = form.Name(field.Name) + " is not valid" + case base.BindingEmailError: + data["ErrorMsg"] = form.Name(field.Name) + " is not a valid e-mail address" + case base.BindingUrlError: + data["ErrorMsg"] = form.Name(field.Name) + " is not a valid URL" default: data["ErrorMsg"] = "Unknown error: " + err } @@ -194,7 +194,7 @@ func (f *InstallForm) Name(field string) string { return names[field] } -func (f *InstallForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *InstallForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } diff --git a/modules/auth/issue.go b/modules/auth/issue.go index 36c87627..f73ddc74 100644 --- a/modules/auth/issue.go +++ b/modules/auth/issue.go @@ -10,8 +10,6 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" - "github.com/gogits/gogs/modules/base" "github.com/gogits/gogs/modules/log" ) @@ -31,7 +29,7 @@ func (f *CreateIssueForm) Name(field string) string { return names[field] } -func (f *CreateIssueForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *CreateIssueForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } diff --git a/modules/auth/release.go b/modules/auth/release.go new file mode 100644 index 00000000..a29028e0 --- /dev/null +++ b/modules/auth/release.go @@ -0,0 +1,50 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + + "github.com/go-martini/martini" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type NewReleaseForm struct { + TagName string `form:"tag_name" binding:"Required"` + Title string `form:"title" binding:"Required"` + Content string `form:"content" binding:"Required"` + Prerelease bool `form:"prerelease"` +} + +func (f *NewReleaseForm) Name(field string) string { + names := map[string]string{ + "TagName": "Tag name", + "Title": "Release title", + "Content": "Release content", + } + return names[field] +} + +func (f *NewReleaseForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("NewReleaseForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/repo.go b/modules/auth/repo.go index eddd6475..f67fbf67 100644 --- a/modules/auth/repo.go +++ b/modules/auth/repo.go @@ -10,19 +10,17 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" - "github.com/gogits/gogs/modules/base" "github.com/gogits/gogs/modules/log" ) type CreateRepoForm struct { RepoName string `form:"repo" binding:"Required;AlphaDash"` - Visibility string `form:"visibility"` + Private bool `form:"private"` Description string `form:"desc" binding:"MaxSize(100)"` Language string `form:"language"` License string `form:"license"` - InitReadme string `form:"initReadme"` + InitReadme bool `form:"initReadme"` } func (f *CreateRepoForm) Name(field string) string { @@ -33,7 +31,7 @@ func (f *CreateRepoForm) Name(field string) string { return names[field] } -func (f *CreateRepoForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *CreateRepoForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } @@ -51,3 +49,41 @@ func (f *CreateRepoForm) Validate(errors *binding.Errors, req *http.Request, con validate(errors, data, f) } + +type MigrateRepoForm struct { + Url string `form:"url" binding:"Url"` + AuthUserName string `form:"auth_username"` + AuthPasswd string `form:"auth_password"` + RepoName string `form:"repo" binding:"Required;AlphaDash"` + Mirror bool `form:"mirror"` + Private bool `form:"private"` + Description string `form:"desc" binding:"MaxSize(100)"` +} + +func (f *MigrateRepoForm) Name(field string) string { + names := map[string]string{ + "Url": "Migration URL", + "RepoName": "Repository name", + "Description": "Description", + } + return names[field] +} + +func (f *MigrateRepoForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("MigrateRepoForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/setting.go b/modules/auth/setting.go index cada7eea..7cee00de 100644 --- a/modules/auth/setting.go +++ b/modules/auth/setting.go @@ -11,8 +11,6 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" - "github.com/gogits/gogs/modules/base" "github.com/gogits/gogs/modules/log" ) @@ -30,7 +28,7 @@ func (f *AddSSHKeyForm) Name(field string) string { return names[field] } -func (f *AddSSHKeyForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *AddSSHKeyForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) AssignForm(f, data) diff --git a/modules/auth/user.go b/modules/auth/user.go index 015059f7..97389422 100644 --- a/modules/auth/user.go +++ b/modules/auth/user.go @@ -10,7 +10,6 @@ import ( "github.com/go-martini/martini" - "github.com/gogits/binding" "github.com/gogits/session" "github.com/gogits/gogs/models" @@ -93,7 +92,7 @@ func (f *UpdateProfileForm) Name(field string) string { return names[field] } -func (f *UpdateProfileForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *UpdateProfileForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } @@ -126,7 +125,7 @@ func (f *UpdatePasswdForm) Name(field string) string { return names[field] } -func (f *UpdatePasswdForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { +func (f *UpdatePasswdForm) Validate(errors *base.BindingErrors, req *http.Request, context martini.Context) { if req.Method == "GET" || errors.Count() == 0 { return } diff --git a/modules/avatar/avatar.go b/modules/avatar/avatar.go index edeb256f..5ed5d16a 100644 --- a/modules/avatar/avatar.go +++ b/modules/avatar/avatar.go @@ -157,9 +157,9 @@ func (this *service) ServeHTTP(w http.ResponseWriter, r *http.Request) { avatar := New(hash, this.cacheDir) avatar.AlterImage = this.altImage if avatar.Expired() { - err := avatar.UpdateTimeout(time.Millisecond * 500) - if err != nil { + if err := avatar.UpdateTimeout(time.Millisecond * 1000); err != nil { log.Trace("avatar update error: %v", err) + return } } if modtime, err := avatar.Modtime(); err == nil { @@ -250,6 +250,7 @@ func (this *thunderTask) Fetch() { var client = &http.Client{} func (this *thunderTask) fetch() error { + log.Debug("avatar.fetch(fetch new avatar): %s", this.Url) req, _ := http.NewRequest("GET", this.Url, nil) req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/jpeg,image/png,*/*;q=0.8") req.Header.Set("Accept-Encoding", "deflate,sdch") diff --git a/modules/base/base.go b/modules/base/base.go index 7c08dcc5..3e80a436 100644 --- a/modules/base/base.go +++ b/modules/base/base.go @@ -8,3 +8,51 @@ type ( // Type TmplData represents data in the templates. TmplData map[string]interface{} ) + +// __________.__ .___.__ +// \______ \__| ____ __| _/|__| ____ ____ +// | | _/ |/ \ / __ | | |/ \ / ___\ +// | | \ | | \/ /_/ | | | | \/ /_/ > +// |______ /__|___| /\____ | |__|___| /\___ / +// \/ \/ \/ \//_____/ + +// Errors represents the contract of the response body when the +// binding step fails before getting to the application. +type BindingErrors struct { + Overall map[string]string `json:"overall"` + Fields map[string]string `json:"fields"` +} + +// Total errors is the sum of errors with the request overall +// and errors on individual fields. +func (err BindingErrors) Count() int { + return len(err.Overall) + len(err.Fields) +} + +func (this *BindingErrors) Combine(other BindingErrors) { + for key, val := range other.Fields { + if _, exists := this.Fields[key]; !exists { + this.Fields[key] = val + } + } + for key, val := range other.Overall { + if _, exists := this.Overall[key]; !exists { + this.Overall[key] = val + } + } +} + +const ( + BindingRequireError string = "Required" + BindingAlphaDashError string = "AlphaDash" + BindingMinSizeError string = "MinSize" + BindingMaxSizeError string = "MaxSize" + BindingEmailError string = "Email" + BindingUrlError string = "Url" + BindingDeserializationError string = "DeserializationError" + BindingIntegerTypeError string = "IntegerTypeError" + BindingBooleanTypeError string = "BooleanTypeError" + BindingFloatTypeError string = "FloatTypeError" +) + +var GoGetMetas = make(map[string]bool) diff --git a/modules/base/base_memcache.go b/modules/base/base_memcache.go new file mode 100644 index 00000000..93ca93a4 --- /dev/null +++ b/modules/base/base_memcache.go @@ -0,0 +1,11 @@ +// +build memcache + +package base + +import ( + _ "github.com/gogits/cache/memcache" +) + +func init() { + EnableMemcache = true +} diff --git a/modules/base/base_redis.go b/modules/base/base_redis.go new file mode 100644 index 00000000..327af841 --- /dev/null +++ b/modules/base/base_redis.go @@ -0,0 +1,11 @@ +// +build redis + +package base + +import ( + _ "github.com/gogits/cache/redis" +) + +func init() { + EnableRedis = true +} diff --git a/modules/base/conf.go b/modules/base/conf.go index 3ebc4ede..abb67f6d 100644 --- a/modules/base/conf.go +++ b/modules/base/conf.go @@ -14,6 +14,7 @@ import ( "github.com/Unknwon/com" "github.com/Unknwon/goconfig" + qlog "github.com/qiniu/log" "github.com/gogits/cache" "github.com/gogits/session" @@ -21,22 +22,38 @@ import ( "github.com/gogits/gogs/modules/log" ) -// Mailer represents a mail service. +// Mailer represents mail service. type Mailer struct { Name string Host string User, Passwd string } +type OauthInfo struct { + ClientId, ClientSecret string + Scopes string + AuthUrl, TokenUrl string +} + +// Oauther represents oauth service. +type Oauther struct { + GitHub, Google, Tencent, + Twitter, Weibo bool + OauthInfos map[string]*OauthInfo +} + var ( - AppVer string - AppName string - AppLogo string - AppUrl string - Domain string - SecretKey string - RunUser string + AppVer string + AppName string + AppLogo string + AppUrl string + IsProdMode bool + Domain string + SecretKey string + RunUser string + RepoRootPath string + ScriptType string InstallLock bool @@ -44,8 +61,9 @@ var ( CookieUserName string CookieRememberName string - Cfg *goconfig.ConfigFile - MailService *Mailer + Cfg *goconfig.ConfigFile + MailService *Mailer + OauthService *Oauther LogMode string LogConfig string @@ -59,11 +77,14 @@ var ( SessionManager *session.Manager PictureService string + + EnableRedis bool + EnableMemcache bool ) var Service struct { RegisterEmailConfirm bool - DisenableRegisteration bool + DisableRegistration bool RequireSignInView bool EnableCacheAvatar bool NotifyMail bool @@ -95,7 +116,7 @@ var logLevels = map[string]string{ func newService() { Service.ActiveCodeLives = Cfg.MustInt("service", "ACTIVE_CODE_LIVE_MINUTES", 180) Service.ResetPwdCodeLives = Cfg.MustInt("service", "RESET_PASSWD_CODE_LIVE_MINUTES", 180) - Service.DisenableRegisteration = Cfg.MustBool("service", "DISENABLE_REGISTERATION", false) + Service.DisableRegistration = Cfg.MustBool("service", "DISABLE_REGISTRATION", false) Service.RequireSignInView = Cfg.MustBool("service", "REQUIRE_SIGNIN_VIEW", false) Service.EnableCacheAvatar = Cfg.MustBool("service", "ENABLE_CACHE_AVATAR", false) } @@ -105,16 +126,14 @@ func newLogService() { LogMode = Cfg.MustValue("log", "MODE", "console") modeSec := "log." + LogMode if _, err := Cfg.GetSection(modeSec); err != nil { - fmt.Printf("Unknown log mode: %s\n", LogMode) - os.Exit(2) + qlog.Fatalf("Unknown log mode: %s\n", LogMode) } // Log level. levelName := Cfg.MustValue("log."+LogMode, "LEVEL", "Trace") level, ok := logLevels[levelName] if !ok { - fmt.Printf("Unknown log level: %s\n", levelName) - os.Exit(2) + qlog.Fatalf("Unknown log level: %s\n", levelName) } // Generate log configuration. @@ -151,12 +170,19 @@ func newLogService() { Cfg.MustValue(modeSec, "CONN")) } + log.Info("%s %s", AppName, AppVer) log.NewLogger(Cfg.MustInt64("log", "BUFFER_LEN", 10000), LogMode, LogConfig) log.Info("Log Mode: %s(%s)", strings.Title(LogMode), levelName) } func newCacheService() { CacheAdapter = Cfg.MustValue("cache", "ADAPTER", "memory") + if EnableRedis { + log.Info("Redis Enabled") + } + if EnableMemcache { + log.Info("Memcache Enabled") + } switch CacheAdapter { case "memory": @@ -164,16 +190,14 @@ func newCacheService() { case "redis", "memcache": CacheConfig = fmt.Sprintf(`{"conn":"%s"}`, Cfg.MustValue("cache", "HOST")) default: - fmt.Printf("Unknown cache adapter: %s\n", CacheAdapter) - os.Exit(2) + qlog.Fatalf("Unknown cache adapter: %s\n", CacheAdapter) } var err error Cache, err = cache.NewCache(CacheAdapter, CacheConfig) if err != nil { - fmt.Printf("Init cache system failed, adapter: %s, config: %s, %v\n", + qlog.Fatalf("Init cache system failed, adapter: %s, config: %s, %v\n", CacheAdapter, CacheConfig, err) - os.Exit(2) } log.Info("Cache Service Enabled") @@ -199,9 +223,8 @@ func newSessionService() { var err error SessionManager, err = session.NewManager(SessionProvider, *SessionConfig) if err != nil { - fmt.Printf("Init session system failed, provider: %s, %v\n", + qlog.Fatalf("Init session system failed, provider: %s, %v\n", SessionProvider, err) - os.Exit(2) } log.Info("Session Service Enabled") @@ -209,15 +232,17 @@ func newSessionService() { func newMailService() { // Check mailer setting. - if Cfg.MustBool("mailer", "ENABLED") { - MailService = &Mailer{ - Name: Cfg.MustValue("mailer", "NAME", AppName), - Host: Cfg.MustValue("mailer", "HOST"), - User: Cfg.MustValue("mailer", "USER"), - Passwd: Cfg.MustValue("mailer", "PASSWD"), - } - log.Info("Mail Service Enabled") + if !Cfg.MustBool("mailer", "ENABLED") { + return + } + + MailService = &Mailer{ + Name: Cfg.MustValue("mailer", "NAME", AppName), + Host: Cfg.MustValue("mailer", "HOST"), + User: Cfg.MustValue("mailer", "USER"), + Passwd: Cfg.MustValue("mailer", "PASSWD"), } + log.Info("Mail Service Enabled") } func newRegisterMailService() { @@ -246,23 +271,20 @@ func NewConfigContext() { //var err error workDir, err := ExecDir() if err != nil { - fmt.Printf("Fail to get work directory: %s\n", err) - os.Exit(2) + qlog.Fatalf("Fail to get work directory: %s\n", err) } cfgPath := filepath.Join(workDir, "conf/app.ini") Cfg, err = goconfig.LoadConfigFile(cfgPath) if err != nil { - fmt.Printf("Cannot load config file(%s): %v\n", cfgPath, err) - os.Exit(2) + qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err) } Cfg.BlockMode = false cfgPath = filepath.Join(workDir, "custom/conf/app.ini") if com.IsFile(cfgPath) { if err = Cfg.AppendFiles(cfgPath); err != nil { - fmt.Printf("Cannot load config file(%s): %v\n", cfgPath, err) - os.Exit(2) + qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err) } } @@ -275,14 +297,13 @@ func NewConfigContext() { InstallLock = Cfg.MustBool("security", "INSTALL_LOCK", false) RunUser = Cfg.MustValue("", "RUN_USER") - curUser := os.Getenv("USERNAME") + curUser := os.Getenv("USER") if len(curUser) == 0 { - curUser = os.Getenv("USER") + curUser = os.Getenv("USERNAME") } // Does not check run user when the install lock is off. if InstallLock && RunUser != curUser { - fmt.Printf("Expect user(%s) but current user is: %s\n", RunUser, curUser) - os.Exit(2) + qlog.Fatalf("Expect user(%s) but current user is: %s\n", RunUser, curUser) } LogInRememberDays = Cfg.MustInt("security", "LOGIN_REMEMBER_DAYS") @@ -294,17 +315,16 @@ func NewConfigContext() { // Determine and create root git reposiroty path. homeDir, err := com.HomeDir() if err != nil { - fmt.Printf("Fail to get home directory): %v\n", err) - os.Exit(2) + qlog.Fatalf("Fail to get home directory): %v\n", err) } - RepoRootPath = Cfg.MustValue("repository", "ROOT", filepath.Join(homeDir, "git/gogs-repositories")) + RepoRootPath = Cfg.MustValue("repository", "ROOT", filepath.Join(homeDir, "gogs-repositories")) if err = os.MkdirAll(RepoRootPath, os.ModePerm); err != nil { - fmt.Printf("Fail to create RepoRootPath(%s): %v\n", RepoRootPath, err) - os.Exit(2) + qlog.Fatalf("Fail to create RepoRootPath(%s): %v\n", RepoRootPath, err) } + ScriptType = Cfg.MustValue("repository", "SCRIPT_TYPE", "bash") } -func NewServices() { +func NewBaseServices() { newService() newLogService() newCacheService() diff --git a/modules/base/markdown.go b/modules/base/markdown.go index 962e1ae1..95b4b212 100644 --- a/modules/base/markdown.go +++ b/modules/base/markdown.go @@ -6,9 +6,11 @@ package base import ( "bytes" + "fmt" "net/http" "path" "path/filepath" + "regexp" "strings" "github.com/gogits/gfm" @@ -87,13 +89,58 @@ func (options *CustomRender) Link(out *bytes.Buffer, link []byte, title []byte, options.Renderer.Link(out, link, title, content) } +var ( + MentionPattern = regexp.MustCompile(`@[0-9a-zA-Z_]{1,}`) + commitPattern = regexp.MustCompile(`(\s|^)https?.*commit/[0-9a-zA-Z]+(#+[0-9a-zA-Z-]*)?`) + issueFullPattern = regexp.MustCompile(`(\s|^)https?.*issues/[0-9]+(#+[0-9a-zA-Z-]*)?`) + issueIndexPattern = regexp.MustCompile(`#[0-9]+`) +) + +func RenderSpecialLink(rawBytes []byte, urlPrefix string) []byte { + ms := MentionPattern.FindAll(rawBytes, -1) + for _, m := range ms { + rawBytes = bytes.Replace(rawBytes, m, + []byte(fmt.Sprintf(`<a href="/user/%s">%s</a>`, m[1:], m)), -1) + } + ms = commitPattern.FindAll(rawBytes, -1) + for _, m := range ms { + m = bytes.TrimSpace(m) + i := strings.Index(string(m), "commit/") + j := strings.Index(string(m), "#") + if j == -1 { + j = len(m) + } + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + ` <code><a href="%s">%s</a></code>`, m, ShortSha(string(m[i+7:j])))), -1) + } + ms = issueFullPattern.FindAll(rawBytes, -1) + for _, m := range ms { + m = bytes.TrimSpace(m) + i := strings.Index(string(m), "issues/") + j := strings.Index(string(m), "#") + if j == -1 { + j = len(m) + } + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + ` <a href="%s">#%s</a>`, m, ShortSha(string(m[i+7:j])))), -1) + } + ms = issueIndexPattern.FindAll(rawBytes, -1) + for _, m := range ms { + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + `<a href="%s/issues/%s">%s</a>`, urlPrefix, m[1:], m)), -1) + } + return rawBytes +} + func RenderMarkdown(rawBytes []byte, urlPrefix string) []byte { + body := RenderSpecialLink(rawBytes, urlPrefix) + // fmt.Println(string(body)) htmlFlags := 0 // htmlFlags |= gfm.HTML_USE_XHTML // htmlFlags |= gfm.HTML_USE_SMARTYPANTS // htmlFlags |= gfm.HTML_SMARTYPANTS_FRACTIONS // htmlFlags |= gfm.HTML_SMARTYPANTS_LATEX_DASHES - htmlFlags |= gfm.HTML_SKIP_HTML + // htmlFlags |= gfm.HTML_SKIP_HTML htmlFlags |= gfm.HTML_SKIP_STYLE htmlFlags |= gfm.HTML_SKIP_SCRIPT htmlFlags |= gfm.HTML_GITHUB_BLOCKCODE @@ -115,7 +162,11 @@ func RenderMarkdown(rawBytes []byte, urlPrefix string) []byte { extensions |= gfm.EXTENSION_SPACE_HEADERS extensions |= gfm.EXTENSION_NO_EMPTY_LINE_BEFORE_BLOCK - body := gfm.Markdown(rawBytes, renderer, extensions) - + body = gfm.Markdown(body, renderer, extensions) + // fmt.Println(string(body)) return body } + +func RenderMarkdownString(raw, urlPrefix string) string { + return string(RenderMarkdown([]byte(raw), urlPrefix)) +} diff --git a/modules/base/template.go b/modules/base/template.go index dfcae931..79aeeb9d 100644 --- a/modules/base/template.go +++ b/modules/base/template.go @@ -5,7 +5,9 @@ package base import ( + "bytes" "container/list" + "encoding/json" "fmt" "html/template" "strings" @@ -54,6 +56,9 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{ "AppDomain": func() string { return Domain }, + "IsProdMode": func() bool { + return IsProdMode + }, "LoadTimes": func(startTime time.Time) string { return fmt.Sprint(time.Since(startTime).Nanoseconds()/1e6) + "ms" }, @@ -62,11 +67,18 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{ "TimeSince": TimeSince, "FileSize": FileSize, "Subtract": Subtract, + "Add": func(a, b int) int { + return a + b + }, "ActionIcon": ActionIcon, "ActionDesc": ActionDesc, "DateFormat": DateFormat, "List": List, "Mail2Domain": func(mail string) string { + if !strings.Contains(mail, "@") { + return "try.gogits.org" + } + suffix := strings.SplitN(mail, "@", 2)[1] domain, ok := mailDomains[suffix] if !ok { @@ -80,4 +92,128 @@ var TemplateFuncs template.FuncMap = map[string]interface{}{ "DiffTypeToStr": DiffTypeToStr, "DiffLineTypeToStr": DiffLineTypeToStr, "ShortSha": ShortSha, + "Oauth2Icon": Oauth2Icon, +} + +type Actioner interface { + GetOpType() int + GetActUserName() string + GetActEmail() string + GetRepoName() string + GetBranch() string + GetContent() string +} + +// ActionIcon accepts a int that represents action operation type +// and returns a icon class name. +func ActionIcon(opType int) string { + switch opType { + case 1: // Create repository. + return "plus-circle" + case 5, 9: // Commit repository. + return "arrow-circle-o-right" + case 6: // Create issue. + return "exclamation-circle" + case 8: // Transfer repository. + return "share" + default: + return "invalid type" + } +} + +const ( + TPL_CREATE_REPO = `<a href="/user/%s">%s</a> created repository <a href="/%s">%s</a>` + TPL_COMMIT_REPO = `<a href="/user/%s">%s</a> pushed to <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>%s` + TPL_COMMIT_REPO_LI = `<div><img src="%s?s=16" alt="user-avatar"/> <a href="/%s/commit/%s">%s</a> %s</div>` + TPL_CREATE_ISSUE = `<a href="/user/%s">%s</a> opened issue <a href="/%s/issues/%s">%s#%s</a> +<div><img src="%s?s=16" alt="user-avatar"/> %s</div>` + TPL_TRANSFER_REPO = `<a href="/user/%s">%s</a> transfered repository <code>%s</code> to <a href="/%s">%s</a>` + TPL_PUSH_TAG = `<a href="/user/%s">%s</a> pushed tag <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>` +) + +type PushCommit struct { + Sha1 string + Message string + AuthorEmail string + AuthorName string +} + +type PushCommits struct { + Len int + Commits []*PushCommit +} + +// ActionDesc accepts int that represents action operation type +// and returns the description. +func ActionDesc(act Actioner) string { + actUserName := act.GetActUserName() + email := act.GetActEmail() + repoName := act.GetRepoName() + repoLink := actUserName + "/" + repoName + branch := act.GetBranch() + content := act.GetContent() + switch act.GetOpType() { + case 1: // Create repository. + return fmt.Sprintf(TPL_CREATE_REPO, actUserName, actUserName, repoLink, repoName) + case 5: // Commit repository. + var push *PushCommits + if err := json.Unmarshal([]byte(content), &push); err != nil { + return err.Error() + } + buf := bytes.NewBuffer([]byte("\n")) + for _, commit := range push.Commits { + buf.WriteString(fmt.Sprintf(TPL_COMMIT_REPO_LI, AvatarLink(commit.AuthorEmail), repoLink, commit.Sha1, commit.Sha1[:7], commit.Message) + "\n") + } + if push.Len > 3 { + buf.WriteString(fmt.Sprintf(`<div><a href="/%s/%s/commits/%s">%d other commits >></a></div>`, actUserName, repoName, branch, push.Len)) + } + return fmt.Sprintf(TPL_COMMIT_REPO, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink, + buf.String()) + case 6: // Create issue. + infos := strings.SplitN(content, "|", 2) + return fmt.Sprintf(TPL_CREATE_ISSUE, actUserName, actUserName, repoLink, infos[0], repoLink, infos[0], + AvatarLink(email), infos[1]) + case 8: // Transfer repository. + newRepoLink := content + "/" + repoName + return fmt.Sprintf(TPL_TRANSFER_REPO, actUserName, actUserName, repoLink, newRepoLink, newRepoLink) + case 9: // Push tag. + return fmt.Sprintf(TPL_PUSH_TAG, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink) + default: + return "invalid type" + } +} + +func DiffTypeToStr(diffType int) string { + diffTypes := map[int]string{ + 1: "add", 2: "modify", 3: "del", + } + return diffTypes[diffType] +} + +func DiffLineTypeToStr(diffType int) string { + switch diffType { + case 2: + return "add" + case 3: + return "del" + case 4: + return "tag" + } + return "same" +} + +func Oauth2Icon(t int) string { + switch t { + case 1: + return "fa-github-square" + case 2: + return "fa-google-plus-square" + case 3: + return "fa-twitter-square" + case 4: + return "fa-linux" + case 5: + return "fa-weibo" + } + return "" } diff --git a/modules/base/tool.go b/modules/base/tool.go index 3946c4b5..9b165b97 100644 --- a/modules/base/tool.go +++ b/modules/base/tool.go @@ -5,13 +5,13 @@ package base import ( - "bytes" + "crypto/hmac" "crypto/md5" "crypto/rand" "crypto/sha1" "encoding/hex" - "encoding/json" "fmt" + "hash" "math" "strconv" "strings" @@ -40,6 +40,44 @@ func GetRandomString(n int, alphabets ...byte) string { return string(bytes) } +// http://code.google.com/p/go/source/browse/pbkdf2/pbkdf2.go?repo=crypto +func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} + // verify time limit code func VerifyTimeLimitCode(data string, minutes int, code string) bool { if len(code) <= 18 { @@ -105,7 +143,7 @@ func AvatarLink(email string) string { if Service.EnableCacheAvatar { return "/avatar/" + EncodeMd5(email) } - return "http://1.gravatar.com/avatar/" + EncodeMd5(email) + return "//1.gravatar.com/avatar/" + EncodeMd5(email) } // Seconds-based time units @@ -246,7 +284,6 @@ func TimeSince(then time.Time) string { default: return fmt.Sprintf("%d years %s", diff/Year, lbl) } - return then.String() } const ( @@ -474,107 +511,3 @@ func (a argInt) Get(i int, args ...int) (r int) { } return } - -type Actioner interface { - GetOpType() int - GetActUserName() string - GetActEmail() string - GetRepoName() string - GetBranch() string - GetContent() string -} - -// ActionIcon accepts a int that represents action operation type -// and returns a icon class name. -func ActionIcon(opType int) string { - switch opType { - case 1: // Create repository. - return "plus-circle" - case 5: // Commit repository. - return "arrow-circle-o-right" - case 6: // Create issue. - return "exclamation-circle" - case 8: // Transfer repository. - return "share" - default: - return "invalid type" - } -} - -const ( - TPL_CREATE_REPO = `<a href="/user/%s">%s</a> created repository <a href="/%s">%s</a>` - TPL_COMMIT_REPO = `<a href="/user/%s">%s</a> pushed to <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>%s` - TPL_COMMIT_REPO_LI = `<div><img src="%s?s=16" alt="user-avatar"/> <a href="/%s/commit/%s">%s</a> %s</div>` - TPL_CREATE_ISSUE = `<a href="/user/%s">%s</a> opened issue <a href="/%s/issues/%s">%s#%s</a> -<div><img src="%s?s=16" alt="user-avatar"/> %s</div>` - TPL_TRANSFER_REPO = `<a href="/user/%s">%s</a> transfered repository <code>%s</code> to <a href="/%s">%s</a>` -) - -type PushCommit struct { - Sha1 string - Message string - AuthorEmail string - AuthorName string -} - -type PushCommits struct { - Len int - Commits []*PushCommit -} - -// ActionDesc accepts int that represents action operation type -// and returns the description. -func ActionDesc(act Actioner) string { - actUserName := act.GetActUserName() - email := act.GetActEmail() - repoName := act.GetRepoName() - repoLink := actUserName + "/" + repoName - branch := act.GetBranch() - content := act.GetContent() - switch act.GetOpType() { - case 1: // Create repository. - return fmt.Sprintf(TPL_CREATE_REPO, actUserName, actUserName, repoLink, repoName) - case 5: // Commit repository. - var push *PushCommits - if err := json.Unmarshal([]byte(content), &push); err != nil { - return err.Error() - } - buf := bytes.NewBuffer([]byte("\n")) - for _, commit := range push.Commits { - buf.WriteString(fmt.Sprintf(TPL_COMMIT_REPO_LI, AvatarLink(commit.AuthorEmail), repoLink, commit.Sha1, commit.Sha1[:7], commit.Message) + "\n") - } - if push.Len > 3 { - buf.WriteString(fmt.Sprintf(`<div><a href="/%s/%s/commits/%s">%d other commits >></a></div>`, actUserName, repoName, branch, push.Len)) - } - return fmt.Sprintf(TPL_COMMIT_REPO, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink, - buf.String()) - case 6: // Create issue. - infos := strings.SplitN(content, "|", 2) - return fmt.Sprintf(TPL_CREATE_ISSUE, actUserName, actUserName, repoLink, infos[0], repoLink, infos[0], - AvatarLink(email), infos[1]) - case 8: // Transfer repository. - newRepoLink := content + "/" + repoName - return fmt.Sprintf(TPL_TRANSFER_REPO, actUserName, actUserName, repoLink, newRepoLink, newRepoLink) - default: - return "invalid type" - } -} - -func DiffTypeToStr(diffType int) string { - diffTypes := map[int]string{ - 1: "add", 2: "modify", 3: "del", - } - return diffTypes[diffType] -} - -func DiffLineTypeToStr(diffType int) string { - switch diffType { - case 2: - return "add" - case 3: - return "del" - case 4: - return "tag" - } - return "same" -} diff --git a/modules/cron/cron.go b/modules/cron/cron.go new file mode 100644 index 00000000..27b1fc41 --- /dev/null +++ b/modules/cron/cron.go @@ -0,0 +1,17 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package cron + +import ( + "github.com/robfig/cron" + + "github.com/gogits/gogs/models" +) + +func NewCronContext() { + c := cron.New() + c.AddFunc("@every 1h", models.MirrorUpdate) + c.Start() +} diff --git a/modules/log/log.go b/modules/log/log.go index 65150237..636ea787 100644 --- a/modules/log/log.go +++ b/modules/log/log.go @@ -21,8 +21,7 @@ func init() { func NewLogger(bufLen int64, mode, config string) { Mode, Config = mode, config logger = logs.NewLogger(bufLen) - logger.EnableFuncCallDepth(true) - logger.SetLogFuncCallDepth(4) + logger.SetLogFuncCallDepth(3) logger.SetLogger(mode, config) } diff --git a/modules/mailer/mail.go b/modules/mailer/mail.go index b99fc8fd..834f4a89 100644 --- a/modules/mailer/mail.go +++ b/modules/mailer/mail.go @@ -86,16 +86,36 @@ func SendActiveMail(r *middleware.Render, user *models.User) { } msg := NewMailMessage([]string{user.Email}, subject, body) - msg.Info = fmt.Sprintf("UID: %d, send email verify mail", user.Id) + msg.Info = fmt.Sprintf("UID: %d, send active mail", user.Id) SendAsync(&msg) } -// SendNotifyMail sends mail notification of all watchers. -func SendNotifyMail(user, owner *models.User, repo *models.Repository, issue *models.Issue) error { +// Send reset password email. +func SendResetPasswdMail(r *middleware.Render, user *models.User) { + code := CreateUserActiveCode(user, nil) + + subject := "Reset your password" + + data := GetMailTmplData(user) + data["Code"] = code + body, err := r.HTMLString("mail/auth/reset_passwd", data) + if err != nil { + log.Error("mail.SendResetPasswdMail(fail to render): %v", err) + return + } + + msg := NewMailMessage([]string{user.Email}, subject, body) + msg.Info = fmt.Sprintf("UID: %d, send reset password email", user.Id) + + SendAsync(&msg) +} + +// SendIssueNotifyMail sends mail notification of all watchers of repository. +func SendIssueNotifyMail(user, owner *models.User, repo *models.Repository, issue *models.Issue) ([]string, error) { watches, err := models.GetWatches(repo.Id) if err != nil { - return errors.New("mail.NotifyWatchers(get watches): " + err.Error()) + return nil, errors.New("mail.NotifyWatchers(get watches): " + err.Error()) } tos := make([]string, 0, len(watches)) @@ -106,20 +126,37 @@ func SendNotifyMail(user, owner *models.User, repo *models.Repository, issue *mo } u, err := models.GetUserById(uid) if err != nil { - return errors.New("mail.NotifyWatchers(get user): " + err.Error()) + return nil, errors.New("mail.NotifyWatchers(get user): " + err.Error()) } tos = append(tos, u.Email) } if len(tos) == 0 { - return nil + return tos, nil } subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name) content := fmt.Sprintf("%s<br>-<br> <a href=\"%s%s/%s/issues/%d\">View it on Gogs</a>.", - issue.Content, base.AppUrl, owner.Name, repo.Name, issue.Index) + base.RenderSpecialLink([]byte(issue.Content), owner.Name+"/"+repo.Name), + base.AppUrl, owner.Name, repo.Name, issue.Index) + msg := NewMailMessageFrom(tos, user.Name, subject, content) + msg.Info = fmt.Sprintf("Subject: %s, send issue notify emails", subject) + SendAsync(&msg) + return tos, nil +} + +// SendIssueMentionMail sends mail notification for who are mentioned in issue. +func SendIssueMentionMail(user, owner *models.User, repo *models.Repository, issue *models.Issue, tos []string) error { + if len(tos) == 0 { + return nil + } + + issueLink := fmt.Sprintf("%s%s/%s/issues/%d", base.AppUrl, owner.Name, repo.Name, issue.Index) + body := fmt.Sprintf(`%s mentioned you.`, user.Name) + subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name) + content := fmt.Sprintf("%s<br>-<br> <a href=\"%s\">View it on Gogs</a>.", body, issueLink) msg := NewMailMessageFrom(tos, user.Name, subject, content) - msg.Info = fmt.Sprintf("Subject: %s, send notify emails", subject) + msg.Info = fmt.Sprintf("Subject: %s, send issue mention emails", subject) SendAsync(&msg) return nil } diff --git a/modules/middleware/auth.go b/modules/middleware/auth.go index bde3be72..39b7796d 100644 --- a/modules/middleware/auth.go +++ b/modules/middleware/auth.go @@ -47,7 +47,7 @@ func Toggle(options *ToggleOptions) martini.Handler { return } else if !ctx.User.IsActive && base.Service.RegisterEmailConfirm { ctx.Data["Title"] = "Activate Your Account" - ctx.HTML(200, "user/active") + ctx.HTML(200, "user/activate") return } } diff --git a/modules/middleware/binding.go b/modules/middleware/binding.go new file mode 100644 index 00000000..cde9ae9c --- /dev/null +++ b/modules/middleware/binding.go @@ -0,0 +1,426 @@ +// Copyright 2013 The Martini Contrib Authors. All rights reserved. +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "regexp" + "strconv" + "strings" + "unicode/utf8" + + "github.com/go-martini/martini" + + "github.com/gogits/gogs/modules/base" +) + +/* + To the land of Middle-ware Earth: + + One func to rule them all, + One func to find them, + One func to bring them all, + And in this package BIND them. +*/ + +// Bind accepts a copy of an empty struct and populates it with +// values from the request (if deserialization is successful). It +// wraps up the functionality of the Form and Json middleware +// according to the Content-Type of the request, and it guesses +// if no Content-Type is specified. Bind invokes the ErrorHandler +// middleware to bail out if errors occurred. If you want to perform +// your own error handling, use Form or Json middleware directly. +// An interface pointer can be added as a second argument in order +// to map the struct to a specific interface. +func Bind(obj interface{}, ifacePtr ...interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "form-urlencoded") { + context.Invoke(Form(obj, ifacePtr...)) + } else if strings.Contains(contentType, "multipart/form-data") { + context.Invoke(MultipartForm(obj, ifacePtr...)) + } else if strings.Contains(contentType, "json") { + context.Invoke(Json(obj, ifacePtr...)) + } else { + context.Invoke(Json(obj, ifacePtr...)) + if getErrors(context).Count() > 0 { + context.Invoke(Form(obj, ifacePtr...)) + } + } + + context.Invoke(ErrorHandler) + } +} + +// BindIgnErr will do the exactly same thing as Bind but without any +// error handling, which user has freedom to deal with them. +// This allows user take advantages of validation. +func BindIgnErr(obj interface{}, ifacePtr ...interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "form-urlencoded") { + context.Invoke(Form(obj, ifacePtr...)) + } else if strings.Contains(contentType, "multipart/form-data") { + context.Invoke(MultipartForm(obj, ifacePtr...)) + } else if strings.Contains(contentType, "json") { + context.Invoke(Json(obj, ifacePtr...)) + } else { + context.Invoke(Json(obj, ifacePtr...)) + if getErrors(context).Count() > 0 { + context.Invoke(Form(obj, ifacePtr...)) + } + } + } +} + +// Form is middleware to deserialize form-urlencoded data from the request. +// It gets data from the form-urlencoded body, if present, or from the +// query string. It uses the http.Request.ParseForm() method +// to perform deserialization, then reflection is used to map each field +// into the struct with the proper type. Structs with primitive slice types +// (bool, float, int, string) can support deserialization of repeated form +// keys, for example: key=val1&key=val2&key=val3 +// An interface pointer can be added as a second argument in order +// to map the struct to a specific interface. +func Form(formStruct interface{}, ifacePtr ...interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + ensureNotPointer(formStruct) + formStruct := reflect.New(reflect.TypeOf(formStruct)) + errors := newErrors() + parseErr := req.ParseForm() + + // Format validation of the request body or the URL would add considerable overhead, + // and ParseForm does not complain when URL encoding is off. + // Because an empty request body or url can also mean absence of all needed values, + // it is not in all cases a bad request, so let's return 422. + if parseErr != nil { + errors.Overall[base.BindingDeserializationError] = parseErr.Error() + } + + mapForm(formStruct, req.Form, errors) + + validateAndMap(formStruct, context, errors, ifacePtr...) + } +} + +func MultipartForm(formStruct interface{}, ifacePtr ...interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + ensureNotPointer(formStruct) + formStruct := reflect.New(reflect.TypeOf(formStruct)) + errors := newErrors() + + // Workaround for multipart forms returning nil instead of an error + // when content is not multipart + // https://code.google.com/p/go/issues/detail?id=6334 + multipartReader, err := req.MultipartReader() + if err != nil { + errors.Overall[base.BindingDeserializationError] = err.Error() + } else { + form, parseErr := multipartReader.ReadForm(MaxMemory) + + if parseErr != nil { + errors.Overall[base.BindingDeserializationError] = parseErr.Error() + } + + req.MultipartForm = form + } + + mapForm(formStruct, req.MultipartForm.Value, errors) + + validateAndMap(formStruct, context, errors, ifacePtr...) + } +} + +// Json is middleware to deserialize a JSON payload from the request +// into the struct that is passed in. The resulting struct is then +// validated, but no error handling is actually performed here. +// An interface pointer can be added as a second argument in order +// to map the struct to a specific interface. +func Json(jsonStruct interface{}, ifacePtr ...interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + ensureNotPointer(jsonStruct) + jsonStruct := reflect.New(reflect.TypeOf(jsonStruct)) + errors := newErrors() + + if req.Body != nil { + defer req.Body.Close() + } + + if err := json.NewDecoder(req.Body).Decode(jsonStruct.Interface()); err != nil && err != io.EOF { + errors.Overall[base.BindingDeserializationError] = err.Error() + } + + validateAndMap(jsonStruct, context, errors, ifacePtr...) + } +} + +// Validate is middleware to enforce required fields. If the struct +// passed in is a Validator, then the user-defined Validate method +// is executed, and its errors are mapped to the context. This middleware +// performs no error handling: it merely detects them and maps them. +func Validate(obj interface{}) martini.Handler { + return func(context martini.Context, req *http.Request) { + errors := newErrors() + validateStruct(errors, obj) + + if validator, ok := obj.(Validator); ok { + validator.Validate(errors, req, context) + } + context.Map(*errors) + } +} + +var ( + alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]") + emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?") + urlPattern = regexp.MustCompile(`(http|https):\/\/[\w\-_]+(\.[\w\-_]+)+([\w\-\.,@?^=%&:/~\+#]*[\w\-\@?^=%&/~\+#])?`) +) + +func validateStruct(errors *base.BindingErrors, obj interface{}) { + typ := reflect.TypeOf(obj) + val := reflect.ValueOf(obj) + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Allow ignored fields in the struct + if field.Tag.Get("form") == "-" { + continue + } + + fieldValue := val.Field(i).Interface() + if field.Type.Kind() == reflect.Struct { + validateStruct(errors, fieldValue) + continue + } + + zero := reflect.Zero(field.Type).Interface() + + // Match rules. + for _, rule := range strings.Split(field.Tag.Get("binding"), ";") { + if len(rule) == 0 { + continue + } + + switch { + case rule == "Required": + if reflect.DeepEqual(zero, fieldValue) { + errors.Fields[field.Name] = base.BindingRequireError + break + } + case rule == "AlphaDash": + if alphaDashPattern.MatchString(fmt.Sprintf("%v", fieldValue)) { + errors.Fields[field.Name] = base.BindingAlphaDashError + break + } + case strings.HasPrefix(rule, "MinSize("): + min, err := strconv.Atoi(rule[8 : len(rule)-1]) + if err != nil { + errors.Overall["MinSize"] = err.Error() + break + } + if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) < min { + errors.Fields[field.Name] = base.BindingMinSizeError + break + } + v := reflect.ValueOf(fieldValue) + if v.Kind() == reflect.Slice && v.Len() < min { + errors.Fields[field.Name] = base.BindingMinSizeError + break + } + case strings.HasPrefix(rule, "MaxSize("): + max, err := strconv.Atoi(rule[8 : len(rule)-1]) + if err != nil { + errors.Overall["MaxSize"] = err.Error() + break + } + if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) > max { + errors.Fields[field.Name] = base.BindingMaxSizeError + break + } + v := reflect.ValueOf(fieldValue) + if v.Kind() == reflect.Slice && v.Len() > max { + errors.Fields[field.Name] = base.BindingMinSizeError + break + } + case rule == "Email": + if !emailPattern.MatchString(fmt.Sprintf("%v", fieldValue)) { + errors.Fields[field.Name] = base.BindingEmailError + break + } + case rule == "Url": + if !urlPattern.MatchString(fmt.Sprintf("%v", fieldValue)) { + errors.Fields[field.Name] = base.BindingUrlError + break + } + } + } + } +} + +func mapForm(formStruct reflect.Value, form map[string][]string, errors *base.BindingErrors) { + typ := formStruct.Elem().Type() + + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" { + structField := formStruct.Elem().Field(i) + if !structField.CanSet() { + continue + } + + inputValue, exists := form[inputFieldName] + + if !exists { + continue + } + + numElems := len(inputValue) + if structField.Kind() == reflect.Slice && numElems > 0 { + sliceOf := structField.Type().Elem().Kind() + slice := reflect.MakeSlice(structField.Type(), numElems, numElems) + for i := 0; i < numElems; i++ { + setWithProperType(sliceOf, inputValue[i], slice.Index(i), inputFieldName, errors) + } + formStruct.Elem().Field(i).Set(slice) + } else { + setWithProperType(typeField.Type.Kind(), inputValue[0], structField, inputFieldName, errors) + } + } + } +} + +// ErrorHandler simply counts the number of errors in the +// context and, if more than 0, writes a 400 Bad Request +// response and a JSON payload describing the errors with +// the "Content-Type" set to "application/json". +// Middleware remaining on the stack will not even see the request +// if, by this point, there are any errors. +// This is a "default" handler, of sorts, and you are +// welcome to use your own instead. The Bind middleware +// invokes this automatically for convenience. +func ErrorHandler(errs base.BindingErrors, resp http.ResponseWriter) { + if errs.Count() > 0 { + resp.Header().Set("Content-Type", "application/json; charset=utf-8") + if _, ok := errs.Overall[base.BindingDeserializationError]; ok { + resp.WriteHeader(http.StatusBadRequest) + } else { + resp.WriteHeader(422) + } + errOutput, _ := json.Marshal(errs) + resp.Write(errOutput) + return + } +} + +// This sets the value in a struct of an indeterminate type to the +// matching value from the request (via Form middleware) in the +// same type, so that not all deserialized values have to be strings. +// Supported types are string, int, float, and bool. +func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value, nameInTag string, errors *base.BindingErrors) { + switch valueKind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, 64) + if err != nil { + errors.Fields[nameInTag] = base.BindingIntegerTypeError + } else { + structField.SetInt(intVal) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, 64) + if err != nil { + errors.Fields[nameInTag] = base.BindingIntegerTypeError + } else { + structField.SetUint(uintVal) + } + case reflect.Bool: + structField.SetBool(val == "on") + case reflect.Float32: + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, 32) + if err != nil { + errors.Fields[nameInTag] = base.BindingFloatTypeError + } else { + structField.SetFloat(floatVal) + } + case reflect.Float64: + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, 64) + if err != nil { + errors.Fields[nameInTag] = base.BindingFloatTypeError + } else { + structField.SetFloat(floatVal) + } + case reflect.String: + structField.SetString(val) + } +} + +// Don't pass in pointers to bind to. Can lead to bugs. See: +// https://github.com/codegangsta/martini-contrib/issues/40 +// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659 +func ensureNotPointer(obj interface{}) { + if reflect.TypeOf(obj).Kind() == reflect.Ptr { + panic("Pointers are not accepted as binding models") + } +} + +// Performs validation and combines errors from validation +// with errors from deserialization, then maps both the +// resulting struct and the errors to the context. +func validateAndMap(obj reflect.Value, context martini.Context, errors *base.BindingErrors, ifacePtr ...interface{}) { + context.Invoke(Validate(obj.Interface())) + errors.Combine(getErrors(context)) + context.Map(*errors) + context.Map(obj.Elem().Interface()) + if len(ifacePtr) > 0 { + context.MapTo(obj.Elem().Interface(), ifacePtr[0]) + } +} + +func newErrors() *base.BindingErrors { + return &base.BindingErrors{make(map[string]string), make(map[string]string)} +} + +func getErrors(context martini.Context) base.BindingErrors { + return context.Get(reflect.TypeOf(base.BindingErrors{})).Interface().(base.BindingErrors) +} + +type ( + // Implement the Validator interface to define your own input + // validation before the request even gets to your application. + // The Validate method will be executed during the validation phase. + Validator interface { + Validate(*base.BindingErrors, *http.Request, martini.Context) + } +) + +var ( + // Maximum amount of memory to use when parsing a multipart form. + // Set this to whatever value you prefer; default is 10 MB. + MaxMemory = int64(1024 * 1024 * 10) +) diff --git a/modules/middleware/binding_test.go b/modules/middleware/binding_test.go new file mode 100644 index 00000000..2a74e1a6 --- /dev/null +++ b/modules/middleware/binding_test.go @@ -0,0 +1,701 @@ +// Copyright 2013 The Martini Contrib Authors. All rights reserved. +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/codegangsta/martini" +) + +func TestBind(t *testing.T) { + testBind(t, false) +} + +func TestBindWithInterface(t *testing.T) { + testBind(t, true) +} + +func TestMultipartBind(t *testing.T) { + index := 0 + for test, expectStatus := range bindMultipartTests { + handler := func(post BlogPost, errors Errors) { + handle(test, t, index, post, errors) + } + recorder := testMultipart(t, test, Bind(BlogPost{}), handler, index) + + if recorder.Code != expectStatus { + t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus) + } + + index++ + } +} + +func TestForm(t *testing.T) { + testForm(t, false) +} + +func TestFormWithInterface(t *testing.T) { + testForm(t, true) +} + +func TestEmptyForm(t *testing.T) { + testEmptyForm(t) +} + +func TestMultipartForm(t *testing.T) { + for index, test := range multipartformTests { + handler := func(post BlogPost, errors Errors) { + handle(test, t, index, post, errors) + } + testMultipart(t, test, MultipartForm(BlogPost{}), handler, index) + } +} + +func TestMultipartFormWithInterface(t *testing.T) { + for index, test := range multipartformTests { + handler := func(post Modeler, errors Errors) { + post.Create(test, t, index) + } + testMultipart(t, test, MultipartForm(BlogPost{}, (*Modeler)(nil)), handler, index) + } +} + +func TestJson(t *testing.T) { + testJson(t, false) +} + +func TestJsonWithInterface(t *testing.T) { + testJson(t, true) +} + +func TestEmptyJson(t *testing.T) { + testEmptyJson(t) +} + +func TestValidate(t *testing.T) { + handlerMustErr := func(errors Errors) { + if errors.Count() == 0 { + t.Error("Expected at least one error, got 0") + } + } + handlerNoErr := func(errors Errors) { + if errors.Count() > 0 { + t.Error("Expected no errors, got", errors.Count()) + } + } + + performValidationTest(&BlogPost{"", "...", 0, 0, []int{}}, handlerMustErr, t) + performValidationTest(&BlogPost{"Good Title", "Good content", 0, 0, []int{}}, handlerNoErr, t) + + performValidationTest(&User{Name: "Jim", Home: Address{"", ""}}, handlerMustErr, t) + performValidationTest(&User{Name: "Jim", Home: Address{"required", ""}}, handlerNoErr, t) +} + +func handle(test testCase, t *testing.T, index int, post BlogPost, errors Errors) { + assertEqualField(t, "Title", index, test.ref.Title, post.Title) + assertEqualField(t, "Content", index, test.ref.Content, post.Content) + assertEqualField(t, "Views", index, test.ref.Views, post.Views) + + for i := range test.ref.Multiple { + if i >= len(post.Multiple) { + t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", post.Multiple, len(post.Multiple), test.ref.Multiple, len(test.ref.Multiple)) + break + } + if test.ref.Multiple[i] != post.Multiple[i] { + t.Errorf("Expected: %v to deep equal: %v", post.Multiple, test.ref.Multiple) + break + } + } + + if test.ok && errors.Count() > 0 { + t.Errorf("%+v should be OK (0 errors), but had errors: %+v", test, errors) + } else if !test.ok && errors.Count() == 0 { + t.Errorf("%+v should have errors, but was OK (0 errors)", test) + } +} + +func handleEmpty(test emptyPayloadTestCase, t *testing.T, index int, section BlogSection, errors Errors) { + assertEqualField(t, "Title", index, test.ref.Title, section.Title) + assertEqualField(t, "Content", index, test.ref.Content, section.Content) + + if test.ok && errors.Count() > 0 { + t.Errorf("%+v should be OK (0 errors), but had errors: %+v", test, errors) + } else if !test.ok && errors.Count() == 0 { + t.Errorf("%+v should have errors, but was OK (0 errors)", test) + } +} + +func testBind(t *testing.T, withInterface bool) { + index := 0 + for test, expectStatus := range bindTests { + m := martini.Classic() + recorder := httptest.NewRecorder() + handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) } + binding := Bind(BlogPost{}) + + if withInterface { + handler = func(post BlogPost, errors Errors) { + post.Create(test, t, index) + } + binding = Bind(BlogPost{}, (*Modeler)(nil)) + } + + switch test.method { + case "GET": + m.Get(route, binding, handler) + case "POST": + m.Post(route, binding, handler) + } + + req, err := http.NewRequest(test.method, test.path, strings.NewReader(test.payload)) + req.Header.Add("Content-Type", test.contentType) + + if err != nil { + t.Error(err) + } + m.ServeHTTP(recorder, req) + + if recorder.Code != expectStatus { + t.Errorf("On test case %v, got status code %d but expected %d", test, recorder.Code, expectStatus) + } + + index++ + } +} + +func testJson(t *testing.T, withInterface bool) { + for index, test := range jsonTests { + recorder := httptest.NewRecorder() + handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) } + binding := Json(BlogPost{}) + + if withInterface { + handler = func(post BlogPost, errors Errors) { + post.Create(test, t, index) + } + binding = Bind(BlogPost{}, (*Modeler)(nil)) + } + + m := martini.Classic() + switch test.method { + case "GET": + m.Get(route, binding, handler) + case "POST": + m.Post(route, binding, handler) + case "PUT": + m.Put(route, binding, handler) + case "DELETE": + m.Delete(route, binding, handler) + } + + req, err := http.NewRequest(test.method, route, strings.NewReader(test.payload)) + if err != nil { + t.Error(err) + } + m.ServeHTTP(recorder, req) + } +} + +func testEmptyJson(t *testing.T) { + for index, test := range emptyPayloadTests { + recorder := httptest.NewRecorder() + handler := func(section BlogSection, errors Errors) { handleEmpty(test, t, index, section, errors) } + binding := Json(BlogSection{}) + + m := martini.Classic() + switch test.method { + case "GET": + m.Get(route, binding, handler) + case "POST": + m.Post(route, binding, handler) + case "PUT": + m.Put(route, binding, handler) + case "DELETE": + m.Delete(route, binding, handler) + } + + req, err := http.NewRequest(test.method, route, strings.NewReader(test.payload)) + if err != nil { + t.Error(err) + } + m.ServeHTTP(recorder, req) + } +} + +func testForm(t *testing.T, withInterface bool) { + for index, test := range formTests { + recorder := httptest.NewRecorder() + handler := func(post BlogPost, errors Errors) { handle(test, t, index, post, errors) } + binding := Form(BlogPost{}) + + if withInterface { + handler = func(post BlogPost, errors Errors) { + post.Create(test, t, index) + } + binding = Form(BlogPost{}, (*Modeler)(nil)) + } + + m := martini.Classic() + switch test.method { + case "GET": + m.Get(route, binding, handler) + case "POST": + m.Post(route, binding, handler) + } + + req, err := http.NewRequest(test.method, test.path, nil) + if err != nil { + t.Error(err) + } + m.ServeHTTP(recorder, req) + } +} + +func testEmptyForm(t *testing.T) { + for index, test := range emptyPayloadTests { + recorder := httptest.NewRecorder() + handler := func(section BlogSection, errors Errors) { handleEmpty(test, t, index, section, errors) } + binding := Form(BlogSection{}) + + m := martini.Classic() + switch test.method { + case "GET": + m.Get(route, binding, handler) + case "POST": + m.Post(route, binding, handler) + } + + req, err := http.NewRequest(test.method, test.path, nil) + if err != nil { + t.Error(err) + } + m.ServeHTTP(recorder, req) + } +} + +func testMultipart(t *testing.T, test testCase, middleware martini.Handler, handler martini.Handler, index int) *httptest.ResponseRecorder { + recorder := httptest.NewRecorder() + + m := martini.Classic() + m.Post(route, middleware, handler) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + writer.WriteField("title", test.ref.Title) + writer.WriteField("content", test.ref.Content) + writer.WriteField("views", strconv.Itoa(test.ref.Views)) + if len(test.ref.Multiple) != 0 { + for _, value := range test.ref.Multiple { + writer.WriteField("multiple", strconv.Itoa(value)) + } + } + + req, err := http.NewRequest(test.method, test.path, body) + req.Header.Add("Content-Type", writer.FormDataContentType()) + + if err != nil { + t.Error(err) + } + + err = writer.Close() + if err != nil { + t.Error(err) + } + + m.ServeHTTP(recorder, req) + + return recorder +} + +func assertEqualField(t *testing.T, fieldname string, testcasenumber int, expected interface{}, got interface{}) { + if expected != got { + t.Errorf("%s: expected=%s, got=%s in test case %d\n", fieldname, expected, got, testcasenumber) + } +} + +func performValidationTest(data interface{}, handler func(Errors), t *testing.T) { + recorder := httptest.NewRecorder() + m := martini.Classic() + m.Get(route, Validate(data), handler) + + req, err := http.NewRequest("GET", route, nil) + if err != nil { + t.Error("HTTP error:", err) + } + + m.ServeHTTP(recorder, req) +} + +func (self BlogPost) Validate(errors *Errors, req *http.Request) { + if len(self.Title) < 4 { + errors.Fields["Title"] = "Too short; minimum 4 characters" + } + if len(self.Content) > 1024 { + errors.Fields["Content"] = "Too long; maximum 1024 characters" + } + if len(self.Content) < 5 { + errors.Fields["Content"] = "Too short; minimum 5 characters" + } +} + +func (self BlogPost) Create(test testCase, t *testing.T, index int) { + assertEqualField(t, "Title", index, test.ref.Title, self.Title) + assertEqualField(t, "Content", index, test.ref.Content, self.Content) + assertEqualField(t, "Views", index, test.ref.Views, self.Views) + + for i := range test.ref.Multiple { + if i >= len(self.Multiple) { + t.Errorf("Expected: %v (size %d) to have same size as: %v (size %d)", self.Multiple, len(self.Multiple), test.ref.Multiple, len(test.ref.Multiple)) + break + } + if test.ref.Multiple[i] != self.Multiple[i] { + t.Errorf("Expected: %v to deep equal: %v", self.Multiple, test.ref.Multiple) + break + } + } +} + +func (self BlogSection) Create(test emptyPayloadTestCase, t *testing.T, index int) { + // intentionally left empty +} + +type ( + testCase struct { + method string + path string + payload string + contentType string + ok bool + ref *BlogPost + } + + emptyPayloadTestCase struct { + method string + path string + payload string + contentType string + ok bool + ref *BlogSection + } + + Modeler interface { + Create(test testCase, t *testing.T, index int) + } + + BlogPost struct { + Title string `form:"title" json:"title" binding:"required"` + Content string `form:"content" json:"content"` + Views int `form:"views" json:"views"` + internal int `form:"-"` + Multiple []int `form:"multiple"` + } + + BlogSection struct { + Title string `form:"title" json:"title"` + Content string `form:"content" json:"content"` + } + + User struct { + Name string `json:"name" binding:"required"` + Home Address `json:"address" binding:"required"` + } + + Address struct { + Street1 string `json:"street1" binding:"required"` + Street2 string `json:"street2"` + } +) + +var ( + bindTests = map[testCase]int{ + // These should bail at the deserialization/binding phase + testCase{ + "POST", + path, + `{ bad JSON `, + "application/json", + false, + new(BlogPost), + }: http.StatusBadRequest, + testCase{ + "POST", + path, + `not multipart but has content-type`, + "multipart/form-data", + false, + new(BlogPost), + }: http.StatusBadRequest, + testCase{ + "POST", + path, + `no content-type and not URL-encoded or JSON"`, + "", + false, + new(BlogPost), + }: http.StatusBadRequest, + + // These should deserialize, then bail at the validation phase + testCase{ + "POST", + path + "?title= This is wrong ", + `not URL-encoded but has content-type`, + "x-www-form-urlencoded", + false, + new(BlogPost), + }: 422, // according to comments in Form() -> although the request is not url encoded, ParseForm does not complain + testCase{ + "GET", + path + "?content=This+is+the+content", + ``, + "x-www-form-urlencoded", + false, + &BlogPost{Title: "", Content: "This is the content"}, + }: 422, + testCase{ + "GET", + path + "", + `{"content":"", "title":"Blog Post Title"}`, + "application/json", + false, + &BlogPost{Title: "Blog Post Title", Content: ""}, + }: 422, + + // These should succeed + testCase{ + "GET", + path + "", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "application/json", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }: http.StatusOK, + testCase{ + "GET", + path + "?content=This+is+the+content&title=Blog+Post+Title", + ``, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }: http.StatusOK, + testCase{ + "GET", + path + "?content=This is the content&title=Blog+Post+Title", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }: http.StatusOK, + testCase{ + "GET", + path + "", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }: http.StatusOK, + } + + bindMultipartTests = map[testCase]int{ + // This should deserialize, then bail at the validation phase + testCase{ + "POST", + path, + "", + "multipart/form-data", + false, + &BlogPost{Title: "", Content: "This is the content"}, + }: 422, + // This should succeed + testCase{ + "POST", + path, + "", + "multipart/form-data", + true, + &BlogPost{Title: "This is the Title", Content: "This is the content"}, + }: http.StatusOK, + } + + formTests = []testCase{ + { + "GET", + path + "?content=This is the content", + "", + "", + false, + &BlogPost{Title: "", Content: "This is the content"}, + }, + { + "POST", + path + "?content=This+is+the+content&title=Blog+Post+Title&views=3", + "", + "", + false, // false because POST requests should have a body, not just a query string + &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3}, + }, + { + "GET", + path + "?content=This+is+the+content&title=Blog+Post+Title&views=3&multiple=5&multiple=10&multiple=15&multiple=20", + "", + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}}, + }, + } + + multipartformTests = []testCase{ + { + "POST", + path, + "", + "multipart/form-data", + false, + &BlogPost{Title: "", Content: "This is the content"}, + }, + { + "POST", + path, + "", + "multipart/form-data", + false, + &BlogPost{Title: "Blog Post Title", Views: 3}, + }, + { + "POST", + path, + "", + "multipart/form-data", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content", Views: 3, Multiple: []int{5, 10, 15, 20}}, + }, + } + + emptyPayloadTests = []emptyPayloadTestCase{ + { + "GET", + "", + "", + "", + true, + &BlogSection{}, + }, + { + "POST", + "", + "", + "", + true, + &BlogSection{}, + }, + { + "PUT", + "", + "", + "", + true, + &BlogSection{}, + }, + { + "DELETE", + "", + "", + "", + true, + &BlogSection{}, + }, + } + + jsonTests = []testCase{ + // bad requests + { + "GET", + "", + `{blah blah blah}`, + "", + false, + &BlogPost{}, + }, + { + "POST", + "", + `{asdf}`, + "", + false, + &BlogPost{}, + }, + { + "PUT", + "", + `{blah blah blah}`, + "", + false, + &BlogPost{}, + }, + { + "DELETE", + "", + `{;sdf _SDf- }`, + "", + false, + &BlogPost{}, + }, + + // Valid-JSON requests + { + "GET", + "", + `{"content":"This is the content"}`, + "", + false, + &BlogPost{Title: "", Content: "This is the content"}, + }, + { + "POST", + "", + `{}`, + "application/json", + false, + &BlogPost{Title: "", Content: ""}, + }, + { + "POST", + "", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }, + { + "PUT", + "", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }, + { + "DELETE", + "", + `{"content":"This is the content", "title":"Blog Post Title"}`, + "", + true, + &BlogPost{Title: "Blog Post Title", Content: "This is the content"}, + }, + } +) + +const ( + route = "/blogposts/create" + path = "http://localhost:3000" + route +) diff --git a/modules/middleware/context.go b/modules/middleware/context.go index 8129b13b..1330172f 100644 --- a/modules/middleware/context.go +++ b/modules/middleware/context.go @@ -10,7 +10,10 @@ import ( "encoding/base64" "fmt" "html/template" + "io" "net/http" + "net/url" + "path/filepath" "strconv" "strings" "time" @@ -34,6 +37,7 @@ type Context struct { p martini.Params Req *http.Request Res http.ResponseWriter + Flash *Flash Session session.SessionStore Cache cache.Cache User *models.User @@ -47,6 +51,7 @@ type Context struct { IsBranch bool IsTag bool IsCommit bool + HasAccess bool Repository *models.Repository Owner *models.User Commit *git.Commit @@ -59,6 +64,7 @@ type Context struct { HTTPS string Git string } + Mirror *models.Mirror } } @@ -78,6 +84,8 @@ func (ctx *Context) HasError() bool { if !ok { return false } + ctx.Flash.ErrorMsg = ctx.Data["ErrorMsg"].(string) + ctx.Data["Flash"] = ctx.Flash return hasErr.(bool) } @@ -88,23 +96,21 @@ func (ctx *Context) HTML(status int, name string, htmlOpt ...HTMLOptions) { // RenderWithErr used for page has form validation but need to prompt error to users. func (ctx *Context) RenderWithErr(msg, tpl string, form auth.Form) { - ctx.Data["HasError"] = true - ctx.Data["ErrorMsg"] = msg if form != nil { auth.AssignForm(form, ctx.Data) } + ctx.Flash.ErrorMsg = msg + ctx.Data["Flash"] = ctx.Flash ctx.HTML(200, tpl) } // Handle handles and logs error by given status. func (ctx *Context) Handle(status int, title string, err error) { log.Error("%s: %v", title, err) - if martini.Dev == martini.Prod { - ctx.HTML(500, "status/500") - return + if martini.Dev != martini.Prod { + ctx.Data["ErrorMsg"] = err } - ctx.Data["ErrorMsg"] = err ctx.HTML(status, fmt.Sprintf("status/%d", status)) } @@ -239,6 +245,56 @@ func (ctx *Context) CsrfTokenValid() bool { return true } +func (ctx *Context) ServeFile(file string, names ...string) { + var name string + if len(names) > 0 { + name = names[0] + } else { + name = filepath.Base(file) + } + ctx.Res.Header().Set("Content-Description", "File Transfer") + ctx.Res.Header().Set("Content-Type", "application/octet-stream") + ctx.Res.Header().Set("Content-Disposition", "attachment; filename="+name) + ctx.Res.Header().Set("Content-Transfer-Encoding", "binary") + ctx.Res.Header().Set("Expires", "0") + ctx.Res.Header().Set("Cache-Control", "must-revalidate") + ctx.Res.Header().Set("Pragma", "public") + http.ServeFile(ctx.Res, ctx.Req, file) +} + +func (ctx *Context) ServeContent(name string, r io.ReadSeeker, params ...interface{}) { + modtime := time.Now() + for _, p := range params { + switch v := p.(type) { + case time.Time: + modtime = v + } + } + ctx.Res.Header().Set("Content-Description", "File Transfer") + ctx.Res.Header().Set("Content-Type", "application/octet-stream") + ctx.Res.Header().Set("Content-Disposition", "attachment; filename="+name) + ctx.Res.Header().Set("Content-Transfer-Encoding", "binary") + ctx.Res.Header().Set("Expires", "0") + ctx.Res.Header().Set("Cache-Control", "must-revalidate") + ctx.Res.Header().Set("Pragma", "public") + http.ServeContent(ctx.Res, ctx.Req, name, modtime, r) +} + +type Flash struct { + url.Values + ErrorMsg, SuccessMsg string +} + +func (f *Flash) Error(msg string) { + f.Set("error", msg) + f.ErrorMsg = msg +} + +func (f *Flash) Success(msg string) { + f.Set("success", msg) + f.SuccessMsg = msg +} + // InitContext initializes a classic context for a request. func InitContext() martini.Handler { return func(res http.ResponseWriter, r *http.Request, c martini.Context, rd *Render) { @@ -256,9 +312,27 @@ func InitContext() martini.Handler { // start session ctx.Session = base.SessionManager.SessionStart(res, r) + + // Get flash. + values, err := url.ParseQuery(ctx.GetCookie("gogs_flash")) + if err != nil { + log.Error("InitContext.ParseQuery(flash): %v", err) + } else if len(values) > 0 { + ctx.Flash = &Flash{Values: values} + ctx.Flash.ErrorMsg = ctx.Flash.Get("error") + ctx.Flash.SuccessMsg = ctx.Flash.Get("success") + ctx.Data["Flash"] = ctx.Flash + ctx.SetCookie("gogs_flash", "", -1) + } + ctx.Flash = &Flash{Values: url.Values{}} + rw := res.(martini.ResponseWriter) rw.Before(func(martini.ResponseWriter) { ctx.Session.SessionRelease(res) + + if flash := ctx.Flash.Encode(); len(flash) > 0 { + ctx.SetCookie("gogs_flash", ctx.Flash.Encode(), 0) + } }) // Get user from session if logined. diff --git a/modules/middleware/render.go b/modules/middleware/render.go index 98d485af..66289988 100644 --- a/modules/middleware/render.go +++ b/modules/middleware/render.go @@ -146,7 +146,7 @@ func compile(options RenderOptions) *template.Template { tmpl := t.New(filepath.ToSlash(name)) for _, funcs := range options.Funcs { - tmpl.Funcs(funcs) + tmpl = tmpl.Funcs(funcs) } template.Must(tmpl.Funcs(helperFuncs).Parse(string(buf))) diff --git a/modules/middleware/repo.go b/modules/middleware/repo.go index 2139742c..34144fe3 100644 --- a/modules/middleware/repo.go +++ b/modules/middleware/repo.go @@ -15,6 +15,7 @@ import ( "github.com/gogits/gogs/models" "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" ) func RepoAssignment(redirect bool, args ...bool) martini.Handler { @@ -39,7 +40,7 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler { userName := params["username"] repoName := params["reponame"] - branchName := params["branchname"] + refName := params["branchname"] // get repository owner ctx.Repo.IsOwner = ctx.IsSigned && ctx.User.LowerName == strings.ToLower(userName) @@ -66,34 +67,69 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler { ctx.Handle(200, "RepoAssignment", errors.New("invliad user account for single repository")) return } + ctx.Repo.Owner = user // get repository repo, err := models.GetRepositoryByName(user.Id, repoName) if err != nil { if err == models.ErrRepoNotExist { ctx.Handle(404, "RepoAssignment", err) + return } else if redirect { ctx.Redirect("/") return } - ctx.Handle(404, "RepoAssignment", err) + ctx.Handle(500, "RepoAssignment", err) return } + + // Check access. + if repo.IsPrivate { + if ctx.User == nil { + ctx.Handle(404, "RepoAssignment(HasAccess)", nil) + return + } + + hasAccess, err := models.HasAccess(ctx.User.Name, ctx.Repo.Owner.Name+"/"+repo.Name, models.AU_READABLE) + if err != nil { + ctx.Handle(500, "RepoAssignment(HasAccess)", err) + return + } else if !hasAccess { + ctx.Handle(404, "RepoAssignment(HasAccess)", nil) + return + } + } + ctx.Repo.HasAccess = true + ctx.Data["HasAccess"] = true + + if repo.IsMirror { + ctx.Repo.Mirror, err = models.GetMirror(repo.Id) + if err != nil { + ctx.Handle(500, "RepoAssignment(GetMirror)", err) + return + } + ctx.Data["MirrorInterval"] = ctx.Repo.Mirror.Interval + } + repo.NumOpenIssues = repo.NumIssues - repo.NumClosedIssues ctx.Repo.Repository = repo - ctx.Data["IsBareRepo"] = ctx.Repo.Repository.IsBare gitRepo, err := git.OpenRepository(models.RepoPath(userName, repoName)) if err != nil { - ctx.Handle(404, "RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err) + ctx.Handle(500, "RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err) return } ctx.Repo.GitRepo = gitRepo - - ctx.Repo.Owner = user ctx.Repo.RepoLink = "/" + user.Name + "/" + repo.Name + tags, err := ctx.Repo.GitRepo.GetTags() + if err != nil { + ctx.Handle(500, "RepoAssignment(GetTags))", err) + return + } + ctx.Repo.Repository.NumTags = len(tags) + ctx.Data["Title"] = user.Name + "/" + repo.Name ctx.Data["Repository"] = repo ctx.Data["Owner"] = user @@ -105,29 +141,43 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler { ctx.Repo.CloneLink.HTTPS = fmt.Sprintf("%s%s/%s.git", base.AppUrl, user.LowerName, repo.LowerName) ctx.Data["CloneLink"] = ctx.Repo.CloneLink + if ctx.Repo.Repository.IsGoget { + ctx.Data["GoGetLink"] = fmt.Sprintf("%s%s/%s", base.AppUrl, user.LowerName, repo.LowerName) + ctx.Data["GoGetImport"] = fmt.Sprintf("%s/%s/%s", base.Domain, user.LowerName, repo.LowerName) + } + // when repo is bare, not valid branch if !ctx.Repo.Repository.IsBare && validBranch { detect: - if len(branchName) > 0 { - // TODO check tag - if models.IsBranchExist(user.Name, repoName, branchName) { + if len(refName) > 0 { + if gitRepo.IsBranchExist(refName) { ctx.Repo.IsBranch = true - ctx.Repo.BranchName = branchName + ctx.Repo.BranchName = refName - ctx.Repo.Commit, err = gitRepo.GetCommitOfBranch(branchName) + ctx.Repo.Commit, err = gitRepo.GetCommitOfBranch(refName) if err != nil { ctx.Handle(404, "RepoAssignment invalid branch", nil) return } + ctx.Repo.CommitId = ctx.Repo.Commit.Id.String() - ctx.Repo.CommitId = ctx.Repo.Commit.Oid.String() + } else if gitRepo.IsTagExist(refName) { + ctx.Repo.IsBranch = true + ctx.Repo.BranchName = refName - } else if len(branchName) == 40 { + ctx.Repo.Commit, err = gitRepo.GetCommitOfTag(refName) + if err != nil { + ctx.Handle(404, "RepoAssignment invalid tag", nil) + return + } + ctx.Repo.CommitId = ctx.Repo.Commit.Id.String() + + } else if len(refName) == 40 { ctx.Repo.IsCommit = true - ctx.Repo.CommitId = branchName - ctx.Repo.BranchName = branchName + ctx.Repo.CommitId = refName + ctx.Repo.BranchName = refName - ctx.Repo.Commit, err = gitRepo.GetCommit(branchName) + ctx.Repo.Commit, err = gitRepo.GetCommit(refName) if err != nil { ctx.Handle(404, "RepoAssignment invalid commit", nil) return @@ -138,16 +188,23 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler { } } else { - branchName = "master" + refName = ctx.Repo.Repository.DefaultBranch + if len(refName) == 0 { + refName = "master" + } goto detect } ctx.Data["IsBranch"] = ctx.Repo.IsBranch ctx.Data["IsCommit"] = ctx.Repo.IsCommit + log.Debug("Repo.Commit: %v", ctx.Repo.Commit) } + log.Debug("displayBare: %v; IsBare: %v", displayBare, ctx.Repo.Repository.IsBare) + // repo is bare and display enable if displayBare && ctx.Repo.Repository.IsBare { + log.Debug("Bare repository: %s", ctx.Repo.RepoLink) ctx.HTML(200, "repo/single_bare") return } @@ -157,6 +214,11 @@ func RepoAssignment(redirect bool, args ...bool) martini.Handler { } ctx.Data["BranchName"] = ctx.Repo.BranchName + brs, err := ctx.Repo.GitRepo.GetBranches() + if err != nil { + log.Error("RepoAssignment(GetBranches): %v", err) + } + ctx.Data["Branches"] = brs ctx.Data["CommitId"] = ctx.Repo.CommitId ctx.Data["IsRepositoryWatching"] = ctx.Repo.IsWatching } diff --git a/modules/social/social.go b/modules/social/social.go new file mode 100644 index 00000000..a628bfe6 --- /dev/null +++ b/modules/social/social.go @@ -0,0 +1,396 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package social + +import ( + "encoding/json" + "net/http" + "net/url" + "strconv" + "strings" + + oauth "github.com/gogits/oauth2" + + "github.com/gogits/gogs/models" + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type BasicUserInfo struct { + Identity string + Name string + Email string +} + +type SocialConnector interface { + Type() int + SetRedirectUrl(string) + UserInfo(*oauth.Token, *url.URL) (*BasicUserInfo, error) + + AuthCodeURL(string) string + Exchange(string) (*oauth.Token, error) +} + +var ( + SocialBaseUrl = "/user/login" + SocialMap = make(map[string]SocialConnector) +) + +func NewOauthService() { + if !base.Cfg.MustBool("oauth", "ENABLED") { + return + } + + base.OauthService = &base.Oauther{} + base.OauthService.OauthInfos = make(map[string]*base.OauthInfo) + + socialConfigs := make(map[string]*oauth.Config) + allOauthes := []string{"github", "google", "qq", "twitter", "weibo"} + // Load all OAuth config data. + for _, name := range allOauthes { + base.OauthService.OauthInfos[name] = &base.OauthInfo{ + ClientId: base.Cfg.MustValue("oauth."+name, "CLIENT_ID"), + ClientSecret: base.Cfg.MustValue("oauth."+name, "CLIENT_SECRET"), + Scopes: base.Cfg.MustValue("oauth."+name, "SCOPES"), + AuthUrl: base.Cfg.MustValue("oauth."+name, "AUTH_URL"), + TokenUrl: base.Cfg.MustValue("oauth."+name, "TOKEN_URL"), + } + socialConfigs[name] = &oauth.Config{ + ClientId: base.OauthService.OauthInfos[name].ClientId, + ClientSecret: base.OauthService.OauthInfos[name].ClientSecret, + RedirectURL: strings.TrimSuffix(base.AppUrl, "/") + SocialBaseUrl + name, + Scope: base.OauthService.OauthInfos[name].Scopes, + AuthURL: base.OauthService.OauthInfos[name].AuthUrl, + TokenURL: base.OauthService.OauthInfos[name].TokenUrl, + } + } + + enabledOauths := make([]string, 0, 10) + + // GitHub. + if base.Cfg.MustBool("oauth.github", "ENABLED") { + base.OauthService.GitHub = true + newGitHubOauth(socialConfigs["github"]) + enabledOauths = append(enabledOauths, "GitHub") + } + + // Google. + if base.Cfg.MustBool("oauth.google", "ENABLED") { + base.OauthService.Google = true + newGoogleOauth(socialConfigs["google"]) + enabledOauths = append(enabledOauths, "Google") + } + + // QQ. + if base.Cfg.MustBool("oauth.qq", "ENABLED") { + base.OauthService.Tencent = true + newTencentOauth(socialConfigs["qq"]) + enabledOauths = append(enabledOauths, "QQ") + } + + // Twitter. + if base.Cfg.MustBool("oauth.twitter", "ENABLED") { + base.OauthService.Twitter = true + newTwitterOauth(socialConfigs["twitter"]) + enabledOauths = append(enabledOauths, "Twitter") + } + + // Weibo. + if base.Cfg.MustBool("oauth.weibo", "ENABLED") { + base.OauthService.Weibo = true + newWeiboOauth(socialConfigs["weibo"]) + enabledOauths = append(enabledOauths, "Weibo") + } + + log.Info("Oauth Service Enabled %s", enabledOauths) +} + +// ________.__ __ ___ ___ ___. +// / _____/|__|/ |_ / | \ __ _\_ |__ +// / \ ___| \ __\/ ~ \ | \ __ \ +// \ \_\ \ || | \ Y / | / \_\ \ +// \______ /__||__| \___|_ /|____/|___ / +// \/ \/ \/ + +type SocialGithub struct { + Token *oauth.Token + *oauth.Transport +} + +func (s *SocialGithub) Type() int { + return models.OT_GITHUB +} + +func newGitHubOauth(config *oauth.Config) { + SocialMap["github"] = &SocialGithub{ + Transport: &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + }, + } +} + +func (s *SocialGithub) SetRedirectUrl(url string) { + s.Transport.Config.RedirectURL = url +} + +func (s *SocialGithub) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) { + transport := &oauth.Transport{ + Token: token, + } + var data struct { + Id int `json:"id"` + Name string `json:"login"` + Email string `json:"email"` + } + var err error + r, err := transport.Client().Get(s.Transport.Scope) + if err != nil { + return nil, err + } + defer r.Body.Close() + if err = json.NewDecoder(r.Body).Decode(&data); err != nil { + return nil, err + } + return &BasicUserInfo{ + Identity: strconv.Itoa(data.Id), + Name: data.Name, + Email: data.Email, + }, nil +} + +// ________ .__ +// / _____/ ____ ____ ____ | | ____ +// / \ ___ / _ \ / _ \ / ___\| | _/ __ \ +// \ \_\ ( <_> | <_> ) /_/ > |_\ ___/ +// \______ /\____/ \____/\___ /|____/\___ > +// \/ /_____/ \/ + +type SocialGoogle struct { + Token *oauth.Token + *oauth.Transport +} + +func (s *SocialGoogle) Type() int { + return models.OT_GOOGLE +} + +func newGoogleOauth(config *oauth.Config) { + SocialMap["google"] = &SocialGoogle{ + Transport: &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + }, + } +} + +func (s *SocialGoogle) SetRedirectUrl(url string) { + s.Transport.Config.RedirectURL = url +} + +func (s *SocialGoogle) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) { + transport := &oauth.Transport{Token: token} + var data struct { + Id string `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + } + var err error + + reqUrl := "https://www.googleapis.com/oauth2/v1/userinfo" + r, err := transport.Client().Get(reqUrl) + if err != nil { + return nil, err + } + defer r.Body.Close() + if err = json.NewDecoder(r.Body).Decode(&data); err != nil { + return nil, err + } + return &BasicUserInfo{ + Identity: data.Id, + Name: data.Name, + Email: data.Email, + }, nil +} + +// ________ ________ +// \_____ \ \_____ \ +// / / \ \ / / \ \ +// / \_/. \/ \_/. \ +// \_____\ \_/\_____\ \_/ +// \__> \__> + +type SocialTencent struct { + Token *oauth.Token + *oauth.Transport + reqUrl string +} + +func (s *SocialTencent) Type() int { + return models.OT_QQ +} + +func newTencentOauth(config *oauth.Config) { + SocialMap["qq"] = &SocialTencent{ + reqUrl: "https://open.t.qq.com/api/user/info", + Transport: &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + }, + } +} + +func (s *SocialTencent) SetRedirectUrl(url string) { + s.Transport.Config.RedirectURL = url +} + +func (s *SocialTencent) UserInfo(token *oauth.Token, URL *url.URL) (*BasicUserInfo, error) { + var data struct { + Data struct { + Id string `json:"openid"` + Name string `json:"name"` + Email string `json:"email"` + } `json:"data"` + } + var err error + // https://open.t.qq.com/api/user/info? + //oauth_consumer_key=APP_KEY& + //access_token=ACCESSTOKEN&openid=openid + //clientip=CLIENTIP&oauth_version=2.a + //scope=all + var urls = url.Values{ + "oauth_consumer_key": {s.Transport.Config.ClientId}, + "access_token": {token.AccessToken}, + "openid": URL.Query()["openid"], + "oauth_version": {"2.a"}, + "scope": {"all"}, + } + r, err := http.Get(s.reqUrl + "?" + urls.Encode()) + if err != nil { + return nil, err + } + defer r.Body.Close() + if err = json.NewDecoder(r.Body).Decode(&data); err != nil { + return nil, err + } + return &BasicUserInfo{ + Identity: data.Data.Id, + Name: data.Data.Name, + Email: data.Data.Email, + }, nil +} + +// ___________ .__ __ __ +// \__ ___/_ _ _|__|/ |__/ |_ ___________ +// | | \ \/ \/ / \ __\ __\/ __ \_ __ \ +// | | \ /| || | | | \ ___/| | \/ +// |____| \/\_/ |__||__| |__| \___ >__| +// \/ + +type SocialTwitter struct { + Token *oauth.Token + *oauth.Transport +} + +func (s *SocialTwitter) Type() int { + return models.OT_TWITTER +} + +func newTwitterOauth(config *oauth.Config) { + SocialMap["twitter"] = &SocialTwitter{ + Transport: &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + }, + } +} + +func (s *SocialTwitter) SetRedirectUrl(url string) { + s.Transport.Config.RedirectURL = url +} + +//https://github.com/mrjones/oauth +func (s *SocialTwitter) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) { + // transport := &oauth.Transport{Token: token} + // var data struct { + // Id string `json:"id"` + // Name string `json:"name"` + // Email string `json:"email"` + // } + // var err error + + // reqUrl := "https://www.googleapis.com/oauth2/v1/userinfo" + // r, err := transport.Client().Get(reqUrl) + // if err != nil { + // return nil, err + // } + // defer r.Body.Close() + // if err = json.NewDecoder(r.Body).Decode(&data); err != nil { + // return nil, err + // } + // return &BasicUserInfo{ + // Identity: data.Id, + // Name: data.Name, + // Email: data.Email, + // }, nil + return nil, nil +} + +// __ __ ._____. +// / \ / \ ____ |__\_ |__ ____ +// \ \/\/ // __ \| || __ \ / _ \ +// \ /\ ___/| || \_\ ( <_> ) +// \__/\ / \___ >__||___ /\____/ +// \/ \/ \/ + +type SocialWeibo struct { + Token *oauth.Token + *oauth.Transport +} + +func (s *SocialWeibo) Type() int { + return models.OT_WEIBO +} + +func newWeiboOauth(config *oauth.Config) { + SocialMap["weibo"] = &SocialWeibo{ + Transport: &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + }, + } +} + +func (s *SocialWeibo) SetRedirectUrl(url string) { + s.Transport.Config.RedirectURL = url +} + +func (s *SocialWeibo) UserInfo(token *oauth.Token, _ *url.URL) (*BasicUserInfo, error) { + transport := &oauth.Transport{Token: token} + var data struct { + Name string `json:"name"` + } + var err error + + var urls = url.Values{ + "access_token": {token.AccessToken}, + "uid": {token.Extra["id_token"]}, + } + reqUrl := "https://api.weibo.com/2/users/show.json" + r, err := transport.Client().Get(reqUrl + "?" + urls.Encode()) + if err != nil { + return nil, err + } + defer r.Body.Close() + if err = json.NewDecoder(r.Body).Decode(&data); err != nil { + return nil, err + } + return &BasicUserInfo{ + Identity: token.Extra["id_token"], + Name: data.Name, + }, nil + return nil, nil +} |