Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OIDC device flow #122

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func backend() *jwtAuthBackend {
"login",
"oidc/auth_url",
"oidc/callback",
"oidc/device_wait",

// Uncomment to mount simple UI handler for local development
// "ui",
Expand Down
49 changes: 37 additions & 12 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,51 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro

role := m["role"]

authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
if err != nil {
return nil, err
}

// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
var deviceCode string
var userCode string
var listener net.Listener

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
if secret != nil {
deviceCode, _ = secret.Data["device_code"].(string)
}
if deviceCode != "" {
userCode, _ = secret.Data["user_code"].(string)
} else {
// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err = net.Listen("tcp", listenAddress+":"+port)
if err != nil {
return nil, err
}
defer listener.Close()
}
defer listener.Close()

// Open the default browser to the callback URL.
fmt.Fprintf(os.Stderr, "Complete the login via your OIDC provider. Launching browser to:\n\n %s\n\n\n", authURL)
if userCode != "" {
fmt.Fprintf(os.Stderr, "When prompted, enter code %s\n\n", userCode)
}
if err := openURL(authURL); err != nil {
fmt.Fprintf(os.Stderr, "Error attempting to automatically open browser: '%s'.\nPlease visit the authorization URL manually.", err)
}

if deviceCode != "" {
interval, _ := secret.Data["interval"].(string)
data := map[string]interface{}{
"role": role,
"device_code": deviceCode,
"interval": interval,
}

return c.Logical().Write(fmt.Sprintf("auth/%s/oidc/device_wait", mount), data)
}

// Start local server
go func() {
err := http.Serve(listener, nil)
Expand Down Expand Up @@ -160,12 +185,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
}
}

func fetchAuthURL(c *api.Client, role, mount, callbackport string, callbackMethod string, callbackHost string) (string, string, error) {
func fetchAuthURL(c *api.Client, role, mount, callbackport string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
var authURL string

clientNonce, err := base62.Random(20)
if err != nil {
return "", "", err
return "", "", nil, err
}

data := map[string]interface{}{
Expand All @@ -176,18 +201,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackport string, callbackMetho

secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
if err != nil {
return "", "", err
return "", "", nil, err
}

if secret != nil {
authURL = secret.Data["auth_url"].(string)
}

if authURL == "" {
return "", "", fmt.Errorf("Unable to authorize role %q. Check Vault logs for more information.", role)
return "", "", nil, fmt.Errorf("Unable to authorize role %q. Check Vault logs for more information.", role)
}

return authURL, clientNonce, nil
return authURL, clientNonce, secret, nil
}

// isWSL tests if the binary is being run in Windows Subsystem for Linux
Expand Down
189 changes: 189 additions & 0 deletions oauth2device.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// This is a small shim on golang's oauth2 library to add device flow.
// If the library adds its own support, this file can be eliminated.
//
// The below code was copied from
// https://raw.githubusercontent.com/rjw57/oauth2device/master/oauth2device.go
// on 16 June 2020 and updated according to the more recent RFC8628.
// Documentation for the original code was available at
// https://godoc.org/github.com/rjw57/oauth2device
// The BSD license applied was this:
//
// Copyright (c) 2014, Rich Wareham rich.oauth2device@richwareham.com
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
// OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
// AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
// WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.

package jwtauth

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"

"golang.org/x/oauth2"
)

// A DeviceCode represents the user-visible code, verification URI and
// device-visible code used to allow for user authorisation of this app.
// The VerificationURIComplete is optional and combines the user code
// and verification URI. If present, apps may choose to show to
// the user the VerificationURIComplete, otherwise the app should show
// the UserCode and VerificationURL to the user. ExpiresIn is how many
// seconds the user has to respond, and the optional Interval is how many
// seconds the app should wait in between polls (default 5).
type DeviceCode struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int64 `json:"expires_in"`
Interval int64 `json:"interval"`
}

// DeviceEndpoint contains the URLs required to initiate the OAuth2.0 flow for a
// provider's device flow.
type DeviceEndpoint struct {
CodeURL string
}

// A version of oauth2.Config augmented with device endpoints
type DeviceConfig struct {
*oauth2.Config
DeviceEndpoint DeviceEndpoint
}

// A tokenOrError is either an OAuth2 Token response or an error indicating why
// such a response failed.
type tokenOrError struct {
*oauth2.Token
Error string `json:"error,omitempty"`
}

var (
// ErrAccessDenied is an error returned when the user has denied this
// app access to their account.
ErrAccessDenied = errors.New("access denied by user")
)

const (
deviceGrantType = "urn:ietf:params:oauth:grant-type:device_code"
)

// RequestDeviceCode will initiate the OAuth2 device authorization flow. It
// requests a device code and information on the code and URL to show to the
// user. Pass the returned DeviceCode to WaitForDeviceAuthorization.
func RequestDeviceCode(client *http.Client, config *DeviceConfig) (*DeviceCode, error) {
scopes := strings.Join(config.Scopes, " ")
resp, err := client.PostForm(config.DeviceEndpoint.CodeURL,
url.Values{"client_id": {config.ClientID}, "scope": {scopes}})

if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"request for device code authorisation returned status %v (%v)",
resp.StatusCode, http.StatusText(resp.StatusCode))
}

// Unmarshal response
var dcr DeviceCode
dec := json.NewDecoder(resp.Body)
if err := dec.Decode(&dcr); err != nil {
return nil, err
}

if dcr.Interval == 0 {
dcr.Interval = 5
}

return &dcr, nil
}

// WaitForDeviceAuthorization polls the token URL waiting for the user to
// authorize the app. Upon authorization, it returns the new token. If
// authorization fails then an error is returned. If that failure was due to a
// user explicitly denying access, the error is ErrAccessDenied.
func WaitForDeviceAuthorization(client *http.Client, config *DeviceConfig, code *DeviceCode) (*oauth2.Token, error) {
for {

resp, err := client.PostForm(config.Endpoint.TokenURL,
url.Values{
"client_secret": {config.ClientSecret},
"client_id": {config.ClientID},
"device_code": {code.DeviceCode},
"grant_type": {deviceGrantType}})
if err != nil {
return nil, fmt.Errorf("post error while polling for OAuth token: %v", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusBadRequest {
return nil, fmt.Errorf("HTTP error %v (%v) when polling for OAuth token",
resp.StatusCode, http.StatusText(resp.StatusCode))
}

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body while polling for OAuth token: %v", err)
}

// Unmarshal response, checking for errors
var token tokenOrError
if err := json.Unmarshal(body, &token); err != nil {
return nil, fmt.Errorf("error decoding response body while polling for OAuth token: %v", err)
}


switch token.Error {
case "":

extra := make(map[string]interface{})
err := json.Unmarshal(body, &extra)
if err != nil {
// already been unmarshalled once, unlikely
return nil, err
}
return token.Token.WithExtra(extra), nil
case "authorization_pending":

case "slow_down":

code.Interval *= 2
case "access_denied":

return nil, ErrAccessDenied
default:

return nil, fmt.Errorf("authorization failed: %v", token.Error)
}

time.Sleep(time.Duration(code.Interval) * time.Second)
}
}
11 changes: 11 additions & 0 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func pathConfig(b *jwtAuthBackend) *framework.Path {
Type: framework.TypeCommaStringSlice,
Description: "The response types to request. Allowed values are 'code' and 'id_token'. Defaults to 'code'.",
},
"oidc_device_auth_url": {
Type: framework.TypeString,
Description: `OIDC Device Flow authentication URL. May only be used with "oidc_discovery_url".`,
},
"jwks_url": {
Type: framework.TypeString,
Description: `JWKS URL to use to authenticate signatures. Cannot be used with "oidc_discovery_url" or "jwt_validation_pubkeys".`,
Expand Down Expand Up @@ -163,6 +167,7 @@ func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Reques
"oidc_client_id": config.OIDCClientID,
"oidc_response_mode": config.OIDCResponseMode,
"oidc_response_types": config.OIDCResponseTypes,
"oidc_device_auth_url": config.OIDCDeviceAuthURL,
"default_role": config.DefaultRole,
"jwt_validation_pubkeys": config.JWTValidationPubKeys,
"jwt_supported_algs": config.JWTSupportedAlgs,
Expand All @@ -184,6 +189,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
OIDCClientSecret: d.Get("oidc_client_secret").(string),
OIDCResponseMode: d.Get("oidc_response_mode").(string),
OIDCResponseTypes: d.Get("oidc_response_types").([]string),
OIDCDeviceAuthURL: d.Get("oidc_device_auth_url").(string),
JWKSURL: d.Get("jwks_url").(string),
JWKSCAPEM: d.Get("jwks_ca_pem").(string),
DefaultRole: d.Get("default_role").(string),
Expand Down Expand Up @@ -222,6 +228,10 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil

case config.OIDCDeviceAuthURL != "" && config.OIDCDiscoveryURL == "":
return logical.ErrorResponse("'oidc_discovery_url' must be set when 'oidc_device_auth_url' is set"), nil

case config.JWKSURL != "":
case config.JWKSURL != "":
ctx, err := b.createCAContext(context.Background(), config.JWKSCAPEM)
if err != nil {
Expand Down Expand Up @@ -346,6 +356,7 @@ type jwtConfig struct {
OIDCClientSecret string `json:"oidc_client_secret"`
OIDCResponseMode string `json:"oidc_response_mode"`
OIDCResponseTypes []string `json:"oidc_response_types"`
OIDCDeviceAuthURL string `json:"oidc_device_auth_url"`
JWKSURL string `json:"jwks_url"`
JWKSCAPEM string `json:"jwks_ca_pem"`
JWTValidationPubKeys []string `json:"jwt_validation_pubkeys"`
Expand Down
4 changes: 2 additions & 2 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("role %q could not be found", roleName), nil
}

if role.RoleType == "oidc" {
if role.RoleType == "oidc" || role.RoleType == "oidcdevice" {
return logical.ErrorResponse("role with oidc role_type is not allowed"), nil
}

Expand Down Expand Up @@ -278,7 +278,7 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig,
SupportedSigningAlgs: config.JWTSupportedAlgs,
}

if role.RoleType == "oidc" {
if role.RoleType == "oidc" || role.RoleType == "oidcdevice" {
oidcConfig.ClientID = config.OIDCClientID
} else {
oidcConfig.SkipClientIDCheck = true
Expand Down
Loading