From f74266f8f83847cbe3057cbf983ada4699d3d5ac Mon Sep 17 00:00:00 2001 From: Juan Font Date: Mon, 14 Nov 2022 14:05:47 +0000 Subject: [PATCH] OIDC code cleanup and harmonize with regular web auth --- oidc.go | 91 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/oidc.go b/oidc.go index 853345a6f2..443827960e 100644 --- a/oidc.go +++ b/oidc.go @@ -76,20 +76,52 @@ func (h *Headscale) RegisterOIDC( ) { vars := mux.Vars(req) nodeKeyStr, ok := vars["nkey"] - if !ok || nodeKeyStr == "" { - log.Error(). - Caller(). - Msg("Missing node key in URL") - http.Error(writer, "Missing node key in URL", http.StatusBadRequest) - - return - } log.Trace(). Caller(). Str("node_key", nodeKeyStr). Msg("Received oidc register call") + if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText( + []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + ) + + if !ok || nodeKeyStr == "" || err != nil { + log.Warn().Err(err).Msg("Failed to parse incoming nodekey") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Wrong params")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + randomBlob := make([]byte, randomByteSize) if _, err := rand.Read(randomBlob); err != nil { log.Error(). @@ -103,7 +135,7 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] // place the node key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) + h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -405,8 +437,8 @@ func (h *Headscale) validateMachineForOIDCCallback( claims *IDTokenClaims, ) (*key.NodePublic, bool, error) { // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.registrationCache.Get(state) - if !machineKeyFound { + nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) + if !nodeKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -423,39 +455,42 @@ func (h *Headscale) validateMachineForOIDCCallback( } var nodeKey key.NodePublic - nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) - err := nodeKey.UnmarshalText( - []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), - ) - if err != nil { + nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) + if !nodeKeyOK { log.Error(). - Msg("could not parse node public key") + Msg("requested machine state key is not a string") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("could not parse public key")) - if werr != nil { + _, err := writer.Write([]byte("state is invalid")) + if err != nil { log.Error(). Caller(). - Err(werr). + Err(err). Msg("Failed to write response") } - return nil, false, err + return nil, false, errOIDCInvalidMachineState } - if !nodeKeyOK { - log.Error().Msg("could not get node key from cache") + err := nodeKey.UnmarshalText( + []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), + ) + if err != nil { + log.Error(). + Str("nodeKey", nodeKeyFromCache). + Bool("nodeKeyOK", nodeKeyOK). + Msg("could not parse node public key") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not get node key from cache")) - if err != nil { + writer.WriteHeader(http.StatusBadRequest) + _, werr := writer.Write([]byte("could not parse node public key")) + if werr != nil { log.Error(). Caller(). - Err(err). + Err(werr). Msg("Failed to write response") } - return nil, false, errOIDCNodeKeyMissing + return nil, false, err } // retrieve machine information if it exist