diff options
Diffstat (limited to 'modules/oauth2/oauth2.go')
-rw-r--r-- | modules/oauth2/oauth2.go | 89 |
1 files changed, 45 insertions, 44 deletions
diff --git a/modules/oauth2/oauth2.go b/modules/oauth2/oauth2.go index 088d65dd..05ae4606 100644 --- a/modules/oauth2/oauth2.go +++ b/modules/oauth2/oauth2.go @@ -1,16 +1,7 @@ // Copyright 2014 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Copyright 2014 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. // Package oauth2 contains Martini handlers to provide // user login via an OAuth 2.0 backend. @@ -26,13 +17,16 @@ import ( "code.google.com/p/goauth2/oauth" "github.com/go-martini/martini" - "github.com/martini-contrib/sessions" + + "github.com/gogits/session" + + "github.com/gogits/gogs/modules/log" + "github.com/gogits/gogs/modules/middleware" ) const ( - codeRedirect = 302 - keyToken = "oauth2_token" - keyNextPage = "next" + keyToken = "oauth2_token" + keyNextPage = "next" ) var ( @@ -142,23 +136,23 @@ func NewOAuth2Provider(opts *Options) martini.Handler { Transport: http.DefaultTransport, } - return func(s sessions.Session, c martini.Context, w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - switch r.URL.Path { + return func(c martini.Context, ctx *middleware.Context) { + if ctx.Req.Method == "GET" { + switch ctx.Req.URL.Path { case PathLogin: - login(transport, s, w, r) + login(transport, ctx) case PathLogout: - logout(transport, s, w, r) + logout(transport, ctx) case PathCallback: - handleOAuth2Callback(transport, s, w, r) + handleOAuth2Callback(transport, ctx) } } - tk := unmarshallToken(s) + tk := unmarshallToken(ctx.Session) if tk != nil { // check if the access token is expired if tk.IsExpired() && tk.Refresh() == "" { - s.Delete(keyToken) + ctx.Session.Delete(keyToken) tk = nil } } @@ -172,49 +166,56 @@ func NewOAuth2Provider(opts *Options) martini.Handler { // Sample usage: // m.Get("/login-required", oauth2.LoginRequired, func() ... {}) var LoginRequired martini.Handler = func() martini.Handler { - return func(s sessions.Session, c martini.Context, w http.ResponseWriter, r *http.Request) { - token := unmarshallToken(s) + return func(c martini.Context, ctx *middleware.Context) { + token := unmarshallToken(ctx.Session) if token == nil || token.IsExpired() { - next := url.QueryEscape(r.URL.RequestURI()) - http.Redirect(w, r, PathLogin+"?next="+next, codeRedirect) + next := url.QueryEscape(ctx.Req.URL.RequestURI()) + ctx.Redirect(PathLogin + "?next=" + next) + return } } }() -func login(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) { - next := extractPath(r.URL.Query().Get(keyNextPage)) - if s.Get(keyToken) == nil { +func login(t *oauth.Transport, ctx *middleware.Context) { + next := extractPath(ctx.Query(keyNextPage)) + if ctx.Session.Get(keyToken) == nil { // User is not logged in. - http.Redirect(w, r, t.Config.AuthCodeURL(next), codeRedirect) + ctx.Redirect(t.Config.AuthCodeURL(next)) return } // No need to login, redirect to the next page. - http.Redirect(w, r, next, codeRedirect) + ctx.Redirect(next) } -func logout(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) { - next := extractPath(r.URL.Query().Get(keyNextPage)) - s.Delete(keyToken) - http.Redirect(w, r, next, codeRedirect) +func logout(t *oauth.Transport, ctx *middleware.Context) { + next := extractPath(ctx.Query(keyNextPage)) + ctx.Session.Delete(keyToken) + ctx.Redirect(next) } -func handleOAuth2Callback(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) { - next := extractPath(r.URL.Query().Get("state")) - code := r.URL.Query().Get("code") +func handleOAuth2Callback(t *oauth.Transport, ctx *middleware.Context) { + if errMsg := ctx.Query("error_description"); len(errMsg) > 0 { + log.Error("oauth2.handleOAuth2Callback: %s", errMsg) + return + } + + next := extractPath(ctx.Query("state")) + code := ctx.Query("code") tk, err := t.Exchange(code) if err != nil { // Pass the error message, or allow dev to provide its own // error handler. - http.Redirect(w, r, PathError, codeRedirect) + log.Error("oauth2.handleOAuth2Callback(token.Exchange): %v", err) + // ctx.Redirect(PathError) return } // Store the credentials in the session. val, _ := json.Marshal(tk) - s.Set(keyToken, val) - http.Redirect(w, r, next, codeRedirect) + ctx.Session.Set(keyToken, val) + ctx.Redirect(next) } -func unmarshallToken(s sessions.Session) (t *token) { +func unmarshallToken(s session.SessionStore) (t *token) { if s.Get(keyToken) == nil { return } |