Skip to content

Commit

Permalink
Add Config.LocalServerCallbackPath (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
int128 authored Jan 25, 2025
1 parent a6300b0 commit 833e3ba
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 18 deletions.
12 changes: 6 additions & 6 deletions e2e_test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestHappyPath(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestRedirectURLHostname(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1") {
if !assertRedirectURI(t, req.RedirectURI, "http", "127.0.0.1", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSuccessRedirect(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestSuccessRedirect(t *testing.T) {
wg.Wait()
}

func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
func assertRedirectURI(t *testing.T, actualURI, scheme, hostname, path string) bool {
redirect, err := url.Parse(actualURI)
if err != nil {
t.Errorf("could not parse redirect_uri: %s", err)
Expand All @@ -256,8 +256,8 @@ func assertRedirectURI(t *testing.T, actualURI, scheme, hostname string) bool {
t.Errorf("redirect_uri wants hostname %s but was %s", hostname, actualHostname)
return false
}
if redirect.Path != "" {
t.Errorf("redirect_uri wants path `` but was %s", redirect.Path)
if actualPath := redirect.Path; actualPath != path {
t.Errorf("redirect_uri wants path %s but was %s", path, actualPath)
return false
}
return true
Expand Down
86 changes: 86 additions & 0 deletions e2e_test/localserveropts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package e2e_test

import (
"context"
"fmt"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/int128/oauth2cli"
"github.com/int128/oauth2cli/e2e_test/authserver"
"github.com/int128/oauth2cli/e2e_test/client"
"golang.org/x/oauth2"
)

func TestLocalServerCallbackPath(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second)
defer cancel()
openBrowserCh := make(chan string)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer close(openBrowserCh)
// Start a local server and get a token.
testServer := httptest.NewServer(&authserver.Handler{
TestingT: t,
NewAuthorizationResponse: func(req authserver.AuthorizationRequest) string {
if want := "email profile"; req.Scope != want {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/callback") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
},
NewTokenResponse: func(req authserver.TokenRequest) (int, string) {
if want := "AUTH_CODE"; req.Code != want {
t.Errorf("code wants %s but %s", want, req.Code)
return 400, invalidGrantResponse
}
return 200, validTokenResponse
},
})
defer testServer.Close()
cfg := oauth2cli.Config{
OAuth2Config: oauth2.Config{
ClientID: "YOUR_CLIENT_ID",
ClientSecret: "YOUR_CLIENT_SECRET",
Scopes: []string{"email", "profile"},
Endpoint: oauth2.Endpoint{
AuthURL: testServer.URL + "/auth",
TokenURL: testServer.URL + "/token",
},
},
LocalServerCallbackPath: "/callback",
LocalServerReadyChan: openBrowserCh,
LocalServerMiddleware: loggingMiddleware(t),
Logf: t.Logf,
}
token, err := oauth2cli.GetToken(ctx, cfg)
if err != nil {
t.Errorf("could not get a token: %s", err)
return
}
if token.AccessToken != "ACCESS_TOKEN" {
t.Errorf("AccessToken wants %s but %s", "ACCESS_TOKEN", token.AccessToken)
}
if token.RefreshToken != "REFRESH_TOKEN" {
t.Errorf("RefreshToken wants %s but %s", "REFRESH_TOKEN", token.RefreshToken)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
toURL, ok := <-openBrowserCh
if !ok {
t.Errorf("server already closed")
return
}
client.GetAndVerify(t, toURL, 200, oauth2cli.DefaultLocalServerSuccessHTML)
}()
wg.Wait()
}
2 changes: 1 addition & 1 deletion e2e_test/pkce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestPKCE(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "http", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
2 changes: 1 addition & 1 deletion e2e_test/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestTLS(t *testing.T) {
t.Errorf("scope wants %s but %s", want, req.Scope)
return fmt.Sprintf("%s?error=invalid_scope", req.RedirectURI)
}
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost") {
if !assertRedirectURI(t, req.RedirectURI, "https", "localhost", "/") {
return fmt.Sprintf("%s?error=invalid_redirect_uri", req.RedirectURI)
}
return fmt.Sprintf("%s?state=%s&code=%s", req.RedirectURI, req.State, "AUTH_CODE")
Expand Down
8 changes: 8 additions & 0 deletions oauth2cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type Config struct {
// This is required when LocalServerCertFile is set.
LocalServerKeyFile string

// Callback path of the local server.
// If your provider requires a specific path of the redirect URL, set it here.
// Default to "/".
LocalServerCallbackPath string

// Response HTML body on authorization completed.
// Default to DefaultLocalServerSuccessHTML.
LocalServerSuccessHTML string
Expand Down Expand Up @@ -119,6 +124,9 @@ func (cfg *Config) validateAndSetDefaults() error {
}
cfg.State = state
}
if cfg.LocalServerCallbackPath == "" {
cfg.LocalServerCallbackPath = "/"
}
if cfg.LocalServerMiddleware == nil {
cfg.LocalServerMiddleware = noopMiddleware
}
Expand Down
29 changes: 19 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
defer localServerListener.Close()

localServerPort := localServerListener.Addr().(*net.TCPAddr).Port
cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort)
localServerURL := constructLocalServerURL(cfg, localServerPort)
localServerIndexURL, err := localServerURL.Parse("/")
if err != nil {
return "", fmt.Errorf("construct the index URL: %w", err)
}
localServerCallbackURL, err := localServerURL.Parse(cfg.LocalServerCallbackPath)
if err != nil {
return "", fmt.Errorf("construct the callback URL: %w", err)
}
cfg.OAuth2Config.RedirectURL = localServerCallbackURL.String()

respCh := make(chan *authorizationResponse)
server := http.Server{
Expand Down Expand Up @@ -84,7 +93,7 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
return nil
}
select {
case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL:
case cfg.LocalServerReadyChan <- localServerIndexURL.String():
return nil
case <-ctx.Done():
return ctx.Err()
Expand All @@ -99,14 +108,14 @@ func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error)
return resp.code, resp.err
}

func constructRedirectURL(cfg *Config, port int) string {
var redirect url.URL
redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
redirect.Scheme = "http"
func constructLocalServerURL(cfg *Config, port int) url.URL {
var localServer url.URL
localServer.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
localServer.Scheme = "http"
if cfg.isLocalServerHTTPS() {
redirect.Scheme = "https"
localServer.Scheme = "https"
}
return redirect.String()
return localServer
}

type authorizationResponse struct {
Expand All @@ -123,11 +132,11 @@ type localServerHandler struct {
func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
switch {
case r.Method == "GET" && r.URL.Path == "/" && q.Get("error") != "":
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("error") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleErrorResponse(w, r)
})
case r.Method == "GET" && r.URL.Path == "/" && q.Get("code") != "":
case r.Method == "GET" && r.URL.Path == h.config.LocalServerCallbackPath && q.Get("code") != "":
h.onceRespCh.Do(func() {
h.respCh <- h.handleCodeResponse(w, r)
})
Expand Down

0 comments on commit 833e3ba

Please sign in to comment.