Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor broker refactoring and cleanup #349

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 65 additions & 41 deletions internal/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ type session struct {
lang string
mode string

selectedMode string
firstSelectedMode string
authModes []string
attemptsPerMode map[string]int
selectedMode string
authModes []string
attemptsPerMode map[string]int

oidcServer *oidc.Provider
oauth2Config oauth2.Config
Expand Down Expand Up @@ -211,18 +210,13 @@ func (b *Broker) connectToOIDCServer(ctx context.Context) (*oidc.Provider, error
}

// GetAuthenticationModes returns the authentication modes available for the user.
func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []map[string]string) (authModes []map[string]string, err error) {
func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []map[string]string) (authModesWithLabels []map[string]string, err error) {
session, err := b.getSession(sessionID)
if err != nil {
return nil, err
}

supportedAuthModes := b.supportedAuthModesFromLayout(supportedUILayouts)

log.Debugf(context.Background(), "Supported UI Layouts for session %s: %#v", sessionID, supportedUILayouts)
log.Debugf(context.Background(), "Supported Authentication modes for session %s: %#v", sessionID, supportedAuthModes)

// Checks if the token exists in the cache.
// Check if the token exists in the cache.
tokenExists, err := fileutils.FileExists(session.tokenPath)
if err != nil {
log.Warningf(context.Background(), "Could not check if token exists: %v", err)
Expand All @@ -235,37 +229,26 @@ func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []m
}
}

endpoints := make(map[string]struct{})
if session.oidcServer != nil && session.oidcServer.Endpoint().DeviceAuthURL != "" {
authMode := authmodes.DeviceQr
if _, ok := supportedAuthModes[authMode]; ok {
endpoints[authMode] = struct{}{}
}
authMode = authmodes.Device
if _, ok := supportedAuthModes[authMode]; ok {
endpoints[authMode] = struct{}{}
}
}

availableModes, err := b.provider.CurrentAuthenticationModesOffered(
session.mode,
supportedAuthModes,
tokenExists,
!session.isOffline,
endpoints,
session.currentAuthStep)
availableModes, err := b.availableAuthModes(session, tokenExists)
if err != nil {
return nil, err
}

for _, id := range availableModes {
authModes = append(authModes, map[string]string{
"id": id,
"label": supportedAuthModes[id],
modesSupportedByUI := b.authModesSupportedByUI(supportedUILayouts)

for _, mode := range availableModes {
if _, ok := modesSupportedByUI[mode]; !ok {
log.Infof(context.Background(), "Authentication mode %q is not supported by the UI", mode)
continue
}

authModesWithLabels = append(authModesWithLabels, map[string]string{
"id": mode,
"label": modesSupportedByUI[mode],
})
}

if len(authModes) == 0 {
if len(authModesWithLabels) == 0 {
return nil, fmt.Errorf("no authentication modes available for user %q", session.username)
}

Expand All @@ -274,10 +257,55 @@ func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []m
return nil, err
}

return authModes, nil
return authModesWithLabels, nil
}

func (b *Broker) availableAuthModes(session session, tokenExists bool) (availableModes []string, err error) {
switch session.mode {
case sessionmode.ChangePassword, sessionmode.ChangePasswordOld:
// session is for changing the password
if !tokenExists {
return nil, errors.New("user has no cached token")
}
availableModes = []string{authmodes.Password}
if session.currentAuthStep > 0 {
availableModes = []string{authmodes.NewPassword}
}

default: // session is for login
if !session.isOffline {
availableModes = b.oidcAuthModes(session)
}
if tokenExists {
availableModes = append([]string{authmodes.Password}, availableModes...)
}
if session.currentAuthStep > 0 {
availableModes = []string{authmodes.NewPassword}
}
}
return availableModes, nil
}

func (b *Broker) oidcAuthModes(session session) []string {
var modes []string
endpoints := make(map[string]struct{})
if session.oidcServer != nil && session.oidcServer.Endpoint().DeviceAuthURL != "" {
endpoints[authmodes.DeviceQr] = struct{}{}
endpoints[authmodes.Device] = struct{}{}
}

for _, mode := range b.provider.SupportedOIDCAuthModes() {
if _, ok := endpoints[mode]; ok {
modes = append(modes, mode)
} else {
log.Warningf(context.Background(), "No provider endpoint for mode %q", mode)
}
}

return modes
}

func (b *Broker) supportedAuthModesFromLayout(supportedUILayouts []map[string]string) (supportedModes map[string]string) {
func (b *Broker) authModesSupportedByUI(supportedUILayouts []map[string]string) (supportedModes map[string]string) {
supportedModes = make(map[string]string)
for _, layout := range supportedUILayouts {
supportedEntries := strings.Split(strings.TrimPrefix(layout["entry"], "optional:"), ",")
Expand Down Expand Up @@ -322,10 +350,6 @@ func (b *Broker) SelectAuthenticationMode(sessionID, authModeID string) (uiLayou

// Store selected mode
session.selectedMode = authModeID
// Store the first one to use to update the lastSelectedMode in MFA cases.
if session.currentAuthStep == 0 {
session.firstSelectedMode = authModeID
}

if err = b.updateSession(sessionID, session); err != nil {
return nil, err
Expand Down
53 changes: 5 additions & 48 deletions internal/providers/msentraid/msentraid.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
msgraphauth "github.com/microsoftgraph/msgraph-sdk-go-core/authentication"
msgraphmodels "github.com/microsoftgraph/msgraph-sdk-go/models"
"github.com/ubuntu/authd-oidc-brokers/internal/broker/authmodes"
"github.com/ubuntu/authd-oidc-brokers/internal/broker/sessionmode"
"github.com/ubuntu/authd-oidc-brokers/internal/consts"
providerErrors "github.com/ubuntu/authd-oidc-brokers/internal/providers/errors"
"github.com/ubuntu/authd-oidc-brokers/internal/providers/info"
Expand Down Expand Up @@ -280,53 +279,6 @@ func isSecurityGroup(group msgraphmodels.Groupable) bool {
return !slices.Contains(group.GetGroupTypes(), "Unified")
}

// CurrentAuthenticationModesOffered returns the generic authentication modes supported by the provider.
//
// Token validity is not considered, only the presence of a token.
func (p Provider) CurrentAuthenticationModesOffered(
sessionMode string,
supportedAuthModes map[string]string,
tokenExists bool,
providerReachable bool,
endpoints map[string]struct{},
currentAuthStep int,
) ([]string, error) {
log.Debugf(context.Background(), "In CurrentAuthenticationModesOffered: sessionMode=%q, supportedAuthModes=%q, tokenExists=%t, providerReachable=%t, endpoints=%q, currentAuthStep=%d\n", sessionMode, supportedAuthModes, tokenExists, providerReachable, endpoints, currentAuthStep)
var offeredModes []string
switch sessionMode {
case sessionmode.ChangePassword, sessionmode.ChangePasswordOld:
if !tokenExists {
return nil, errors.New("user has no cached token")
}
offeredModes = []string{authmodes.Password}
if currentAuthStep > 0 {
offeredModes = []string{authmodes.NewPassword}
}

default: // auth mode
if _, ok := endpoints[authmodes.DeviceQr]; ok && providerReachable {
offeredModes = []string{authmodes.DeviceQr}
} else if _, ok := endpoints[authmodes.Device]; ok && providerReachable {
offeredModes = []string{authmodes.Device}
}
if tokenExists {
offeredModes = append([]string{authmodes.Password}, offeredModes...)
}
if currentAuthStep > 0 {
offeredModes = []string{authmodes.NewPassword}
}
}
log.Debugf(context.Background(), "Offered modes: %q", offeredModes)

for _, mode := range offeredModes {
if _, ok := supportedAuthModes[mode]; !ok {
return nil, fmt.Errorf("auth mode %q required by the provider, but is not supported locally", mode)
}
}

return offeredModes, nil
}

// NormalizeUsername parses a username into a normalized version.
func (p Provider) NormalizeUsername(username string) string {
// Microsoft Entra usernames are case-insensitive. We can safely use strings.ToLower here without worrying about
Expand All @@ -335,6 +287,11 @@ func (p Provider) NormalizeUsername(username string) string {
return strings.ToLower(username)
}

// SupportedOIDCAuthModes returns the OIDC authentication modes supported by the provider.
func (p Provider) SupportedOIDCAuthModes() []string {
return []string{authmodes.Device, authmodes.DeviceQr}
}

// VerifyUsername checks if the authenticated username matches the requested username and that both are valid.
func (p Provider) VerifyUsername(requestedUsername, authenticatedUsername string) error {
if p.NormalizeUsername(requestedUsername) != p.NormalizeUsername(authenticatedUsername) {
Expand Down
50 changes: 5 additions & 45 deletions internal/providers/noprovider/noprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package noprovider

import (
"context"
"errors"
"fmt"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/ubuntu/authd-oidc-brokers/internal/broker/authmodes"
"github.com/ubuntu/authd-oidc-brokers/internal/broker/sessionmode"
"github.com/ubuntu/authd-oidc-brokers/internal/providers/info"
"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -37,49 +35,6 @@ func (p NoProvider) AuthOptions() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{}
}

// CurrentAuthenticationModesOffered returns the generic authentication modes supported by the provider.
func (p NoProvider) CurrentAuthenticationModesOffered(
sessionMode string,
supportedAuthModes map[string]string,
tokenExists bool,
providerReachable bool,
endpoints map[string]struct{},
currentAuthStep int,
) ([]string, error) {
var offeredModes []string
switch sessionMode {
case sessionmode.ChangePassword, sessionmode.ChangePasswordOld:
if !tokenExists {
return nil, errors.New("user has no cached token")
}
offeredModes = []string{authmodes.Password}
if currentAuthStep > 0 {
offeredModes = []string{authmodes.NewPassword}
}

default: // auth mode
if _, ok := endpoints[authmodes.DeviceQr]; ok && providerReachable {
offeredModes = []string{authmodes.DeviceQr}
} else if _, ok := endpoints[authmodes.Device]; ok && providerReachable {
offeredModes = []string{authmodes.Device}
}
if tokenExists {
offeredModes = append([]string{authmodes.Password}, offeredModes...)
}
if currentAuthStep > 0 {
offeredModes = []string{authmodes.NewPassword}
}
}

for _, mode := range offeredModes {
if _, ok := supportedAuthModes[mode]; !ok {
return nil, fmt.Errorf("auth mode %q required by the provider, but is not supported locally", mode)
}
}

return offeredModes, nil
}

// GetExtraFields returns the extra fields of the token which should be stored persistently.
func (p NoProvider) GetExtraFields(token *oauth2.Token) map[string]interface{} {
return nil
Expand Down Expand Up @@ -125,6 +80,11 @@ func (p NoProvider) VerifyUsername(requestedUsername, username string) error {
return nil
}

// SupportedOIDCAuthModes returns the OIDC authentication modes supported by the provider.
func (p NoProvider) SupportedOIDCAuthModes() []string {
return []string{authmodes.Device, authmodes.DeviceQr}
}

type claims struct {
Email string `json:"email"`
Sub string `json:"sub"`
Expand Down
9 changes: 1 addition & 8 deletions internal/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,10 @@ type Provider interface {
AdditionalScopes() []string
AuthOptions() []oauth2.AuthCodeOption
CheckTokenScopes(token *oauth2.Token) error
CurrentAuthenticationModesOffered(
sessionMode string,
supportedAuthModes map[string]string,
tokenExists bool,
providerReachable bool,
endpoints map[string]struct{},
currentAuthStep int,
) ([]string, error)
GetExtraFields(token *oauth2.Token) map[string]interface{}
GetMetadata(provider *oidc.Provider) (map[string]interface{}, error)
GetUserInfo(ctx context.Context, accessToken *oauth2.Token, idToken *oidc.IDToken, providerMetadata map[string]interface{}) (info.User, error)
NormalizeUsername(username string) string
SupportedOIDCAuthModes() []string
VerifyUsername(requestedUsername, authenticatedUsername string) error
}
Loading