diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/auth/admin.go | 55 | ||||
-rw-r--r-- | modules/auth/auth.go | 214 | ||||
-rw-r--r-- | modules/auth/issue.go | 51 | ||||
-rw-r--r-- | modules/auth/repo.go | 53 | ||||
-rw-r--r-- | modules/auth/setting.go | 55 | ||||
-rw-r--r-- | modules/auth/user.go | 145 | ||||
-rw-r--r-- | modules/avatar/avatar.go | 296 | ||||
-rw-r--r-- | modules/avatar/avatar_test.go | 61 | ||||
-rw-r--r-- | modules/base/base.go | 10 | ||||
-rw-r--r-- | modules/base/conf.go | 340 | ||||
-rw-r--r-- | modules/base/markdown.go | 168 | ||||
-rw-r--r-- | modules/base/template.go | 196 | ||||
-rw-r--r-- | modules/base/tool.go | 514 | ||||
-rw-r--r-- | modules/log/log.go | 49 | ||||
-rw-r--r-- | modules/mailer/mail.go | 162 | ||||
-rw-r--r-- | modules/mailer/mailer.go | 126 | ||||
-rw-r--r-- | modules/middleware/auth.go | 63 | ||||
-rw-r--r-- | modules/middleware/context.go | 286 | ||||
-rw-r--r-- | modules/middleware/logger.go | 46 | ||||
-rw-r--r-- | modules/middleware/render.go | 289 | ||||
-rw-r--r-- | modules/middleware/repo.go | 163 | ||||
-rw-r--r-- | modules/oauth2/oauth2.go | 234 | ||||
-rw-r--r-- | modules/oauth2/oauth2_test.go | 162 |
23 files changed, 3738 insertions, 0 deletions
diff --git a/modules/auth/admin.go b/modules/auth/admin.go new file mode 100644 index 00000000..fe889c23 --- /dev/null +++ b/modules/auth/admin.go @@ -0,0 +1,55 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type AdminEditUserForm struct { + Email string `form:"email" binding:"Required;Email;MaxSize(50)"` + Website string `form:"website" binding:"MaxSize(50)"` + Location string `form:"location" binding:"MaxSize(50)"` + Avatar string `form:"avatar" binding:"Required;Email;MaxSize(50)"` + Active string `form:"active"` + Admin string `form:"admin"` +} + +func (f *AdminEditUserForm) Name(field string) string { + names := map[string]string{ + "Email": "E-mail address", + "Website": "Website", + "Location": "Location", + "Avatar": "Gravatar Email", + } + return names[field] +} + +func (f *AdminEditUserForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("AdminEditUserForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/auth.go b/modules/auth/auth.go new file mode 100644 index 00000000..4561dd83 --- /dev/null +++ b/modules/auth/auth.go @@ -0,0 +1,214 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + "strings" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +// Web form interface. +type Form interface { + Name(field string) string +} + +type RegisterForm struct { + UserName string `form:"username" binding:"Required;AlphaDash;MaxSize(30)"` + Email string `form:"email" binding:"Required;Email;MaxSize(50)"` + Password string `form:"passwd" binding:"Required;MinSize(6);MaxSize(30)"` + RetypePasswd string `form:"retypepasswd"` +} + +func (f *RegisterForm) Name(field string) string { + names := map[string]string{ + "UserName": "Username", + "Email": "E-mail address", + "Password": "Password", + "RetypePasswd": "Re-type password", + } + return names[field] +} + +func (f *RegisterForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("RegisterForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} + +type LogInForm struct { + UserName string `form:"username" binding:"Required;AlphaDash;MaxSize(30)"` + Password string `form:"passwd" binding:"Required;MinSize(6);MaxSize(30)"` + Remember string `form:"remember"` +} + +func (f *LogInForm) Name(field string) string { + names := map[string]string{ + "UserName": "Username", + "Password": "Password", + } + return names[field] +} + +func (f *LogInForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("LogInForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} + +func getMinMaxSize(field reflect.StructField) string { + for _, rule := range strings.Split(field.Tag.Get("binding"), ";") { + if strings.HasPrefix(rule, "MinSize(") || strings.HasPrefix(rule, "MaxSize(") { + return rule[8 : len(rule)-1] + } + } + return "" +} + +func validate(errors *binding.Errors, data base.TmplData, form Form) { + typ := reflect.TypeOf(form) + val := reflect.ValueOf(form) + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + fieldName := field.Tag.Get("form") + // Allow ignored fields in the struct + if fieldName == "-" { + continue + } + + if err, ok := errors.Fields[field.Name]; ok { + data["Err_"+field.Name] = true + switch err { + case binding.RequireError: + data["ErrorMsg"] = form.Name(field.Name) + " cannot be empty" + case binding.AlphaDashError: + data["ErrorMsg"] = form.Name(field.Name) + " must be valid alpha or numeric or dash(-_) characters" + case binding.MinSizeError: + data["ErrorMsg"] = form.Name(field.Name) + " must contain at least " + getMinMaxSize(field) + " characters" + case binding.MaxSizeError: + data["ErrorMsg"] = form.Name(field.Name) + " must contain at most " + getMinMaxSize(field) + " characters" + case binding.EmailError: + data["ErrorMsg"] = form.Name(field.Name) + " is not valid" + default: + data["ErrorMsg"] = "Unknown error: " + err + } + return + } + } +} + +// AssignForm assign form values back to the template data. +func AssignForm(form interface{}, data base.TmplData) { + typ := reflect.TypeOf(form) + val := reflect.ValueOf(form) + + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } + + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + fieldName := field.Tag.Get("form") + // Allow ignored fields in the struct + if fieldName == "-" { + continue + } + + data[fieldName] = val.Field(i).Interface() + } +} + +type InstallForm struct { + Database string `form:"database" binding:"Required"` + Host string `form:"host"` + User string `form:"user"` + Passwd string `form:"passwd"` + DatabaseName string `form:"database_name"` + SslMode string `form:"ssl_mode"` + DatabasePath string `form:"database_path"` + RepoRootPath string `form:"repo_path"` + RunUser string `form:"run_user"` + Domain string `form:"domain"` + AppUrl string `form:"app_url"` + AdminName string `form:"admin_name" binding:"Required"` + AdminPasswd string `form:"admin_pwd" binding:"Required;MinSize(6);MaxSize(30)"` + AdminEmail string `form:"admin_email" binding:"Required;Email;MaxSize(50)"` + SmtpHost string `form:"smtp_host"` + SmtpEmail string `form:"mailer_user"` + SmtpPasswd string `form:"mailer_pwd"` + RegisterConfirm string `form:"register_confirm"` + MailNotify string `form:"mail_notify"` +} + +func (f *InstallForm) Name(field string) string { + names := map[string]string{ + "Database": "Database name", + "AdminName": "Admin user name", + "AdminPasswd": "Admin password", + "AdminEmail": "Admin e-maill address", + } + return names[field] +} + +func (f *InstallForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("InstallForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/issue.go b/modules/auth/issue.go new file mode 100644 index 00000000..36c87627 --- /dev/null +++ b/modules/auth/issue.go @@ -0,0 +1,51 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type CreateIssueForm struct { + IssueName string `form:"title" binding:"Required;MaxSize(50)"` + MilestoneId int64 `form:"milestoneid"` + AssigneeId int64 `form:"assigneeid"` + Labels string `form:"labels"` + Content string `form:"content"` +} + +func (f *CreateIssueForm) Name(field string) string { + names := map[string]string{ + "IssueName": "Issue name", + } + return names[field] +} + +func (f *CreateIssueForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("CreateIssueForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/repo.go b/modules/auth/repo.go new file mode 100644 index 00000000..eddd6475 --- /dev/null +++ b/modules/auth/repo.go @@ -0,0 +1,53 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type CreateRepoForm struct { + RepoName string `form:"repo" binding:"Required;AlphaDash"` + Visibility string `form:"visibility"` + Description string `form:"desc" binding:"MaxSize(100)"` + Language string `form:"language"` + License string `form:"license"` + InitReadme string `form:"initReadme"` +} + +func (f *CreateRepoForm) Name(field string) string { + names := map[string]string{ + "RepoName": "Repository name", + "Description": "Description", + } + return names[field] +} + +func (f *CreateRepoForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + AssignForm(f, data) + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("CreateRepoForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/setting.go b/modules/auth/setting.go new file mode 100644 index 00000000..cada7eea --- /dev/null +++ b/modules/auth/setting.go @@ -0,0 +1,55 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + "strings" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type AddSSHKeyForm struct { + KeyName string `form:"keyname" binding:"Required"` + KeyContent string `form:"key_content" binding:"Required"` +} + +func (f *AddSSHKeyForm) Name(field string) string { + names := map[string]string{ + "KeyName": "SSH key name", + "KeyContent": "SSH key content", + } + return names[field] +} + +func (f *AddSSHKeyForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + AssignForm(f, data) + + if req.Method == "GET" || errors.Count() == 0 { + if req.Method == "POST" && + (len(f.KeyContent) < 100 || !strings.HasPrefix(f.KeyContent, "ssh-rsa")) { + data["HasError"] = true + data["ErrorMsg"] = "SSH key content is not valid" + } + return + } + + data["HasError"] = true + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("AddSSHKeyForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/auth/user.go b/modules/auth/user.go new file mode 100644 index 00000000..015059f7 --- /dev/null +++ b/modules/auth/user.go @@ -0,0 +1,145 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net/http" + "reflect" + + "github.com/go-martini/martini" + + "github.com/gogits/binding" + "github.com/gogits/session" + + "github.com/gogits/gogs/models" + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +// SignedInId returns the id of signed in user. +func SignedInId(session session.SessionStore) int64 { + if !models.HasEngine { + return 0 + } + + userId := session.Get("userId") + if userId == nil { + return 0 + } + if s, ok := userId.(int64); ok { + if _, err := models.GetUserById(s); err != nil { + return 0 + } + return s + } + return 0 +} + +// SignedInName returns the name of signed in user. +func SignedInName(session session.SessionStore) string { + userName := session.Get("userName") + if userName == nil { + return "" + } + if s, ok := userName.(string); ok { + return s + } + return "" +} + +// SignedInUser returns the user object of signed user. +func SignedInUser(session session.SessionStore) *models.User { + id := SignedInId(session) + if id <= 0 { + return nil + } + + user, err := models.GetUserById(id) + if err != nil { + log.Error("user.SignedInUser: %v", err) + return nil + } + return user +} + +// IsSignedIn check if any user has signed in. +func IsSignedIn(session session.SessionStore) bool { + return SignedInId(session) > 0 +} + +type FeedsForm struct { + UserId int64 `form:"userid" binding:"Required"` + Page int64 `form:"p"` +} + +type UpdateProfileForm struct { + UserName string `form:"username" binding:"Required;AlphaDash;MaxSize(30)"` + Email string `form:"email" binding:"Required;Email;MaxSize(50)"` + Website string `form:"website" binding:"MaxSize(50)"` + Location string `form:"location" binding:"MaxSize(50)"` + Avatar string `form:"avatar" binding:"Required;Email;MaxSize(50)"` +} + +func (f *UpdateProfileForm) Name(field string) string { + names := map[string]string{ + "UserName": "Username", + "Email": "E-mail address", + "Website": "Website", + "Location": "Location", + "Avatar": "Gravatar Email", + } + return names[field] +} + +func (f *UpdateProfileForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("UpdateProfileForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} + +type UpdatePasswdForm struct { + OldPasswd string `form:"oldpasswd" binding:"Required;MinSize(6);MaxSize(30)"` + NewPasswd string `form:"newpasswd" binding:"Required;MinSize(6);MaxSize(30)"` + RetypePasswd string `form:"retypepasswd"` +} + +func (f *UpdatePasswdForm) Name(field string) string { + names := map[string]string{ + "OldPasswd": "Old password", + "NewPasswd": "New password", + "RetypePasswd": "Re-type password", + } + return names[field] +} + +func (f *UpdatePasswdForm) Validate(errors *binding.Errors, req *http.Request, context martini.Context) { + if req.Method == "GET" || errors.Count() == 0 { + return + } + + data := context.Get(reflect.TypeOf(base.TmplData{})).Interface().(base.TmplData) + data["HasError"] = true + + if len(errors.Overall) > 0 { + for _, err := range errors.Overall { + log.Error("UpdatePasswdForm.Validate: %v", err) + } + return + } + + validate(errors, data, f) +} diff --git a/modules/avatar/avatar.go b/modules/avatar/avatar.go new file mode 100644 index 00000000..edeb256f --- /dev/null +++ b/modules/avatar/avatar.go @@ -0,0 +1,296 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// for www.gravatar.com image cache + +/* +It is recommend to use this way + + cacheDir := "./cache" + defaultImg := "./default.jpg" + http.Handle("/avatar/", avatar.CacheServer(cacheDir, defaultImg)) +*/ +package avatar + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "image" + "image/jpeg" + "image/png" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/nfnt/resize" + + "github.com/gogits/gogs/modules/log" +) + +var ( + gravatar = "http://www.gravatar.com/avatar" +) + +// hash email to md5 string +// keep this func in order to make this package indenpent +func HashEmail(email string) string { + h := md5.New() + h.Write([]byte(strings.ToLower(email))) + return hex.EncodeToString(h.Sum(nil)) +} + +// Avatar represents the avatar object. +type Avatar struct { + Hash string + AlterImage string // image path + cacheDir string // image save dir + reqParams string + imagePath string + expireDuration time.Duration +} + +func New(hash string, cacheDir string) *Avatar { + return &Avatar{ + Hash: hash, + cacheDir: cacheDir, + expireDuration: time.Minute * 10, + reqParams: url.Values{ + "d": {"retro"}, + "size": {"200"}, + "r": {"pg"}}.Encode(), + imagePath: filepath.Join(cacheDir, hash+".image"), //maybe png or jpeg + } +} + +func (this *Avatar) HasCache() bool { + fileInfo, err := os.Stat(this.imagePath) + return err == nil && fileInfo.Mode().IsRegular() +} + +func (this *Avatar) Modtime() (modtime time.Time, err error) { + fileInfo, err := os.Stat(this.imagePath) + if err != nil { + return + } + return fileInfo.ModTime(), nil +} + +func (this *Avatar) Expired() bool { + modtime, err := this.Modtime() + return err != nil || time.Since(modtime) > this.expireDuration +} + +// default image format: jpeg +func (this *Avatar) Encode(wr io.Writer, size int) (err error) { + var img image.Image + decodeImageFile := func(file string) (img image.Image, err error) { + fd, err := os.Open(file) + if err != nil { + return + } + defer fd.Close() + + if img, err = jpeg.Decode(fd); err != nil { + fd.Seek(0, os.SEEK_SET) + img, err = png.Decode(fd) + } + return + } + imgPath := this.imagePath + if !this.HasCache() { + if this.AlterImage == "" { + return errors.New("request image failed, and no alt image offered") + } + imgPath = this.AlterImage + } + + if img, err = decodeImageFile(imgPath); err != nil { + return + } + m := resize.Resize(uint(size), 0, img, resize.Lanczos3) + return jpeg.Encode(wr, m, nil) +} + +// get image from gravatar.com +func (this *Avatar) Update() { + thunder.Fetch(gravatar+"/"+this.Hash+"?"+this.reqParams, + this.imagePath) +} + +func (this *Avatar) UpdateTimeout(timeout time.Duration) (err error) { + select { + case <-time.After(timeout): + err = fmt.Errorf("get gravatar image %s timeout", this.Hash) + case err = <-thunder.GoFetch(gravatar+"/"+this.Hash+"?"+this.reqParams, + this.imagePath): + } + return err +} + +type service struct { + cacheDir string + altImage string +} + +func (this *service) mustInt(r *http.Request, defaultValue int, keys ...string) (v int) { + for _, k := range keys { + if _, err := fmt.Sscanf(r.FormValue(k), "%d", &v); err == nil { + defaultValue = v + } + } + return defaultValue +} + +func (this *service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + urlPath := r.URL.Path + hash := urlPath[strings.LastIndex(urlPath, "/")+1:] + size := this.mustInt(r, 80, "s", "size") // default size = 80*80 + + avatar := New(hash, this.cacheDir) + avatar.AlterImage = this.altImage + if avatar.Expired() { + err := avatar.UpdateTimeout(time.Millisecond * 500) + if err != nil { + log.Trace("avatar update error: %v", err) + } + } + if modtime, err := avatar.Modtime(); err == nil { + etag := fmt.Sprintf("size(%d)", size) + if t, err := time.Parse(http.TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) && etag == r.Header.Get("If-None-Match") { + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) + w.Header().Set("ETag", etag) + } + w.Header().Set("Content-Type", "image/jpeg") + + if err := avatar.Encode(w, size); err != nil { + log.Warn("avatar encode error: %v", err) + w.WriteHeader(500) + } +} + +// http.Handle("/avatar/", avatar.CacheServer("./cache")) +func CacheServer(cacheDir string, defaultImgPath string) http.Handler { + return &service{ + cacheDir: cacheDir, + altImage: defaultImgPath, + } +} + +// thunder downloader +var thunder = &Thunder{QueueSize: 10} + +type Thunder struct { + QueueSize int // download queue size + q chan *thunderTask + once sync.Once +} + +func (t *Thunder) init() { + if t.QueueSize < 1 { + t.QueueSize = 1 + } + t.q = make(chan *thunderTask, t.QueueSize) + for i := 0; i < t.QueueSize; i++ { + go func() { + for { + task := <-t.q + task.Fetch() + } + }() + } +} + +func (t *Thunder) Fetch(url string, saveFile string) error { + t.once.Do(t.init) + task := &thunderTask{ + Url: url, + SaveFile: saveFile, + } + task.Add(1) + t.q <- task + task.Wait() + return task.err +} + +func (t *Thunder) GoFetch(url, saveFile string) chan error { + c := make(chan error) + go func() { + c <- t.Fetch(url, saveFile) + }() + return c +} + +// thunder download +type thunderTask struct { + Url string + SaveFile string + sync.WaitGroup + err error +} + +func (this *thunderTask) Fetch() { + this.err = this.fetch() + this.Done() +} + +var client = &http.Client{} + +func (this *thunderTask) fetch() error { + req, _ := http.NewRequest("GET", this.Url, nil) + req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/jpeg,image/png,*/*;q=0.8") + req.Header.Set("Accept-Encoding", "deflate,sdch") + req.Header.Set("Accept-Language", "zh-CN,zh;q=0.8") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/33.0.1750.154 Safari/537.36") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("status code: %d", resp.StatusCode) + } + + /* + log.Println("headers:", resp.Header) + switch resp.Header.Get("Content-Type") { + case "image/jpeg": + this.SaveFile += ".jpeg" + case "image/png": + this.SaveFile += ".png" + } + */ + /* + imgType := resp.Header.Get("Content-Type") + if imgType != "image/jpeg" && imgType != "image/png" { + return errors.New("not png or jpeg") + } + */ + + tmpFile := this.SaveFile + ".part" // mv to destination when finished + fd, err := os.Create(tmpFile) + if err != nil { + return err + } + _, err = io.Copy(fd, resp.Body) + fd.Close() + if err != nil { + os.Remove(tmpFile) + return err + } + return os.Rename(tmpFile, this.SaveFile) +} diff --git a/modules/avatar/avatar_test.go b/modules/avatar/avatar_test.go new file mode 100644 index 00000000..0cbf70fe --- /dev/null +++ b/modules/avatar/avatar_test.go @@ -0,0 +1,61 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. +package avatar_test + +import ( + "errors" + "os" + "strconv" + "testing" + "time" + + "github.com/gogits/gogs/modules/avatar" + "github.com/gogits/gogs/modules/log" +) + +const TMPDIR = "test-avatar" + +func TestFetch(t *testing.T) { + os.Mkdir(TMPDIR, 0755) + defer os.RemoveAll(TMPDIR) + + hash := avatar.HashEmail("ssx205@gmail.com") + a := avatar.New(hash, TMPDIR) + a.UpdateTimeout(time.Millisecond * 200) +} + +func TestFetchMany(t *testing.T) { + os.Mkdir(TMPDIR, 0755) + defer os.RemoveAll(TMPDIR) + + t.Log("start") + var n = 5 + ch := make(chan bool, n) + for i := 0; i < n; i++ { + go func(i int) { + hash := avatar.HashEmail(strconv.Itoa(i) + "ssx205@gmail.com") + a := avatar.New(hash, TMPDIR) + a.Update() + t.Log("finish", hash) + ch <- true + }(i) + } + for i := 0; i < n; i++ { + <-ch + } + t.Log("end") +} + +// cat +// wget http://www.artsjournal.com/artfulmanager/wp/wp-content/uploads/2013/12/200x200xmirror_cat.jpg.pagespeed.ic.GOZSv6v1_H.jpg -O default.jpg +/* +func TestHttp(t *testing.T) { + http.Handle("/", avatar.CacheServer("./", "default.jpg")) + http.ListenAndServe(":8001", nil) +} +*/ + +func TestLogTrace(t *testing.T) { + log.Trace("%v", errors.New("console log test")) +} diff --git a/modules/base/base.go b/modules/base/base.go new file mode 100644 index 00000000..7c08dcc5 --- /dev/null +++ b/modules/base/base.go @@ -0,0 +1,10 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package base + +type ( + // Type TmplData represents data in the templates. + TmplData map[string]interface{} +) diff --git a/modules/base/conf.go b/modules/base/conf.go new file mode 100644 index 00000000..871595e4 --- /dev/null +++ b/modules/base/conf.go @@ -0,0 +1,340 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package base + +import ( + "fmt" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + + "github.com/Unknwon/com" + "github.com/Unknwon/goconfig" + qlog "github.com/qiniu/log" + + "github.com/gogits/cache" + "github.com/gogits/session" + + "github.com/gogits/gogs/modules/log" +) + +// Mailer represents mail service. +type Mailer struct { + Name string + Host string + User, Passwd string +} + +// Oauther represents oauth service. +type Oauther struct { + GitHub struct { + Enabled bool + ClientId, ClientSecret string + Scopes string + } +} + +var ( + AppVer string + AppName string + AppLogo string + AppUrl string + IsProdMode bool + Domain string + SecretKey string + RunUser string + RepoRootPath string + + InstallLock bool + + LogInRememberDays int + CookieUserName string + CookieRememberName string + + Cfg *goconfig.ConfigFile + MailService *Mailer + OauthService *Oauther + + LogMode string + LogConfig string + + Cache cache.Cache + CacheAdapter string + CacheConfig string + + SessionProvider string + SessionConfig *session.Config + SessionManager *session.Manager + + PictureService string +) + +var Service struct { + RegisterEmailConfirm bool + DisenableRegisteration bool + RequireSignInView bool + EnableCacheAvatar bool + NotifyMail bool + ActiveCodeLives int + ResetPwdCodeLives int +} + +func ExecDir() (string, error) { + file, err := exec.LookPath(os.Args[0]) + if err != nil { + return "", err + } + p, err := filepath.Abs(file) + if err != nil { + return "", err + } + return path.Dir(strings.Replace(p, "\\", "/", -1)), nil +} + +var logLevels = map[string]string{ + "Trace": "0", + "Debug": "1", + "Info": "2", + "Warn": "3", + "Error": "4", + "Critical": "5", +} + +func newService() { + Service.ActiveCodeLives = Cfg.MustInt("service", "ACTIVE_CODE_LIVE_MINUTES", 180) + Service.ResetPwdCodeLives = Cfg.MustInt("service", "RESET_PASSWD_CODE_LIVE_MINUTES", 180) + Service.DisenableRegisteration = Cfg.MustBool("service", "DISENABLE_REGISTERATION", false) + Service.RequireSignInView = Cfg.MustBool("service", "REQUIRE_SIGNIN_VIEW", false) + Service.EnableCacheAvatar = Cfg.MustBool("service", "ENABLE_CACHE_AVATAR", false) +} + +func newLogService() { + // Get and check log mode. + LogMode = Cfg.MustValue("log", "MODE", "console") + modeSec := "log." + LogMode + if _, err := Cfg.GetSection(modeSec); err != nil { + qlog.Fatalf("Unknown log mode: %s\n", LogMode) + } + + // Log level. + levelName := Cfg.MustValue("log."+LogMode, "LEVEL", "Trace") + level, ok := logLevels[levelName] + if !ok { + qlog.Fatalf("Unknown log level: %s\n", levelName) + } + + // Generate log configuration. + switch LogMode { + case "console": + LogConfig = fmt.Sprintf(`{"level":%s}`, level) + case "file": + logPath := Cfg.MustValue(modeSec, "FILE_NAME", "log/gogs.log") + os.MkdirAll(path.Dir(logPath), os.ModePerm) + LogConfig = fmt.Sprintf( + `{"level":%s,"filename":"%s","rotate":%v,"maxlines":%d,"maxsize":%d,"daily":%v,"maxdays":%d}`, level, + logPath, + Cfg.MustBool(modeSec, "LOG_ROTATE", true), + Cfg.MustInt(modeSec, "MAX_LINES", 1000000), + 1<<uint(Cfg.MustInt(modeSec, "MAX_SIZE_SHIFT", 28)), + Cfg.MustBool(modeSec, "DAILY_ROTATE", true), + Cfg.MustInt(modeSec, "MAX_DAYS", 7)) + case "conn": + LogConfig = fmt.Sprintf(`{"level":"%s","reconnectOnMsg":%v,"reconnect":%v,"net":"%s","addr":"%s"}`, level, + Cfg.MustBool(modeSec, "RECONNECT_ON_MSG", false), + Cfg.MustBool(modeSec, "RECONNECT", false), + Cfg.MustValue(modeSec, "PROTOCOL", "tcp"), + Cfg.MustValue(modeSec, "ADDR", ":7020")) + case "smtp": + LogConfig = fmt.Sprintf(`{"level":"%s","username":"%s","password":"%s","host":"%s","sendTos":"%s","subject":"%s"}`, level, + Cfg.MustValue(modeSec, "USER", "example@example.com"), + Cfg.MustValue(modeSec, "PASSWD", "******"), + Cfg.MustValue(modeSec, "HOST", "127.0.0.1:25"), + Cfg.MustValue(modeSec, "RECEIVERS", "[]"), + Cfg.MustValue(modeSec, "SUBJECT", "Diagnostic message from serve")) + case "database": + LogConfig = fmt.Sprintf(`{"level":"%s","driver":"%s","conn":"%s"}`, level, + Cfg.MustValue(modeSec, "Driver"), + Cfg.MustValue(modeSec, "CONN")) + } + + log.Info("%s %s", AppName, AppVer) + log.NewLogger(Cfg.MustInt64("log", "BUFFER_LEN", 10000), LogMode, LogConfig) + log.Info("Log Mode: %s(%s)", strings.Title(LogMode), levelName) +} + +func newCacheService() { + CacheAdapter = Cfg.MustValue("cache", "ADAPTER", "memory") + + switch CacheAdapter { + case "memory": + CacheConfig = fmt.Sprintf(`{"interval":%d}`, Cfg.MustInt("cache", "INTERVAL", 60)) + case "redis", "memcache": + CacheConfig = fmt.Sprintf(`{"conn":"%s"}`, Cfg.MustValue("cache", "HOST")) + default: + qlog.Fatalf("Unknown cache adapter: %s\n", CacheAdapter) + } + + var err error + Cache, err = cache.NewCache(CacheAdapter, CacheConfig) + if err != nil { + qlog.Fatalf("Init cache system failed, adapter: %s, config: %s, %v\n", + CacheAdapter, CacheConfig, err) + } + + log.Info("Cache Service Enabled") +} + +func newSessionService() { + SessionProvider = Cfg.MustValue("session", "PROVIDER", "memory") + + SessionConfig = new(session.Config) + SessionConfig.ProviderConfig = Cfg.MustValue("session", "PROVIDER_CONFIG") + SessionConfig.CookieName = Cfg.MustValue("session", "COOKIE_NAME", "i_like_gogits") + SessionConfig.CookieSecure = Cfg.MustBool("session", "COOKIE_SECURE") + SessionConfig.EnableSetCookie = Cfg.MustBool("session", "ENABLE_SET_COOKIE", true) + SessionConfig.GcIntervalTime = Cfg.MustInt64("session", "GC_INTERVAL_TIME", 86400) + SessionConfig.SessionLifeTime = Cfg.MustInt64("session", "SESSION_LIFE_TIME", 86400) + SessionConfig.SessionIDHashFunc = Cfg.MustValue("session", "SESSION_ID_HASHFUNC", "sha1") + SessionConfig.SessionIDHashKey = Cfg.MustValue("session", "SESSION_ID_HASHKEY") + + if SessionProvider == "file" { + os.MkdirAll(path.Dir(SessionConfig.ProviderConfig), os.ModePerm) + } + + var err error + SessionManager, err = session.NewManager(SessionProvider, *SessionConfig) + if err != nil { + qlog.Fatalf("Init session system failed, provider: %s, %v\n", + SessionProvider, err) + } + + log.Info("Session Service Enabled") +} + +func newMailService() { + // Check mailer setting. + if !Cfg.MustBool("mailer", "ENABLED") { + return + } + + MailService = &Mailer{ + Name: Cfg.MustValue("mailer", "NAME", AppName), + Host: Cfg.MustValue("mailer", "HOST"), + User: Cfg.MustValue("mailer", "USER"), + Passwd: Cfg.MustValue("mailer", "PASSWD"), + } + log.Info("Mail Service Enabled") +} + +func newRegisterMailService() { + if !Cfg.MustBool("service", "REGISTER_EMAIL_CONFIRM") { + return + } else if MailService == nil { + log.Warn("Register Mail Service: Mail Service is not enabled") + return + } + Service.RegisterEmailConfirm = true + log.Info("Register Mail Service Enabled") +} + +func newNotifyMailService() { + if !Cfg.MustBool("service", "ENABLE_NOTIFY_MAIL") { + return + } else if MailService == nil { + log.Warn("Notify Mail Service: Mail Service is not enabled") + return + } + Service.NotifyMail = true + log.Info("Notify Mail Service Enabled") +} + +func newOauthService() { + if !Cfg.MustBool("oauth", "ENABLED") { + return + } + + OauthService = &Oauther{} + oauths := make([]string, 0, 10) + + // GitHub. + if Cfg.MustBool("oauth.github", "ENABLED") { + OauthService.GitHub.Enabled = true + OauthService.GitHub.ClientId = Cfg.MustValue("oauth.github", "CLIENT_ID") + OauthService.GitHub.ClientSecret = Cfg.MustValue("oauth.github", "CLIENT_SECRET") + OauthService.GitHub.Scopes = Cfg.MustValue("oauth.github", "SCOPES") + oauths = append(oauths, "GitHub") + } + + log.Info("Oauth Service Enabled %s", oauths) +} + +func NewConfigContext() { + //var err error + workDir, err := ExecDir() + if err != nil { + qlog.Fatalf("Fail to get work directory: %s\n", err) + } + + cfgPath := filepath.Join(workDir, "conf/app.ini") + Cfg, err = goconfig.LoadConfigFile(cfgPath) + if err != nil { + qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err) + } + Cfg.BlockMode = false + + cfgPath = filepath.Join(workDir, "custom/conf/app.ini") + if com.IsFile(cfgPath) { + if err = Cfg.AppendFiles(cfgPath); err != nil { + qlog.Fatalf("Cannot load config file(%s): %v\n", cfgPath, err) + } + } + + AppName = Cfg.MustValue("", "APP_NAME", "Gogs: Go Git Service") + AppLogo = Cfg.MustValue("", "APP_LOGO", "img/favicon.png") + AppUrl = Cfg.MustValue("server", "ROOT_URL") + Domain = Cfg.MustValue("server", "DOMAIN") + SecretKey = Cfg.MustValue("security", "SECRET_KEY") + + InstallLock = Cfg.MustBool("security", "INSTALL_LOCK", false) + + RunUser = Cfg.MustValue("", "RUN_USER") + curUser := os.Getenv("USERNAME") + if len(curUser) == 0 { + curUser = os.Getenv("USER") + } + // Does not check run user when the install lock is off. + if InstallLock && RunUser != curUser { + qlog.Fatalf("Expect user(%s) but current user is: %s\n", RunUser, curUser) + } + + LogInRememberDays = Cfg.MustInt("security", "LOGIN_REMEMBER_DAYS") + CookieUserName = Cfg.MustValue("security", "COOKIE_USERNAME") + CookieRememberName = Cfg.MustValue("security", "COOKIE_REMEMBER_NAME") + + PictureService = Cfg.MustValue("picture", "SERVICE") + + // Determine and create root git reposiroty path. + homeDir, err := com.HomeDir() + if err != nil { + qlog.Fatalf("Fail to get home directory): %v\n", err) + } + RepoRootPath = Cfg.MustValue("repository", "ROOT", filepath.Join(homeDir, "gogs-repositories")) + if err = os.MkdirAll(RepoRootPath, os.ModePerm); err != nil { + qlog.Fatalf("Fail to create RepoRootPath(%s): %v\n", RepoRootPath, err) + } +} + +func NewServices() { + newService() + newLogService() + newCacheService() + newSessionService() + newMailService() + newRegisterMailService() + newNotifyMailService() + newOauthService() +} diff --git a/modules/base/markdown.go b/modules/base/markdown.go new file mode 100644 index 00000000..cc180775 --- /dev/null +++ b/modules/base/markdown.go @@ -0,0 +1,168 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package base + +import ( + "bytes" + "fmt" + "net/http" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/gogits/gfm" +) + +func isletter(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} + +func isalnum(c byte) bool { + return (c >= '0' && c <= '9') || isletter(c) +} + +var validLinks = [][]byte{[]byte("http://"), []byte("https://"), []byte("ftp://"), []byte("mailto://")} + +func isLink(link []byte) bool { + for _, prefix := range validLinks { + if len(link) > len(prefix) && bytes.Equal(bytes.ToLower(link[:len(prefix)]), prefix) && isalnum(link[len(prefix)]) { + return true + } + } + + return false +} + +func IsMarkdownFile(name string) bool { + name = strings.ToLower(name) + switch filepath.Ext(name) { + case ".md", ".markdown", ".mdown": + return true + } + return false +} + +func IsTextFile(data []byte) (string, bool) { + contentType := http.DetectContentType(data) + if strings.Index(contentType, "text/") != -1 { + return contentType, true + } + return contentType, false +} + +func IsImageFile(data []byte) (string, bool) { + contentType := http.DetectContentType(data) + if strings.Index(contentType, "image/") != -1 { + return contentType, true + } + return contentType, false +} + +func IsReadmeFile(name string) bool { + name = strings.ToLower(name) + if len(name) < 6 { + return false + } + if name[:6] == "readme" { + return true + } + return false +} + +type CustomRender struct { + gfm.Renderer + urlPrefix string +} + +func (options *CustomRender) Link(out *bytes.Buffer, link []byte, title []byte, content []byte) { + if len(link) > 0 && !isLink(link) { + if link[0] == '#' { + // link = append([]byte(options.urlPrefix), link...) + } else { + link = []byte(path.Join(options.urlPrefix, string(link))) + } + } + + options.Renderer.Link(out, link, title, content) +} + +var ( + MentionPattern = regexp.MustCompile(`@[0-9a-zA-Z_]{1,}`) + commitPattern = regexp.MustCompile(`(\s|^)https?.*commit/[0-9a-zA-Z]+(#+[0-9a-zA-Z-]*)?`) + issueFullPattern = regexp.MustCompile(`(\s|^)https?.*issues/[0-9]+(#+[0-9a-zA-Z-]*)?`) + issueIndexPattern = regexp.MustCompile(`#[0-9]+`) +) + +func RenderSpecialLink(rawBytes []byte, urlPrefix string) []byte { + ms := MentionPattern.FindAll(rawBytes, -1) + for _, m := range ms { + rawBytes = bytes.Replace(rawBytes, m, + []byte(fmt.Sprintf(`<a href="/user/%s">%s</a>`, m[1:], m)), -1) + } + ms = commitPattern.FindAll(rawBytes, -1) + for _, m := range ms { + m = bytes.TrimSpace(m) + i := strings.Index(string(m), "commit/") + j := strings.Index(string(m), "#") + if j == -1 { + j = len(m) + } + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + ` <code><a href="%s">%s</a></code>`, m, ShortSha(string(m[i+7:j])))), -1) + } + ms = issueFullPattern.FindAll(rawBytes, -1) + for _, m := range ms { + m = bytes.TrimSpace(m) + i := strings.Index(string(m), "issues/") + j := strings.Index(string(m), "#") + if j == -1 { + j = len(m) + } + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + ` <a href="%s">#%s</a>`, m, ShortSha(string(m[i+7:j])))), -1) + } + ms = issueIndexPattern.FindAll(rawBytes, -1) + for _, m := range ms { + rawBytes = bytes.Replace(rawBytes, m, []byte(fmt.Sprintf( + `<a href="%s/issues/%s">%s</a>`, urlPrefix, m[1:], m)), -1) + } + return rawBytes +} + +func RenderMarkdown(rawBytes []byte, urlPrefix string) []byte { + body := RenderSpecialLink(rawBytes, urlPrefix) + // fmt.Println(string(body)) + htmlFlags := 0 + // htmlFlags |= gfm.HTML_USE_XHTML + // htmlFlags |= gfm.HTML_USE_SMARTYPANTS + // htmlFlags |= gfm.HTML_SMARTYPANTS_FRACTIONS + // htmlFlags |= gfm.HTML_SMARTYPANTS_LATEX_DASHES + // htmlFlags |= gfm.HTML_SKIP_HTML + htmlFlags |= gfm.HTML_SKIP_STYLE + htmlFlags |= gfm.HTML_SKIP_SCRIPT + htmlFlags |= gfm.HTML_GITHUB_BLOCKCODE + htmlFlags |= gfm.HTML_OMIT_CONTENTS + // htmlFlags |= gfm.HTML_COMPLETE_PAGE + renderer := &CustomRender{ + Renderer: gfm.HtmlRenderer(htmlFlags, "", ""), + urlPrefix: urlPrefix, + } + + // set up the parser + extensions := 0 + extensions |= gfm.EXTENSION_NO_INTRA_EMPHASIS + extensions |= gfm.EXTENSION_TABLES + extensions |= gfm.EXTENSION_FENCED_CODE + extensions |= gfm.EXTENSION_AUTOLINK + extensions |= gfm.EXTENSION_STRIKETHROUGH + extensions |= gfm.EXTENSION_HARD_LINE_BREAK + extensions |= gfm.EXTENSION_SPACE_HEADERS + extensions |= gfm.EXTENSION_NO_EMPTY_LINE_BEFORE_BLOCK + + body = gfm.Markdown(body, renderer, extensions) + // fmt.Println(string(body)) + return body +} diff --git a/modules/base/template.go b/modules/base/template.go new file mode 100644 index 00000000..5a42107c --- /dev/null +++ b/modules/base/template.go @@ -0,0 +1,196 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package base + +import ( + "bytes" + "container/list" + "encoding/json" + "fmt" + "html/template" + "strings" + "time" +) + +func Str2html(raw string) template.HTML { + return template.HTML(raw) +} + +func Range(l int) []int { + return make([]int, l) +} + +func List(l *list.List) chan interface{} { + e := l.Front() + c := make(chan interface{}) + go func() { + for e != nil { + c <- e.Value + e = e.Next() + } + close(c) + }() + return c +} + +func ShortSha(sha1 string) string { + if len(sha1) == 40 { + return sha1[:10] + } + return sha1 +} + +var mailDomains = map[string]string{ + "gmail.com": "gmail.com", +} + +var TemplateFuncs template.FuncMap = map[string]interface{}{ + "AppName": func() string { + return AppName + }, + "AppVer": func() string { + return AppVer + }, + "AppDomain": func() string { + return Domain + }, + "IsProdMode": func() bool { + return IsProdMode + }, + "LoadTimes": func(startTime time.Time) string { + return fmt.Sprint(time.Since(startTime).Nanoseconds()/1e6) + "ms" + }, + "AvatarLink": AvatarLink, + "str2html": Str2html, + "TimeSince": TimeSince, + "FileSize": FileSize, + "Subtract": Subtract, + "ActionIcon": ActionIcon, + "ActionDesc": ActionDesc, + "DateFormat": DateFormat, + "List": List, + "Mail2Domain": func(mail string) string { + if !strings.Contains(mail, "@") { + return "try.gogits.org" + } + + suffix := strings.SplitN(mail, "@", 2)[1] + domain, ok := mailDomains[suffix] + if !ok { + return "mail." + suffix + } + return domain + }, + "SubStr": func(str string, start, length int) string { + return str[start : start+length] + }, + "DiffTypeToStr": DiffTypeToStr, + "DiffLineTypeToStr": DiffLineTypeToStr, + "ShortSha": ShortSha, +} + +type Actioner interface { + GetOpType() int + GetActUserName() string + GetActEmail() string + GetRepoName() string + GetBranch() string + GetContent() string +} + +// ActionIcon accepts a int that represents action operation type +// and returns a icon class name. +func ActionIcon(opType int) string { + switch opType { + case 1: // Create repository. + return "plus-circle" + case 5: // Commit repository. + return "arrow-circle-o-right" + case 6: // Create issue. + return "exclamation-circle" + case 8: // Transfer repository. + return "share" + default: + return "invalid type" + } +} + +const ( + TPL_CREATE_REPO = `<a href="/user/%s">%s</a> created repository <a href="/%s">%s</a>` + TPL_COMMIT_REPO = `<a href="/user/%s">%s</a> pushed to <a href="/%s/src/%s">%s</a> at <a href="/%s">%s</a>%s` + TPL_COMMIT_REPO_LI = `<div><img src="%s?s=16" alt="user-avatar"/> <a href="/%s/commit/%s">%s</a> %s</div>` + TPL_CREATE_ISSUE = `<a href="/user/%s">%s</a> opened issue <a href="/%s/issues/%s">%s#%s</a> +<div><img src="%s?s=16" alt="user-avatar"/> %s</div>` + TPL_TRANSFER_REPO = `<a href="/user/%s">%s</a> transfered repository <code>%s</code> to <a href="/%s">%s</a>` +) + +type PushCommit struct { + Sha1 string + Message string + AuthorEmail string + AuthorName string +} + +type PushCommits struct { + Len int + Commits []*PushCommit +} + +// ActionDesc accepts int that represents action operation type +// and returns the description. +func ActionDesc(act Actioner) string { + actUserName := act.GetActUserName() + email := act.GetActEmail() + repoName := act.GetRepoName() + repoLink := actUserName + "/" + repoName + branch := act.GetBranch() + content := act.GetContent() + switch act.GetOpType() { + case 1: // Create repository. + return fmt.Sprintf(TPL_CREATE_REPO, actUserName, actUserName, repoLink, repoName) + case 5: // Commit repository. + var push *PushCommits + if err := json.Unmarshal([]byte(content), &push); err != nil { + return err.Error() + } + buf := bytes.NewBuffer([]byte("\n")) + for _, commit := range push.Commits { + buf.WriteString(fmt.Sprintf(TPL_COMMIT_REPO_LI, AvatarLink(commit.AuthorEmail), repoLink, commit.Sha1, commit.Sha1[:7], commit.Message) + "\n") + } + if push.Len > 3 { + buf.WriteString(fmt.Sprintf(`<div><a href="/%s/%s/commits/%s">%d other commits >></a></div>`, actUserName, repoName, branch, push.Len)) + } + return fmt.Sprintf(TPL_COMMIT_REPO, actUserName, actUserName, repoLink, branch, branch, repoLink, repoLink, + buf.String()) + case 6: // Create issue. + infos := strings.SplitN(content, "|", 2) + return fmt.Sprintf(TPL_CREATE_ISSUE, actUserName, actUserName, repoLink, infos[0], repoLink, infos[0], + AvatarLink(email), infos[1]) + case 8: // Transfer repository. + newRepoLink := content + "/" + repoName + return fmt.Sprintf(TPL_TRANSFER_REPO, actUserName, actUserName, repoLink, newRepoLink, newRepoLink) + default: + return "invalid type" + } +} + +func DiffTypeToStr(diffType int) string { + diffTypes := map[int]string{ + 1: "add", 2: "modify", 3: "del", + } + return diffTypes[diffType] +} + +func DiffLineTypeToStr(diffType int) string { + switch diffType { + case 2: + return "add" + case 3: + return "del" + case 4: + return "tag" + } + return "same" +} diff --git a/modules/base/tool.go b/modules/base/tool.go new file mode 100644 index 00000000..0f06b3e0 --- /dev/null +++ b/modules/base/tool.go @@ -0,0 +1,514 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package base + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "encoding/hex" + "fmt" + "hash" + "math" + "strconv" + "strings" + "time" +) + +// Encode string to md5 hex value +func EncodeMd5(str string) string { + m := md5.New() + m.Write([]byte(str)) + return hex.EncodeToString(m.Sum(nil)) +} + +// GetRandomString generate random string by specify chars. +func GetRandomString(n int, alphabets ...byte) string { + const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + var bytes = make([]byte, n) + rand.Read(bytes) + for i, b := range bytes { + if len(alphabets) == 0 { + bytes[i] = alphanum[b%byte(len(alphanum))] + } else { + bytes[i] = alphabets[b%byte(len(alphabets))] + } + } + return string(bytes) +} + +// http://code.google.com/p/go/source/browse/pbkdf2/pbkdf2.go?repo=crypto +func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} + +// verify time limit code +func VerifyTimeLimitCode(data string, minutes int, code string) bool { + if len(code) <= 18 { + return false + } + + // split code + start := code[:12] + lives := code[12:18] + if d, err := StrTo(lives).Int(); err == nil { + minutes = d + } + + // right active code + retCode := CreateTimeLimitCode(data, minutes, start) + if retCode == code && minutes > 0 { + // check time is expired or not + before, _ := DateParse(start, "YmdHi") + now := time.Now() + if before.Add(time.Minute*time.Duration(minutes)).Unix() > now.Unix() { + return true + } + } + + return false +} + +const TimeLimitCodeLength = 12 + 6 + 40 + +// create a time limit code +// code format: 12 length date time string + 6 minutes string + 40 sha1 encoded string +func CreateTimeLimitCode(data string, minutes int, startInf interface{}) string { + format := "YmdHi" + + var start, end time.Time + var startStr, endStr string + + if startInf == nil { + // Use now time create code + start = time.Now() + startStr = DateFormat(start, format) + } else { + // use start string create code + startStr = startInf.(string) + start, _ = DateParse(startStr, format) + startStr = DateFormat(start, format) + } + + end = start.Add(time.Minute * time.Duration(minutes)) + endStr = DateFormat(end, format) + + // create sha1 encode string + sh := sha1.New() + sh.Write([]byte(data + SecretKey + startStr + endStr + ToStr(minutes))) + encoded := hex.EncodeToString(sh.Sum(nil)) + + code := fmt.Sprintf("%s%06d%s", startStr, minutes, encoded) + return code +} + +// AvatarLink returns avatar link by given e-mail. +func AvatarLink(email string) string { + if Service.EnableCacheAvatar { + return "/avatar/" + EncodeMd5(email) + } + return "http://1.gravatar.com/avatar/" + EncodeMd5(email) +} + +// Seconds-based time units +const ( + Minute = 60 + Hour = 60 * Minute + Day = 24 * Hour + Week = 7 * Day + Month = 30 * Day + Year = 12 * Month +) + +func computeTimeDiff(diff int64) (int64, string) { + diffStr := "" + switch { + case diff <= 0: + diff = 0 + diffStr = "now" + case diff < 2: + diff = 0 + diffStr = "1 second" + case diff < 1*Minute: + diffStr = fmt.Sprintf("%d seconds", diff) + diff = 0 + + case diff < 2*Minute: + diff -= 1 * Minute + diffStr = "1 minute" + case diff < 1*Hour: + diffStr = fmt.Sprintf("%d minutes", diff/Minute) + diff -= diff / Minute * Minute + + case diff < 2*Hour: + diff -= 1 * Hour + diffStr = "1 hour" + case diff < 1*Day: + diffStr = fmt.Sprintf("%d hours", diff/Hour) + diff -= diff / Hour * Hour + + case diff < 2*Day: + diff -= 1 * Day + diffStr = "1 day" + case diff < 1*Week: + diffStr = fmt.Sprintf("%d days", diff/Day) + diff -= diff / Day * Day + + case diff < 2*Week: + diff -= 1 * Week + diffStr = "1 week" + case diff < 1*Month: + diffStr = fmt.Sprintf("%d weeks", diff/Week) + diff -= diff / Week * Week + + case diff < 2*Month: + diff -= 1 * Month + diffStr = "1 month" + case diff < 1*Year: + diffStr = fmt.Sprintf("%d months", diff/Month) + diff -= diff / Month * Month + + case diff < 2*Year: + diff -= 1 * Year + diffStr = "1 year" + default: + diffStr = fmt.Sprintf("%d years", diff/Year) + diff = 0 + } + return diff, diffStr +} + +// TimeSincePro calculates the time interval and generate full user-friendly string. +func TimeSincePro(then time.Time) string { + now := time.Now() + diff := now.Unix() - then.Unix() + + if then.After(now) { + return "future" + } + + var timeStr, diffStr string + for { + if diff == 0 { + break + } + + diff, diffStr = computeTimeDiff(diff) + timeStr += ", " + diffStr + } + return strings.TrimPrefix(timeStr, ", ") +} + +// TimeSince calculates the time interval and generate user-friendly string. +func TimeSince(then time.Time) string { + now := time.Now() + + lbl := "ago" + diff := now.Unix() - then.Unix() + if then.After(now) { + lbl = "from now" + diff = then.Unix() - now.Unix() + } + + switch { + case diff <= 0: + return "now" + case diff <= 2: + return fmt.Sprintf("1 second %s", lbl) + case diff < 1*Minute: + return fmt.Sprintf("%d seconds %s", diff, lbl) + + case diff < 2*Minute: + return fmt.Sprintf("1 minute %s", lbl) + case diff < 1*Hour: + return fmt.Sprintf("%d minutes %s", diff/Minute, lbl) + + case diff < 2*Hour: + return fmt.Sprintf("1 hour %s", lbl) + case diff < 1*Day: + return fmt.Sprintf("%d hours %s", diff/Hour, lbl) + + case diff < 2*Day: + return fmt.Sprintf("1 day %s", lbl) + case diff < 1*Week: + return fmt.Sprintf("%d days %s", diff/Day, lbl) + + case diff < 2*Week: + return fmt.Sprintf("1 week %s", lbl) + case diff < 1*Month: + return fmt.Sprintf("%d weeks %s", diff/Week, lbl) + + case diff < 2*Month: + return fmt.Sprintf("1 month %s", lbl) + case diff < 1*Year: + return fmt.Sprintf("%d months %s", diff/Month, lbl) + + case diff < 2*Year: + return fmt.Sprintf("1 year %s", lbl) + default: + return fmt.Sprintf("%d years %s", diff/Year, lbl) + } + return then.String() +} + +const ( + Byte = 1 + KByte = Byte * 1024 + MByte = KByte * 1024 + GByte = MByte * 1024 + TByte = GByte * 1024 + PByte = TByte * 1024 + EByte = PByte * 1024 +) + +var bytesSizeTable = map[string]uint64{ + "b": Byte, + "kb": KByte, + "mb": MByte, + "gb": GByte, + "tb": TByte, + "pb": PByte, + "eb": EByte, +} + +func logn(n, b float64) float64 { + return math.Log(n) / math.Log(b) +} + +func humanateBytes(s uint64, base float64, sizes []string) string { + if s < 10 { + return fmt.Sprintf("%dB", s) + } + e := math.Floor(logn(float64(s), base)) + suffix := sizes[int(e)] + val := float64(s) / math.Pow(base, math.Floor(e)) + f := "%.0f" + if val < 10 { + f = "%.1f" + } + + return fmt.Sprintf(f+"%s", val, suffix) +} + +// FileSize calculates the file size and generate user-friendly string. +func FileSize(s int64) string { + sizes := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB"} + return humanateBytes(uint64(s), 1024, sizes) +} + +// Subtract deals with subtraction of all types of number. +func Subtract(left interface{}, right interface{}) interface{} { + var rleft, rright int64 + var fleft, fright float64 + var isInt bool = true + switch left.(type) { + case int: + rleft = int64(left.(int)) + case int8: + rleft = int64(left.(int8)) + case int16: + rleft = int64(left.(int16)) + case int32: + rleft = int64(left.(int32)) + case int64: + rleft = left.(int64) + case float32: + fleft = float64(left.(float32)) + isInt = false + case float64: + fleft = left.(float64) + isInt = false + } + + switch right.(type) { + case int: + rright = int64(right.(int)) + case int8: + rright = int64(right.(int8)) + case int16: + rright = int64(right.(int16)) + case int32: + rright = int64(right.(int32)) + case int64: + rright = right.(int64) + case float32: + fright = float64(left.(float32)) + isInt = false + case float64: + fleft = left.(float64) + isInt = false + } + + if isInt { + return rleft - rright + } else { + return fleft + float64(rleft) - (fright + float64(rright)) + } +} + +// DateFormat pattern rules. +var datePatterns = []string{ + // year + "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 + "y", "06", //A two digit representation of a year Examples: 99 or 03 + + // month + "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 + "n", "1", // Numeric representation of a month, without leading zeros 1 through 12 + "M", "Jan", // A short textual representation of a month, three letters Jan through Dec + "F", "January", // A full textual representation of a month, such as January or March January through December + + // day + "d", "02", // Day of the month, 2 digits with leading zeros 01 to 31 + "j", "2", // Day of the month without leading zeros 1 to 31 + + // week + "D", "Mon", // A textual representation of a day, three letters Mon through Sun + "l", "Monday", // A full textual representation of the day of the week Sunday through Saturday + + // time + "g", "3", // 12-hour format of an hour without leading zeros 1 through 12 + "G", "15", // 24-hour format of an hour without leading zeros 0 through 23 + "h", "03", // 12-hour format of an hour with leading zeros 01 through 12 + "H", "15", // 24-hour format of an hour with leading zeros 00 through 23 + + "a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm + "A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM + + "i", "04", // Minutes with leading zeros 00 to 59 + "s", "05", // Seconds, with leading zeros 00 through 59 + + // time zone + "T", "MST", + "P", "-07:00", + "O", "-0700", + + // RFC 2822 + "r", time.RFC1123Z, +} + +// Parse Date use PHP time format. +func DateParse(dateString, format string) (time.Time, error) { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return time.ParseInLocation(format, dateString, time.Local) +} + +// Date takes a PHP like date func to Go's time format. +func DateFormat(t time.Time, format string) string { + replacer := strings.NewReplacer(datePatterns...) + format = replacer.Replace(format) + return t.Format(format) +} + +// convert string to specify type + +type StrTo string + +func (f StrTo) Exist() bool { + return string(f) != string(0x1E) +} + +func (f StrTo) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +func (f StrTo) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + return int64(v), err +} + +func (f StrTo) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// convert any type to string +func ToStr(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +type argInt []int + +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} diff --git a/modules/log/log.go b/modules/log/log.go new file mode 100644 index 00000000..f21897b9 --- /dev/null +++ b/modules/log/log.go @@ -0,0 +1,49 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package log is a wrapper of logs for short calling name. +package log + +import ( + "github.com/gogits/logs" +) + +var ( + logger *logs.BeeLogger + Mode, Config string +) + +func init() { + NewLogger(0, "console", `{"level": 0}`) +} + +func NewLogger(bufLen int64, mode, config string) { + Mode, Config = mode, config + logger = logs.NewLogger(bufLen) + logger.SetLogger(mode, config) +} + +func Trace(format string, v ...interface{}) { + logger.Trace(format, v...) +} + +func Debug(format string, v ...interface{}) { + logger.Debug(format, v...) +} + +func Info(format string, v ...interface{}) { + logger.Info(format, v...) +} + +func Error(format string, v ...interface{}) { + logger.Error(format, v...) +} + +func Warn(format string, v ...interface{}) { + logger.Warn(format, v...) +} + +func Critical(format string, v ...interface{}) { + logger.Critical(format, v...) +} diff --git a/modules/mailer/mail.go b/modules/mailer/mail.go new file mode 100644 index 00000000..d2bf1310 --- /dev/null +++ b/modules/mailer/mail.go @@ -0,0 +1,162 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package mailer + +import ( + "encoding/hex" + "errors" + "fmt" + + "github.com/gogits/gogs/models" + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" + "github.com/gogits/gogs/modules/middleware" +) + +// Create New mail message use MailFrom and MailUser +func NewMailMessageFrom(To []string, from, subject, body string) Message { + msg := NewHtmlMessage(To, from, subject, body) + msg.User = base.MailService.User + return msg +} + +// Create New mail message use MailFrom and MailUser +func NewMailMessage(To []string, subject, body string) Message { + return NewMailMessageFrom(To, base.MailService.User, subject, body) +} + +func GetMailTmplData(user *models.User) map[interface{}]interface{} { + data := make(map[interface{}]interface{}, 10) + data["AppName"] = base.AppName + data["AppVer"] = base.AppVer + data["AppUrl"] = base.AppUrl + data["AppLogo"] = base.AppLogo + data["ActiveCodeLives"] = base.Service.ActiveCodeLives / 60 + data["ResetPwdCodeLives"] = base.Service.ResetPwdCodeLives / 60 + if user != nil { + data["User"] = user + } + return data +} + +// create a time limit code for user active +func CreateUserActiveCode(user *models.User, startInf interface{}) string { + minutes := base.Service.ActiveCodeLives + data := base.ToStr(user.Id) + user.Email + user.LowerName + user.Passwd + user.Rands + code := base.CreateTimeLimitCode(data, minutes, startInf) + + // add tail hex username + code += hex.EncodeToString([]byte(user.LowerName)) + return code +} + +// Send user register mail with active code +func SendRegisterMail(r *middleware.Render, user *models.User) { + code := CreateUserActiveCode(user, nil) + subject := "Register success, Welcome" + + data := GetMailTmplData(user) + data["Code"] = code + body, err := r.HTMLString("mail/auth/register_success", data) + if err != nil { + log.Error("mail.SendRegisterMail(fail to render): %v", err) + return + } + + msg := NewMailMessage([]string{user.Email}, subject, body) + msg.Info = fmt.Sprintf("UID: %d, send register mail", user.Id) + + SendAsync(&msg) +} + +// Send email verify active email. +func SendActiveMail(r *middleware.Render, user *models.User) { + code := CreateUserActiveCode(user, nil) + + subject := "Verify your e-mail address" + + data := GetMailTmplData(user) + data["Code"] = code + body, err := r.HTMLString("mail/auth/active_email", data) + if err != nil { + log.Error("mail.SendActiveMail(fail to render): %v", err) + return + } + + msg := NewMailMessage([]string{user.Email}, subject, body) + msg.Info = fmt.Sprintf("UID: %d, send active mail", user.Id) + + SendAsync(&msg) +} + +// Send reset password email. +func SendResetPasswdMail(r *middleware.Render, user *models.User) { + code := CreateUserActiveCode(user, nil) + + subject := "Reset your password" + + data := GetMailTmplData(user) + data["Code"] = code + body, err := r.HTMLString("mail/auth/reset_passwd", data) + if err != nil { + log.Error("mail.SendResetPasswdMail(fail to render): %v", err) + return + } + + msg := NewMailMessage([]string{user.Email}, subject, body) + msg.Info = fmt.Sprintf("UID: %d, send reset password email", user.Id) + + SendAsync(&msg) +} + +// SendIssueNotifyMail sends mail notification of all watchers of repository. +func SendIssueNotifyMail(user, owner *models.User, repo *models.Repository, issue *models.Issue) ([]string, error) { + watches, err := models.GetWatches(repo.Id) + if err != nil { + return nil, errors.New("mail.NotifyWatchers(get watches): " + err.Error()) + } + + tos := make([]string, 0, len(watches)) + for i := range watches { + uid := watches[i].UserId + if user.Id == uid { + continue + } + u, err := models.GetUserById(uid) + if err != nil { + return nil, errors.New("mail.NotifyWatchers(get user): " + err.Error()) + } + tos = append(tos, u.Email) + } + + if len(tos) == 0 { + return tos, nil + } + + subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name) + content := fmt.Sprintf("%s<br>-<br> <a href=\"%s%s/%s/issues/%d\">View it on Gogs</a>.", + base.RenderSpecialLink([]byte(issue.Content), owner.Name+"/"+repo.Name), + base.AppUrl, owner.Name, repo.Name, issue.Index) + msg := NewMailMessageFrom(tos, user.Name, subject, content) + msg.Info = fmt.Sprintf("Subject: %s, send issue notify emails", subject) + SendAsync(&msg) + return tos, nil +} + +// SendIssueMentionMail sends mail notification for who are mentioned in issue. +func SendIssueMentionMail(user, owner *models.User, repo *models.Repository, issue *models.Issue, tos []string) error { + if len(tos) == 0 { + return nil + } + + issueLink := fmt.Sprintf("%s%s/%s/issues/%d", base.AppUrl, owner.Name, repo.Name, issue.Index) + body := fmt.Sprintf(`%s mentioned you.`) + subject := fmt.Sprintf("[%s] %s", repo.Name, issue.Name) + content := fmt.Sprintf("%s<br>-<br> <a href=\"%s\">View it on Gogs</a>.", body, issueLink) + msg := NewMailMessageFrom(tos, user.Name, subject, content) + msg.Info = fmt.Sprintf("Subject: %s, send issue mention emails", subject) + SendAsync(&msg) + return nil +} diff --git a/modules/mailer/mailer.go b/modules/mailer/mailer.go new file mode 100644 index 00000000..63861d87 --- /dev/null +++ b/modules/mailer/mailer.go @@ -0,0 +1,126 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package mailer + +import ( + "fmt" + "net/smtp" + "strings" + + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +type Message struct { + To []string + From string + Subject string + Body string + User string + Type string + Massive bool + Info string +} + +// create mail content +func (m Message) Content() string { + // set mail type + contentType := "text/plain; charset=UTF-8" + if m.Type == "html" { + contentType = "text/html; charset=UTF-8" + } + + // create mail content + content := "From: " + m.From + "<" + m.User + + ">\r\nSubject: " + m.Subject + "\r\nContent-Type: " + contentType + "\r\n\r\n" + m.Body + return content +} + +var mailQueue chan *Message + +func NewMailerContext() { + mailQueue = make(chan *Message, base.Cfg.MustInt("mailer", "SEND_BUFFER_LEN", 10)) + go processMailQueue() +} + +func processMailQueue() { + for { + select { + case msg := <-mailQueue: + num, err := Send(msg) + tos := strings.Join(msg.To, "; ") + info := "" + if err != nil { + if len(msg.Info) > 0 { + info = ", info: " + msg.Info + } + log.Error(fmt.Sprintf("Async sent email %d succeed, not send emails: %s%s err: %s", num, tos, info, err)) + } else { + log.Trace(fmt.Sprintf("Async sent email %d succeed, sent emails: %s%s", num, tos, info)) + } + } + } +} + +// Direct Send mail message +func Send(msg *Message) (int, error) { + log.Trace("Sending mails to: %s", strings.Join(msg.To, "; ")) + host := strings.Split(base.MailService.Host, ":") + + // get message body + content := msg.Content() + + auth := smtp.PlainAuth("", base.MailService.User, base.MailService.Passwd, host[0]) + + if len(msg.To) == 0 { + return 0, fmt.Errorf("empty receive emails") + } + + if len(msg.Body) == 0 { + return 0, fmt.Errorf("empty email body") + } + + if msg.Massive { + // send mail to multiple emails one by one + num := 0 + for _, to := range msg.To { + body := []byte("To: " + to + "\r\n" + content) + err := smtp.SendMail(base.MailService.Host, auth, msg.From, []string{to}, body) + if err != nil { + return num, err + } + num++ + } + return num, nil + } else { + body := []byte("To: " + strings.Join(msg.To, ";") + "\r\n" + content) + + // send to multiple emails in one message + err := smtp.SendMail(base.MailService.Host, auth, msg.From, msg.To, body) + if err != nil { + return 0, err + } else { + return 1, nil + } + } +} + +// Async Send mail message +func SendAsync(msg *Message) { + go func() { + mailQueue <- msg + }() +} + +// Create html mail message +func NewHtmlMessage(To []string, From, Subject, Body string) Message { + return Message{ + To: To, + From: From, + Subject: Subject, + Body: Body, + Type: "html", + } +} diff --git a/modules/middleware/auth.go b/modules/middleware/auth.go new file mode 100644 index 00000000..bde3be72 --- /dev/null +++ b/modules/middleware/auth.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "net/url" + + "github.com/go-martini/martini" + + "github.com/gogits/gogs/modules/base" +) + +type ToggleOptions struct { + SignInRequire bool + SignOutRequire bool + AdminRequire bool + DisableCsrf bool +} + +func Toggle(options *ToggleOptions) martini.Handler { + return func(ctx *Context) { + if !base.InstallLock { + ctx.Redirect("/install") + return + } + + if options.SignOutRequire && ctx.IsSigned && ctx.Req.RequestURI != "/" { + ctx.Redirect("/") + return + } + + if !options.DisableCsrf { + if ctx.Req.Method == "POST" { + if !ctx.CsrfTokenValid() { + ctx.Error(403, "CSRF token does not match") + return + } + } + } + + if options.SignInRequire { + if !ctx.IsSigned { + ctx.SetCookie("redirect_to", "/"+url.QueryEscape(ctx.Req.RequestURI)) + ctx.Redirect("/user/login") + return + } else if !ctx.User.IsActive && base.Service.RegisterEmailConfirm { + ctx.Data["Title"] = "Activate Your Account" + ctx.HTML(200, "user/active") + return + } + } + + if options.AdminRequire { + if !ctx.User.IsAdmin { + ctx.Error(403) + return + } + ctx.Data["PageIsAdmin"] = true + } + } +} diff --git a/modules/middleware/context.go b/modules/middleware/context.go new file mode 100644 index 00000000..8129b13b --- /dev/null +++ b/modules/middleware/context.go @@ -0,0 +1,286 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "html/template" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-martini/martini" + + "github.com/gogits/cache" + "github.com/gogits/git" + "github.com/gogits/session" + + "github.com/gogits/gogs/models" + "github.com/gogits/gogs/modules/auth" + "github.com/gogits/gogs/modules/base" + "github.com/gogits/gogs/modules/log" +) + +// Context represents context of a request. +type Context struct { + *Render + c martini.Context + p martini.Params + Req *http.Request + Res http.ResponseWriter + Session session.SessionStore + Cache cache.Cache + User *models.User + IsSigned bool + + csrfToken string + + Repo struct { + IsOwner bool + IsWatching bool + IsBranch bool + IsTag bool + IsCommit bool + Repository *models.Repository + Owner *models.User + Commit *git.Commit + GitRepo *git.Repository + BranchName string + CommitId string + RepoLink string + CloneLink struct { + SSH string + HTTPS string + Git string + } + } +} + +// Query querys form parameter. +func (ctx *Context) Query(name string) string { + ctx.Req.ParseForm() + return ctx.Req.Form.Get(name) +} + +// func (ctx *Context) Param(name string) string { +// return ctx.p[name] +// } + +// HasError returns true if error occurs in form validation. +func (ctx *Context) HasError() bool { + hasErr, ok := ctx.Data["HasError"] + if !ok { + return false + } + return hasErr.(bool) +} + +// HTML calls render.HTML underlying but reduce one argument. +func (ctx *Context) HTML(status int, name string, htmlOpt ...HTMLOptions) { + ctx.Render.HTML(status, name, ctx.Data, htmlOpt...) +} + +// RenderWithErr used for page has form validation but need to prompt error to users. +func (ctx *Context) RenderWithErr(msg, tpl string, form auth.Form) { + ctx.Data["HasError"] = true + ctx.Data["ErrorMsg"] = msg + if form != nil { + auth.AssignForm(form, ctx.Data) + } + ctx.HTML(200, tpl) +} + +// Handle handles and logs error by given status. +func (ctx *Context) Handle(status int, title string, err error) { + log.Error("%s: %v", title, err) + if martini.Dev == martini.Prod { + ctx.HTML(500, "status/500") + return + } + + ctx.Data["ErrorMsg"] = err + ctx.HTML(status, fmt.Sprintf("status/%d", status)) +} + +func (ctx *Context) Debug(msg string, args ...interface{}) { + log.Debug(msg, args...) +} + +func (ctx *Context) GetCookie(name string) string { + cookie, err := ctx.Req.Cookie(name) + if err != nil { + return "" + } + return cookie.Value +} + +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + cookie := http.Cookie{} + cookie.Name = name + cookie.Value = value + + if len(others) > 0 { + switch v := others[0].(type) { + case int: + cookie.MaxAge = v + case int64: + cookie.MaxAge = int(v) + case int32: + cookie.MaxAge = int(v) + } + } + + // default "/" + if len(others) > 1 { + if v, ok := others[1].(string); ok && len(v) > 0 { + cookie.Path = v + } + } else { + cookie.Path = "/" + } + + // default empty + if len(others) > 2 { + if v, ok := others[2].(string); ok && len(v) > 0 { + cookie.Domain = v + } + } + + // default empty + if len(others) > 3 { + switch v := others[3].(type) { + case bool: + cookie.Secure = v + default: + if others[3] != nil { + cookie.Secure = true + } + } + } + + // default false. for session cookie default true + if len(others) > 4 { + if v, ok := others[4].(bool); ok && v { + cookie.HttpOnly = true + } + } + + ctx.Res.Header().Add("Set-Cookie", cookie.String()) +} + +// Get secure cookie from request by a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.GetCookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha1.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// Set Secure cookie for response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha1.New, []byte(Secret)) + fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.SetCookie(name, cookie, others...) +} + +func (ctx *Context) CsrfToken() string { + if len(ctx.csrfToken) > 0 { + return ctx.csrfToken + } + + token := ctx.GetCookie("_csrf") + if len(token) == 0 { + token = base.GetRandomString(30) + ctx.SetCookie("_csrf", token) + } + ctx.csrfToken = token + return token +} + +func (ctx *Context) CsrfTokenValid() bool { + token := ctx.Query("_csrf") + if token == "" { + token = ctx.Req.Header.Get("X-Csrf-Token") + } + if token == "" { + return false + } else if ctx.csrfToken != token { + return false + } + return true +} + +// InitContext initializes a classic context for a request. +func InitContext() martini.Handler { + return func(res http.ResponseWriter, r *http.Request, c martini.Context, rd *Render) { + + ctx := &Context{ + c: c, + // p: p, + Req: r, + Res: res, + Cache: base.Cache, + Render: rd, + } + + ctx.Data["PageStartTime"] = time.Now() + + // start session + ctx.Session = base.SessionManager.SessionStart(res, r) + rw := res.(martini.ResponseWriter) + rw.Before(func(martini.ResponseWriter) { + ctx.Session.SessionRelease(res) + }) + + // Get user from session if logined. + user := auth.SignedInUser(ctx.Session) + ctx.User = user + ctx.IsSigned = user != nil + + ctx.Data["IsSigned"] = ctx.IsSigned + + if user != nil { + ctx.Data["SignedUser"] = user + ctx.Data["SignedUserId"] = user.Id + ctx.Data["SignedUserName"] = user.Name + ctx.Data["IsAdmin"] = ctx.User.IsAdmin + } + + // get or create csrf token + ctx.Data["CsrfToken"] = ctx.CsrfToken() + ctx.Data["CsrfTokenHtml"] = template.HTML(`<input type="hidden" name="_csrf" value="` + ctx.csrfToken + `">`) + + c.Map(ctx) + + c.Next() + } +} diff --git a/modules/middleware/logger.go b/modules/middleware/logger.go new file mode 100644 index 00000000..fc8e1a81 --- /dev/null +++ b/modules/middleware/logger.go @@ -0,0 +1,46 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "fmt" + "log" + "net/http" + "runtime" + "time" + + "github.com/go-martini/martini" +) + +var isWindows bool + +func init() { + isWindows = runtime.GOOS == "windows" +} + +func Logger() martini.Handler { + return func(res http.ResponseWriter, req *http.Request, ctx martini.Context, log *log.Logger) { + start := time.Now() + log.Printf("Started %s %s", req.Method, req.URL.Path) + + rw := res.(martini.ResponseWriter) + ctx.Next() + + content := fmt.Sprintf("Completed %v %s in %v", rw.Status(), http.StatusText(rw.Status()), time.Since(start)) + if !isWindows { + switch rw.Status() { + case 200: + content = fmt.Sprintf("\033[1;32m%s\033[0m", content) + case 304: + content = fmt.Sprintf("\033[1;33m%s\033[0m", content) + case 404: + content = fmt.Sprintf("\033[1;31m%s\033[0m", content) + case 500: + content = fmt.Sprintf("\033[1;36m%s\033[0m", content) + } + } + log.Println(content) + } +} diff --git a/modules/middleware/render.go b/modules/middleware/render.go new file mode 100644 index 00000000..66289988 --- /dev/null +++ b/modules/middleware/render.go @@ -0,0 +1,289 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// foked from https://github.com/martini-contrib/render/blob/master/render.go +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "html/template" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/go-martini/martini" + + "github.com/gogits/gogs/modules/base" +) + +const ( + ContentType = "Content-Type" + ContentLength = "Content-Length" + ContentJSON = "application/json" + ContentHTML = "text/html" + ContentXHTML = "application/xhtml+xml" + defaultCharset = "UTF-8" +) + +var helperFuncs = template.FuncMap{ + "yield": func() (string, error) { + return "", fmt.Errorf("yield called with no layout defined") + }, +} + +type Delims struct { + Left string + + Right string +} + +type RenderOptions struct { + Directory string + + Layout string + + Extensions []string + + Funcs []template.FuncMap + + Delims Delims + + Charset string + + IndentJSON bool + + HTMLContentType string +} + +type HTMLOptions struct { + Layout string +} + +func Renderer(options ...RenderOptions) martini.Handler { + opt := prepareOptions(options) + cs := prepareCharset(opt.Charset) + t := compile(opt) + return func(res http.ResponseWriter, req *http.Request, c martini.Context) { + var tc *template.Template + if martini.Env == martini.Dev { + + tc = compile(opt) + } else { + + tc, _ = t.Clone() + } + + rd := &Render{res, req, tc, opt, cs, base.TmplData{}, time.Time{}} + + rd.Data["TmplLoadTimes"] = func() string { + if rd.startTime.IsZero() { + return "" + } + return fmt.Sprint(time.Since(rd.startTime).Nanoseconds()/1e6) + "ms" + } + + c.Map(rd.Data) + c.Map(rd) + } +} + +func prepareCharset(charset string) string { + if len(charset) != 0 { + return "; charset=" + charset + } + + return "; charset=" + defaultCharset +} + +func prepareOptions(options []RenderOptions) RenderOptions { + var opt RenderOptions + if len(options) > 0 { + opt = options[0] + } + + if len(opt.Directory) == 0 { + opt.Directory = "templates" + } + if len(opt.Extensions) == 0 { + opt.Extensions = []string{".tmpl"} + } + if len(opt.HTMLContentType) == 0 { + opt.HTMLContentType = ContentHTML + } + + return opt +} + +func compile(options RenderOptions) *template.Template { + dir := options.Directory + t := template.New(dir) + t.Delims(options.Delims.Left, options.Delims.Right) + + template.Must(t.Parse("Martini")) + + filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + r, err := filepath.Rel(dir, path) + if err != nil { + return err + } + + ext := filepath.Ext(r) + for _, extension := range options.Extensions { + if ext == extension { + + buf, err := ioutil.ReadFile(path) + if err != nil { + panic(err) + } + + name := (r[0 : len(r)-len(ext)]) + tmpl := t.New(filepath.ToSlash(name)) + + for _, funcs := range options.Funcs { + tmpl = tmpl.Funcs(funcs) + } + + template.Must(tmpl.Funcs(helperFuncs).Parse(string(buf))) + break + } + } + + return nil + }) + + return t +} + +type Render struct { + http.ResponseWriter + req *http.Request + t *template.Template + opt RenderOptions + compiledCharset string + + Data base.TmplData + + startTime time.Time +} + +func (r *Render) JSON(status int, v interface{}) { + var result []byte + var err error + if r.opt.IndentJSON { + result, err = json.MarshalIndent(v, "", " ") + } else { + result, err = json.Marshal(v) + } + if err != nil { + http.Error(r, err.Error(), 500) + return + } + + r.Header().Set(ContentType, ContentJSON+r.compiledCharset) + r.WriteHeader(status) + r.Write(result) +} + +func (r *Render) JSONString(v interface{}) (string, error) { + var result []byte + var err error + if r.opt.IndentJSON { + result, err = json.MarshalIndent(v, "", " ") + } else { + result, err = json.Marshal(v) + } + if err != nil { + return "", err + } + return string(result), nil +} + +func (r *Render) renderBytes(name string, binding interface{}, htmlOpt ...HTMLOptions) (*bytes.Buffer, error) { + opt := r.prepareHTMLOptions(htmlOpt) + + if len(opt.Layout) > 0 { + r.addYield(name, binding) + name = opt.Layout + } + + out, err := r.execute(name, binding) + if err != nil { + return nil, err + } + + return out, nil +} + +func (r *Render) HTML(status int, name string, binding interface{}, htmlOpt ...HTMLOptions) { + r.startTime = time.Now() + + out, err := r.renderBytes(name, binding, htmlOpt...) + if err != nil { + http.Error(r, err.Error(), http.StatusInternalServerError) + return + } + + r.Header().Set(ContentType, r.opt.HTMLContentType+r.compiledCharset) + r.WriteHeader(status) + io.Copy(r, out) +} + +func (r *Render) HTMLString(name string, binding interface{}, htmlOpt ...HTMLOptions) (string, error) { + if out, err := r.renderBytes(name, binding, htmlOpt...); err != nil { + return "", err + } else { + return out.String(), nil + } +} + +func (r *Render) Error(status int, message ...string) { + r.WriteHeader(status) + if len(message) > 0 { + r.Write([]byte(message[0])) + } +} + +func (r *Render) Redirect(location string, status ...int) { + code := http.StatusFound + if len(status) == 1 { + code = status[0] + } + + http.Redirect(r, r.req, location, code) +} + +func (r *Render) Template() *template.Template { + return r.t +} + +func (r *Render) execute(name string, binding interface{}) (*bytes.Buffer, error) { + buf := new(bytes.Buffer) + return buf, r.t.ExecuteTemplate(buf, name, binding) +} + +func (r *Render) addYield(name string, binding interface{}) { + funcs := template.FuncMap{ + "yield": func() (template.HTML, error) { + buf, err := r.execute(name, binding) + + return template.HTML(buf.String()), err + }, + } + r.t.Funcs(funcs) +} + +func (r *Render) prepareHTMLOptions(htmlOpt []HTMLOptions) HTMLOptions { + if len(htmlOpt) > 0 { + return htmlOpt[0] + } + + return HTMLOptions{ + Layout: r.opt.Layout, + } +} diff --git a/modules/middleware/repo.go b/modules/middleware/repo.go new file mode 100644 index 00000000..2139742c --- /dev/null +++ b/modules/middleware/repo.go @@ -0,0 +1,163 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package middleware + +import ( + "errors" + "fmt" + "strings" + + "github.com/go-martini/martini" + + "github.com/gogits/git" + + "github.com/gogits/gogs/models" + "github.com/gogits/gogs/modules/base" +) + +func RepoAssignment(redirect bool, args ...bool) martini.Handler { + return func(ctx *Context, params martini.Params) { + // valid brachname + var validBranch bool + // display bare quick start if it is a bare repo + var displayBare bool + + if len(args) >= 1 { + validBranch = args[0] + } + + if len(args) >= 2 { + displayBare = args[1] + } + + var ( + user *models.User + err error + ) + + userName := params["username"] + repoName := params["reponame"] + branchName := params["branchname"] + + // get repository owner + ctx.Repo.IsOwner = ctx.IsSigned && ctx.User.LowerName == strings.ToLower(userName) + + if !ctx.Repo.IsOwner { + user, err = models.GetUserByName(params["username"]) + if err != nil { + if redirect { + ctx.Redirect("/") + return + } + ctx.Handle(200, "RepoAssignment", err) + return + } + } else { + user = ctx.User + } + + if user == nil { + if redirect { + ctx.Redirect("/") + return + } + ctx.Handle(200, "RepoAssignment", errors.New("invliad user account for single repository")) + return + } + + // get repository + repo, err := models.GetRepositoryByName(user.Id, repoName) + if err != nil { + if err == models.ErrRepoNotExist { + ctx.Handle(404, "RepoAssignment", err) + } else if redirect { + ctx.Redirect("/") + return + } + ctx.Handle(404, "RepoAssignment", err) + return + } + repo.NumOpenIssues = repo.NumIssues - repo.NumClosedIssues + ctx.Repo.Repository = repo + + ctx.Data["IsBareRepo"] = ctx.Repo.Repository.IsBare + + gitRepo, err := git.OpenRepository(models.RepoPath(userName, repoName)) + if err != nil { + ctx.Handle(404, "RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err) + return + } + ctx.Repo.GitRepo = gitRepo + + ctx.Repo.Owner = user + ctx.Repo.RepoLink = "/" + user.Name + "/" + repo.Name + + ctx.Data["Title"] = user.Name + "/" + repo.Name + ctx.Data["Repository"] = repo + ctx.Data["Owner"] = user + ctx.Data["RepoLink"] = ctx.Repo.RepoLink + ctx.Data["IsRepositoryOwner"] = ctx.Repo.IsOwner + ctx.Data["BranchName"] = "" + + ctx.Repo.CloneLink.SSH = fmt.Sprintf("%s@%s:%s/%s.git", base.RunUser, base.Domain, user.LowerName, repo.LowerName) + ctx.Repo.CloneLink.HTTPS = fmt.Sprintf("%s%s/%s.git", base.AppUrl, user.LowerName, repo.LowerName) + ctx.Data["CloneLink"] = ctx.Repo.CloneLink + + // when repo is bare, not valid branch + if !ctx.Repo.Repository.IsBare && validBranch { + detect: + if len(branchName) > 0 { + // TODO check tag + if models.IsBranchExist(user.Name, repoName, branchName) { + ctx.Repo.IsBranch = true + ctx.Repo.BranchName = branchName + + ctx.Repo.Commit, err = gitRepo.GetCommitOfBranch(branchName) + if err != nil { + ctx.Handle(404, "RepoAssignment invalid branch", nil) + return + } + + ctx.Repo.CommitId = ctx.Repo.Commit.Oid.String() + + } else if len(branchName) == 40 { + ctx.Repo.IsCommit = true + ctx.Repo.CommitId = branchName + ctx.Repo.BranchName = branchName + + ctx.Repo.Commit, err = gitRepo.GetCommit(branchName) + if err != nil { + ctx.Handle(404, "RepoAssignment invalid commit", nil) + return + } + } else { + ctx.Handle(404, "RepoAssignment invalid repo", nil) + return + } + + } else { + branchName = "master" + goto detect + } + + ctx.Data["IsBranch"] = ctx.Repo.IsBranch + ctx.Data["IsCommit"] = ctx.Repo.IsCommit + } + + // repo is bare and display enable + if displayBare && ctx.Repo.Repository.IsBare { + ctx.HTML(200, "repo/single_bare") + return + } + + if ctx.IsSigned { + ctx.Repo.IsWatching = models.IsWatching(ctx.User.Id, repo.Id) + } + + ctx.Data["BranchName"] = ctx.Repo.BranchName + ctx.Data["CommitId"] = ctx.Repo.CommitId + ctx.Data["IsRepositoryWatching"] = ctx.Repo.IsWatching + } +} diff --git a/modules/oauth2/oauth2.go b/modules/oauth2/oauth2.go new file mode 100644 index 00000000..05ae4606 --- /dev/null +++ b/modules/oauth2/oauth2.go @@ -0,0 +1,234 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +// Package oauth2 contains Martini handlers to provide +// user login via an OAuth 2.0 backend. +package oauth2 + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "code.google.com/p/goauth2/oauth" + "github.com/go-martini/martini" + + "github.com/gogits/session" + + "github.com/gogits/gogs/modules/log" + "github.com/gogits/gogs/modules/middleware" +) + +const ( + keyToken = "oauth2_token" + keyNextPage = "next" +) + +var ( + // Path to handle OAuth 2.0 logins. + PathLogin = "/login" + // Path to handle OAuth 2.0 logouts. + PathLogout = "/logout" + // Path to handle callback from OAuth 2.0 backend + // to exchange credentials. + PathCallback = "/oauth2callback" + // Path to handle error cases. + PathError = "/oauth2error" +) + +// Represents OAuth2 backend options. +type Options struct { + ClientId string + ClientSecret string + RedirectURL string + Scopes []string + + AuthUrl string + TokenUrl string +} + +// Represents a container that contains +// user's OAuth 2.0 access and refresh tokens. +type Tokens interface { + Access() string + Refresh() string + IsExpired() bool + ExpiryTime() time.Time + ExtraData() map[string]string +} + +type token struct { + oauth.Token +} + +func (t *token) ExtraData() map[string]string { + return t.Extra +} + +// Returns the access token. +func (t *token) Access() string { + return t.AccessToken +} + +// Returns the refresh token. +func (t *token) Refresh() string { + return t.RefreshToken +} + +// Returns whether the access token is +// expired or not. +func (t *token) IsExpired() bool { + if t == nil { + return true + } + return t.Expired() +} + +// Returns the expiry time of the user's +// access token. +func (t *token) ExpiryTime() time.Time { + return t.Expiry +} + +// Formats tokens into string. +func (t *token) String() string { + return fmt.Sprintf("tokens: %v", t) +} + +// Returns a new Google OAuth 2.0 backend endpoint. +func Google(opts *Options) martini.Handler { + opts.AuthUrl = "https://accounts.google.com/o/oauth2/auth" + opts.TokenUrl = "https://accounts.google.com/o/oauth2/token" + return NewOAuth2Provider(opts) +} + +// Returns a new Github OAuth 2.0 backend endpoint. +func Github(opts *Options) martini.Handler { + opts.AuthUrl = "https://github.com/login/oauth/authorize" + opts.TokenUrl = "https://github.com/login/oauth/access_token" + return NewOAuth2Provider(opts) +} + +func Facebook(opts *Options) martini.Handler { + opts.AuthUrl = "https://www.facebook.com/dialog/oauth" + opts.TokenUrl = "https://graph.facebook.com/oauth/access_token" + return NewOAuth2Provider(opts) +} + +// Returns a generic OAuth 2.0 backend endpoint. +func NewOAuth2Provider(opts *Options) martini.Handler { + config := &oauth.Config{ + ClientId: opts.ClientId, + ClientSecret: opts.ClientSecret, + RedirectURL: opts.RedirectURL, + Scope: strings.Join(opts.Scopes, " "), + AuthURL: opts.AuthUrl, + TokenURL: opts.TokenUrl, + } + + transport := &oauth.Transport{ + Config: config, + Transport: http.DefaultTransport, + } + + return func(c martini.Context, ctx *middleware.Context) { + if ctx.Req.Method == "GET" { + switch ctx.Req.URL.Path { + case PathLogin: + login(transport, ctx) + case PathLogout: + logout(transport, ctx) + case PathCallback: + handleOAuth2Callback(transport, ctx) + } + } + + tk := unmarshallToken(ctx.Session) + if tk != nil { + // check if the access token is expired + if tk.IsExpired() && tk.Refresh() == "" { + ctx.Session.Delete(keyToken) + tk = nil + } + } + // Inject tokens. + c.MapTo(tk, (*Tokens)(nil)) + } +} + +// Handler that redirects user to the login page +// if user is not logged in. +// Sample usage: +// m.Get("/login-required", oauth2.LoginRequired, func() ... {}) +var LoginRequired martini.Handler = func() martini.Handler { + return func(c martini.Context, ctx *middleware.Context) { + token := unmarshallToken(ctx.Session) + if token == nil || token.IsExpired() { + next := url.QueryEscape(ctx.Req.URL.RequestURI()) + ctx.Redirect(PathLogin + "?next=" + next) + return + } + } +}() + +func login(t *oauth.Transport, ctx *middleware.Context) { + next := extractPath(ctx.Query(keyNextPage)) + if ctx.Session.Get(keyToken) == nil { + // User is not logged in. + ctx.Redirect(t.Config.AuthCodeURL(next)) + return + } + // No need to login, redirect to the next page. + ctx.Redirect(next) +} + +func logout(t *oauth.Transport, ctx *middleware.Context) { + next := extractPath(ctx.Query(keyNextPage)) + ctx.Session.Delete(keyToken) + ctx.Redirect(next) +} + +func handleOAuth2Callback(t *oauth.Transport, ctx *middleware.Context) { + if errMsg := ctx.Query("error_description"); len(errMsg) > 0 { + log.Error("oauth2.handleOAuth2Callback: %s", errMsg) + return + } + + next := extractPath(ctx.Query("state")) + code := ctx.Query("code") + tk, err := t.Exchange(code) + if err != nil { + // Pass the error message, or allow dev to provide its own + // error handler. + log.Error("oauth2.handleOAuth2Callback(token.Exchange): %v", err) + // ctx.Redirect(PathError) + return + } + // Store the credentials in the session. + val, _ := json.Marshal(tk) + ctx.Session.Set(keyToken, val) + ctx.Redirect(next) +} + +func unmarshallToken(s session.SessionStore) (t *token) { + if s.Get(keyToken) == nil { + return + } + data := s.Get(keyToken).([]byte) + var tk oauth.Token + json.Unmarshal(data, &tk) + return &token{tk} +} + +func extractPath(next string) string { + n, err := url.Parse(next) + if err != nil { + return "/" + } + return n.Path +} diff --git a/modules/oauth2/oauth2_test.go b/modules/oauth2/oauth2_test.go new file mode 100644 index 00000000..71443030 --- /dev/null +++ b/modules/oauth2/oauth2_test.go @@ -0,0 +1,162 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oauth2 + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-martini/martini" + "github.com/martini-contrib/sessions" +) + +func Test_LoginRedirect(t *testing.T) { + recorder := httptest.NewRecorder() + m := martini.New() + m.Use(sessions.Sessions("my_session", sessions.NewCookieStore([]byte("secret123")))) + m.Use(Google(&Options{ + ClientId: "client_id", + ClientSecret: "client_secret", + RedirectURL: "refresh_url", + Scopes: []string{"x", "y"}, + })) + + r, _ := http.NewRequest("GET", "/login", nil) + m.ServeHTTP(recorder, r) + + location := recorder.HeaderMap["Location"][0] + if recorder.Code != 302 { + t.Errorf("Not being redirected to the auth page.") + } + if location != "https://accounts.google.com/o/oauth2/auth?access_type=&approval_prompt=&client_id=client_id&redirect_uri=refresh_url&response_type=code&scope=x+y&state=" { + t.Errorf("Not being redirected to the right page, %v found", location) + } +} + +func Test_LoginRedirectAfterLoginRequired(t *testing.T) { + recorder := httptest.NewRecorder() + m := martini.Classic() + m.Use(sessions.Sessions("my_session", sessions.NewCookieStore([]byte("secret123")))) + m.Use(Google(&Options{ + ClientId: "client_id", + ClientSecret: "client_secret", + RedirectURL: "refresh_url", + Scopes: []string{"x", "y"}, + })) + + m.Get("/login-required", LoginRequired, func(tokens Tokens) (int, string) { + return 200, tokens.Access() + }) + + r, _ := http.NewRequest("GET", "/login-required?key=value", nil) + m.ServeHTTP(recorder, r) + + location := recorder.HeaderMap["Location"][0] + if recorder.Code != 302 { + t.Errorf("Not being redirected to the auth page.") + } + if location != "/login?next=%2Flogin-required%3Fkey%3Dvalue" { + t.Errorf("Not being redirected to the right page, %v found", location) + } +} + +func Test_Logout(t *testing.T) { + recorder := httptest.NewRecorder() + s := sessions.NewCookieStore([]byte("secret123")) + + m := martini.Classic() + m.Use(sessions.Sessions("my_session", s)) + m.Use(Google(&Options{ + // no need to configure + })) + + m.Get("/", func(s sessions.Session) { + s.Set(keyToken, "dummy token") + }) + + m.Get("/get", func(s sessions.Session) { + if s.Get(keyToken) != nil { + t.Errorf("User credentials are still kept in the session.") + } + }) + + logout, _ := http.NewRequest("GET", "/logout", nil) + index, _ := http.NewRequest("GET", "/", nil) + + m.ServeHTTP(httptest.NewRecorder(), index) + m.ServeHTTP(recorder, logout) + + if recorder.Code != 302 { + t.Errorf("Not being redirected to the next page.") + } +} + +func Test_LogoutOnAccessTokenExpiration(t *testing.T) { + recorder := httptest.NewRecorder() + s := sessions.NewCookieStore([]byte("secret123")) + + m := martini.Classic() + m.Use(sessions.Sessions("my_session", s)) + m.Use(Google(&Options{ + // no need to configure + })) + + m.Get("/addtoken", func(s sessions.Session) { + s.Set(keyToken, "dummy token") + }) + + m.Get("/", func(s sessions.Session) { + if s.Get(keyToken) != nil { + t.Errorf("User not logged out although access token is expired.") + } + }) + + addtoken, _ := http.NewRequest("GET", "/addtoken", nil) + index, _ := http.NewRequest("GET", "/", nil) + m.ServeHTTP(recorder, addtoken) + m.ServeHTTP(recorder, index) +} + +func Test_InjectedTokens(t *testing.T) { + recorder := httptest.NewRecorder() + m := martini.Classic() + m.Use(sessions.Sessions("my_session", sessions.NewCookieStore([]byte("secret123")))) + m.Use(Google(&Options{ + // no need to configure + })) + m.Get("/", func(tokens Tokens) string { + return "Hello world!" + }) + r, _ := http.NewRequest("GET", "/", nil) + m.ServeHTTP(recorder, r) +} + +func Test_LoginRequired(t *testing.T) { + recorder := httptest.NewRecorder() + m := martini.Classic() + m.Use(sessions.Sessions("my_session", sessions.NewCookieStore([]byte("secret123")))) + m.Use(Google(&Options{ + // no need to configure + })) + m.Get("/", LoginRequired, func(tokens Tokens) string { + return "Hello world!" + }) + r, _ := http.NewRequest("GET", "/", nil) + m.ServeHTTP(recorder, r) + if recorder.Code != 302 { + t.Errorf("Not being redirected to the auth page although user is not logged in.") + } +} |