Skip to content

Commit

Permalink
feat(oauth): add new route to retrieve user info
Browse files Browse the repository at this point in the history
GET {proxy-prefix}/token will output access_token,
refresh_token, username, email and expires in json
format, if the route is enabled (default: off).

Ref bitly#571
  • Loading branch information
blaskovicz committed Mar 25, 2018
1 parent a94b0a8 commit f8b6621
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ Usage of oauth2_proxy:
-profile-url string: Profile access endpoint
-provider string: OAuth provider (default "google")
-proxy-prefix string: the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in) (default "/oauth2")
-allow-token-request allow authenticated GET requests to {proxy-prefix}/token to output access_token, refresh_token, username, email and expires (default: false)
-redeem-url string: Token redemption endpoint
-redirect-url string: the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback"
-request-logging: Log requests to stdout (default true)
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func main() {
flagSet.String("custom-templates-dir", "", "path to custom html templates")
flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.")
flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)")
flagSet.Bool("allow-token-request", false, "Allow authenticated GET requests to {proxy-prefix}/token to output access_token, refresh_token, username, email and expires.")

flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates")
flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)")
Expand Down
45 changes: 38 additions & 7 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"html/template"
Expand Down Expand Up @@ -52,6 +53,7 @@ type OAuthProxy struct {
OAuthStartPath string
OAuthCallbackPath string
AuthOnlyPath string
OAuthTokenPath string

redirectURL *url.URL // the url to receive requests at
provider providers.Provider
Expand All @@ -67,6 +69,7 @@ type OAuthProxy struct {
BasicAuthPassword string
PassAccessToken bool
CookieCipher *cookie.Cipher
allowTokenRequest bool
skipAuthRegex []string
skipAuthPreflight bool
compiledRegex []*regexp.Regexp
Expand Down Expand Up @@ -189,13 +192,15 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix),
OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix),
AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix),
OAuthTokenPath: fmt.Sprintf("%s/token", opts.ProxyPrefix),

ProxyPrefix: opts.ProxyPrefix,
provider: opts.provider,
serveMux: serveMux,
redirectURL: redirectURL,
skipAuthRegex: opts.SkipAuthRegex,
skipAuthPreflight: opts.SkipAuthPreflight,
allowTokenRequest: opts.AllowTokenRequest,
compiledRegex: opts.CompiledRegex,
SetXAuthRequest: opts.SetXAuthRequest,
PassBasicAuth: opts.PassBasicAuth,
Expand Down Expand Up @@ -474,6 +479,8 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.OAuthCallback(rw, req)
case path == p.AuthOnlyPath:
p.AuthenticateOnly(rw, req)
case path == p.OAuthTokenPath:
p.OAuthToken(rw, req)
default:
p.Proxy(rw, req)
}
Expand Down Expand Up @@ -582,8 +589,25 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
}

func (p *OAuthProxy) OAuthToken(rw http.ResponseWriter, req *http.Request) {
if !p.allowTokenRequest {
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
return
}

session, status := p.Authenticate(rw, req)
if status == http.StatusAccepted {
// TODO: accept header and different content-types, if required.
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusOK)
json.NewEncoder(rw).Encode(session)
} else {
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
}
}

func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
_, status := p.Authenticate(rw, req)
if status == http.StatusAccepted {
rw.WriteHeader(http.StatusAccepted)
} else {
Expand All @@ -592,7 +616,7 @@ func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request)
}

func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
_, status := p.Authenticate(rw, req)
if status == http.StatusInternalServerError {
p.ErrorPage(rw, http.StatusInternalServerError,
"Internal Error", "Internal Error")
Expand All @@ -607,7 +631,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
}
}

func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int {
func (p *OAuthProxy) authenticateSession(rw http.ResponseWriter, req *http.Request) (*providers.SessionState, error) {
var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req)

Expand Down Expand Up @@ -656,7 +680,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
err := p.SaveSession(rw, req, session)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
return http.StatusInternalServerError
return nil, err
}
}

Expand All @@ -671,8 +695,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
}
}

if session == nil {
return http.StatusForbidden
return session, nil
}

func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (*providers.SessionState, int) {
session, err := p.authenticateSession(rw, req)
if err != nil {
return nil, http.StatusInternalServerError
} else if session == nil {
return nil, http.StatusForbidden
}

// At this point, the user is authenticated. proxy normally
Expand Down Expand Up @@ -703,7 +734,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
} else {
rw.Header().Set("GAP-Auth", session.Email)
}
return http.StatusAccepted
return session, http.StatusAccepted
}

func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
Expand Down
64 changes: 55 additions & 9 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto"
"encoding/base64"
"encoding/json"
"io"
"io/ioutil"
"log"
Expand All @@ -18,6 +19,7 @@ import (
"github.com/bitly/oauth2_proxy/providers"
"github.com/mbland/hmacauth"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func init() {
Expand Down Expand Up @@ -470,17 +472,23 @@ type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool
}

func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
var pc_test ProcessCookieTest

pc_test.opts = NewOptions()
pc_test.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdefabcd"
func defaultOpts() *Options {
opts := NewOptions()
opts.ClientID = "bazquux"
opts.ClientSecret = "xyzzyplugh"
opts.CookieSecret = "0123456789abcdefabcd"
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Hour
pc_test.opts.Validate()
opts.CookieRefresh = time.Hour
return opts
}
func NewProcessCookieTest(testOpts ProcessCookieTestOpts) *ProcessCookieTest {
return NewProcessCookieTestWithOpts(testOpts, defaultOpts())
}
func NewProcessCookieTestWithOpts(opts ProcessCookieTestOpts, pc_test_opts *Options) *ProcessCookieTest {
pc_test_opts.Validate()
var pc_test ProcessCookieTest
pc_test.opts = pc_test_opts

pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
return pc_test.validate_user
Expand Down Expand Up @@ -589,6 +597,15 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
}
}

func NewOAuthTokenEndpointTest(allow bool) *ProcessCookieTest {
opts := defaultOpts()
opts.AllowTokenRequest = allow
pc_test := NewProcessCookieTestWithOpts(ProcessCookieTestOpts{provider_validate_cookie_response: true}, opts)
pc_test.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/token", nil)
return pc_test
}

func NewAuthOnlyEndpointTest() *ProcessCookieTest {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.req, _ = http.NewRequest("GET",
Expand Down Expand Up @@ -674,6 +691,35 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0])
}

func TestOAuthTokenEndpoint(t *testing.T) {
t.Run("NotAllowed", func(t *testing.T) {
test := NewOAuthTokenEndpointTest(false)
startSession := &providers.SessionState{User: "zach", Email: "zach@example.com", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now())

test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, err := ioutil.ReadAll(test.rw.Body)
require.NoError(t, err, "failed to read token response body")
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
})
t.Run("Allowed", func(t *testing.T) {
test := NewOAuthTokenEndpointTest(true)
startSession := &providers.SessionState{User: "zach", Email: "zach@example.com", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now())

test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusOK, test.rw.Code)
rb := map[string]interface{}{}
err := json.Unmarshal(test.rw.Body.Bytes(), &rb)
require.NoError(t, err, "failed to unmarshal token response (%#v)", test.rw.Body.String())
assert.Equal(t, "application/json", test.rw.Header().Get("Content-Type"), "token content-type header incorrect")
assert.Equal(t, startSession.User, rb["user"].(string), "token user incorrect (%#v)", rb)
assert.Equal(t, startSession.Email, rb["email"].(string), "token email incorrect (%#v)", rb)
assert.Equal(t, startSession.AccessToken, rb["access_token"].(string), "token access_token incorrect (%#v)", rb)
})
}

func TestAuthSkippedForPreflightRequests(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
Expand Down
2 changes: 2 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Options struct {
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"`
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`
AllowTokenRequest bool `flag:"allow-token-request" cfg:"allow_token_request"`

// These options allow for other providers besides Google, with
// potential overrides.
Expand Down Expand Up @@ -106,6 +107,7 @@ func NewOptions() *Options {
CookieRefresh: time.Duration(0),
SetXAuthRequest: false,
SkipAuthPreflight: false,
AllowTokenRequest: false,
PassBasicAuth: true,
PassUserHeaders: true,
PassAccessToken: false,
Expand Down
10 changes: 5 additions & 5 deletions providers/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
)

type SessionState struct {
AccessToken string
ExpiresOn time.Time
RefreshToken string
Email string
User string
AccessToken string `json:"access_token"`
ExpiresOn time.Time `json:"expires_on"`
RefreshToken string `json:"refresh_token"`
Email string `json:"email"`
User string `json:"user"`
}

func (s *SessionState) IsExpired() bool {
Expand Down

0 comments on commit f8b6621

Please sign in to comment.