Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow code flow callback to be direct to Vault #130

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func backend() *jwtAuthBackend {
"login",
"oidc/auth_url",
"oidc/callback",
"oidc/poll",

// Uncomment to mount simple UI handler for local development
// "ui",
Expand Down
130 changes: 108 additions & 22 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path"
Expand All @@ -27,9 +28,11 @@ const (
defaultPort = "8250"
defaultCallbackHost = "localhost"
defaultCallbackMethod = "http"
defaultCallbackMode = "client"

FieldCallbackHost = "callbackhost"
FieldCallbackMethod = "callbackmethod"
FieldCallbackMode = "callbackmode"
FieldListenAddress = "listenaddress"
FieldPort = "port"
FieldCallbackPort = "callbackport"
Expand Down Expand Up @@ -69,19 +72,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
port = defaultPort
}

var vaultURL *url.URL
callbackMode, ok := m[FieldCallbackMode]
if !ok {
callbackMode = defaultCallbackMode
} else if callbackMode == "direct" {
vaultAddr := os.Getenv("VAULT_ADDR")
if vaultAddr != "" {
vaultURL, _ = url.Parse(vaultAddr)
}
}

callbackHost, ok := m[FieldCallbackHost]
if !ok {
callbackHost = defaultCallbackHost
if vaultURL != nil {
callbackHost = vaultURL.Hostname()
} else {
callbackHost = defaultCallbackHost
}
}

callbackMethod, ok := m[FieldCallbackMethod]
if !ok {
callbackMethod = defaultCallbackMethod
if vaultURL != nil {
callbackMethod = vaultURL.Scheme
} else {
callbackMethod = defaultCallbackMethod
}
}

callbackPort, ok := m[FieldCallbackPort]
if !ok {
callbackPort = port
if vaultURL != nil {
callbackPort = vaultURL.Port() + "/v1/auth/" + mount
} else {
callbackPort = port
}
}

parseBool := func(f string, d bool) (bool, error) {
Expand Down Expand Up @@ -115,20 +141,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro

role := m["role"]

authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
if err != nil {
return nil, err
}

// Set up callback handler
doneCh := make(chan loginResp)
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
var pollInterval string
var interval int
var state string
var listener net.Listener

if secret != nil {
pollInterval, _ = secret.Data["poll_interval"].(string)
state, _ = secret.Data["state"].(string)
}
if callbackMode == "direct" {
if state == "" {
return nil, errors.New("no state returned in direct callback mode")
}
if pollInterval == "" {
return nil, errors.New("no poll_interval returned in direct callback mode")
}
interval, err = strconv.Atoi(pollInterval)
if err != nil {
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
}
} else {
if state != "" {
return nil, errors.New("state returned in client callback mode, try direct")
}
if pollInterval != "" {
return nil, errors.New("poll_interval returned in client callback mode")
}
// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err = net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
}
defer listener.Close()
}
defer listener.Close()

// Open the default browser to the callback URL.
if !skipBrowserLaunch {
Expand All @@ -144,6 +199,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
}
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")

if callbackMode == "direct" {
data := map[string]interface{}{
"state": state,
"client_nonce": clientNonce,
}
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
for {
time.Sleep(time.Duration(interval) * time.Second)

secret, err := c.Logical().Write(pollUrl, data)
if err == nil {
return secret, nil
}
if !strings.HasSuffix(err.Error(), "authorization_pending") {
return nil, err
}
// authorization is pending, try again
}
}

// Start local server
go func() {
err := http.Serve(listener, nil)
Expand Down Expand Up @@ -210,12 +285,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
}
}

func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
var authURL string

clientNonce, err := base62.Random(20)
if err != nil {
return "", "", err
return "", "", nil, err
}

redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
Expand All @@ -227,18 +302,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho

secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
if err != nil {
return "", "", err
return "", "", nil, err
}

if secret != nil {
authURL = secret.Data["auth_url"].(string)
}

if authURL == "" {
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
}

return authURL, clientNonce, nil
return authURL, clientNonce, secret, nil
}

// parseError converts error from the API into summary and detailed portions.
Expand Down Expand Up @@ -292,35 +367,46 @@ Usage: vault login -method=oidc [CONFIG K=V...]

https://accounts.google.com/o/oauth2/v2/...

The default browser will be opened for the user to complete the login. Alternatively,
the user may visit the provided URL directly.
The default browser will be opened for the user to complete the login.
Alternatively, the user may visit the provided URL directly.

Configuration:

role=<string>
Vault role of type "OIDC" to use for authentication.

%s=<string>
Optional address to bind the OIDC callback listener to (default: localhost).
Mode of callback: "direct" for direct connection to Vault or "client"
for connection to command line client (default: client).

%s=<string>
Optional address to bind the OIDC callback listener to in client callback
mode (default: localhost).

%s=<string>
Optional localhost port to use for OIDC callback (default: 8250).
Optional localhost port to use for OIDC callback in client callback mode
(default: 8250).

%s=<string>
Optional method to to use in OIDC redirect_uri (default: http).
Optional method to use in OIDC redirect_uri (default: the method from
$VAULT_ADDR in direct callback mode, else http)

%s=<string>
Optional callback host address to use in OIDC redirect_uri (default: localhost).
Optional callback host address to use in OIDC redirect_uri (default:
the host from $VAULT_ADDR in direct callback mode, else localhost).

%s=<string>
Optional port to to use in OIDC redirect_uri (default: the value set for port).
Optional port to use in OIDC redirect_uri (default: the value set for
port in client callback mode, else the port from $VAULT_ADDR with an
added /v1/auth/<path> where <path> is from the login -path option).

%s=<bool>
Toggle the automatic launching of the default browser to the login URL. (default: false).

%s=<bool>
Abort on any error. (default: false).
`,
FieldCallbackMode,
FieldListenAddress, FieldPort, FieldCallbackMethod,
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
FieldAbortOnError,
Expand Down
File renamed without changes.
Loading