Skip to content
This repository has been archived by the owner on Dec 7, 2020. It is now read-only.

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
- cleaning up the logging line
- removing some verbose comments
  • Loading branch information
gambol99 committed Jan 27, 2017
1 parent 8429947 commit 8559052
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 70 deletions.
8 changes: 4 additions & 4 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func newOauthProxyApp() *cli.App {
app.Version = version
app.Author = author
app.Email = email
app.Flags = getCLIOptions()
app.Flags = getCommandLineOptions()
app.UsageText = "keycloak-proxy [options]"

// step: the standard usage message isn't that helpful
Expand Down Expand Up @@ -86,8 +86,9 @@ func newOauthProxyApp() *cli.App {
return app
}

// getCLIOptions returns the command line options
func getCLIOptions() []cli.Flag {
// getCommandLineOptions builds the command line options by reflecting the Config struct and extracting
// the tagged information
func getCommandLineOptions() []cli.Flag {
defaults := newDefaultConfig()
var flags []cli.Flag
count := reflect.TypeOf(Config{}).NumField()
Expand Down Expand Up @@ -149,7 +150,6 @@ func getCLIOptions() []cli.Flag {
}

// parseCLIOptions parses the command line options and constructs a config object
// @TODO look for a shorter way of doing this, we're maintaining the same options in multiple places, it's tedious!
func parseCLIOptions(cx *cli.Context, config *Config) (err error) {
// step: we can ignore these options in the Config struct
ignoredOptions := []string{"tag-data", "match-claims", "resources", "headers"}
Expand Down
4 changes: 2 additions & 2 deletions cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ import (
)

func TestGetCLIOptions(t *testing.T) {
if flags := getCLIOptions(); flags == nil {
if flags := getCommandLineOptions(); flags == nil {
t.Error("we should have received some flags options")
}
}

func TestReadOptions(t *testing.T) {
c := cli.NewApp()
c.Flags = getCLIOptions()
c.Flags = getCommandLineOptions()
c.Action = func(cx *cli.Context) error {
parseCLIOptions(cx, &Config{})
return nil
Expand Down
24 changes: 6 additions & 18 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) {
// step: exchange the authorization for a access token
response, err := exchangeAuthenticationCode(client, code)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to exchange code for access token")
log.WithFields(log.Fields{"error": err.Error()}).Errorf("unable to exchange code for access token")

r.accessForbidden(cx)
return
Expand All @@ -138,19 +136,15 @@ func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) {
// step: parse decode the identity token
session, identity, err := parseToken(response.IDToken)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to parse id token for identity")
log.WithFields(log.Fields{"error": err.Error()}).Errorf("unable to parse id token for identity")

r.accessForbidden(cx)
return
}

// step: verify the token is valid
if err = verifyToken(r.client, session); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to verify the id token")
log.WithFields(log.Fields{"error": err.Error()}).Errorf("unable to verify the id token")

r.accessForbidden(cx)
return
Expand All @@ -159,9 +153,7 @@ func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) {
// step: attempt to decode the access token else we default to the id token
accessToken, id, err := parseToken(response.AccessToken)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("unable to parse the access token, using id token only")
log.WithFields(log.Fields{"error": err.Error()}).Errorf("unable to parse the access token, using id token only")
} else {
session = accessToken
identity = id
Expand All @@ -182,9 +174,7 @@ func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) {
// step: encrypt the refresh token
encrypted, err := encodeText(response.RefreshToken, r.config.EncryptionKey)
if err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Errorf("failed to encrypt the refresh token")
log.WithFields(log.Fields{"error": err.Error()}).Errorf("failed to encrypt the refresh token")

cx.AbortWithStatus(http.StatusInternalServerError)
return
Expand All @@ -194,9 +184,7 @@ func (r *oauthProxy) oauthCallbackHandler(cx *gin.Context) {
switch r.useStore() {
case true:
if err := r.StoreRefreshToken(session, encrypted); err != nil {
log.WithFields(log.Fields{
"error": err.Error(),
}).Warnf("failed to save the refresh token in the store")
log.WithFields(log.Fields{"error": err.Error()}).Warnf("failed to save the refresh token in the store")
}
// step: get expiration of the refresh token if we can
_, ident, err := parseToken(response.RefreshToken)
Expand Down
6 changes: 0 additions & 6 deletions misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ import (
"github.com/gin-gonic/gin"
)

//
// accessForbidden redirects the user to the forbidden page
//
func (r *oauthProxy) accessForbidden(cx *gin.Context) {
if r.config.hasCustomForbiddenPage() {
cx.HTML(http.StatusForbidden, path.Base(r.config.ForbiddenPage), r.config.Tags)
Expand All @@ -38,17 +36,13 @@ func (r *oauthProxy) accessForbidden(cx *gin.Context) {
cx.AbortWithStatus(http.StatusForbidden)
}

//
// redirectToURL redirects the user and aborts the context
//
func (r *oauthProxy) redirectToURL(url string, cx *gin.Context) {
cx.Redirect(http.StatusTemporaryRedirect, url)
cx.Abort()
}

//
// redirectToAuthorization redirects the user to authorization handler
//
func (r *oauthProxy) redirectToAuthorization(cx *gin.Context) {
if r.config.NoRedirects {
cx.AbortWithStatus(http.StatusUnauthorized)
Expand Down
2 changes: 0 additions & 2 deletions store_boltdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ var (
ErrNoBoltdbBucket = errors.New("the boltdb bucket does not exists")
)

//
// A local file store used to hold the refresh tokens
//
type boltdbStore struct {
client *bolt.DB
}
Expand Down
38 changes: 0 additions & 38 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ func decodeKeyPairs(list []string) (map[string]string, error) {
return kp, nil
}

//
// isValidHTTPMethod ensure this is a valid http method type
//
func isValidHTTPMethod(method string) bool {
return httpMethodRegex.MatchString(method)
}
Expand All @@ -234,9 +232,7 @@ func defaultTo(v, d string) string {
return d
}

//
// cloneTLSConfig clones the tls configuration
//
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
Expand Down Expand Up @@ -264,9 +260,7 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
}
}

//
// fileExists check if a file exists
//
func fileExists(filename string) bool {
if _, err := os.Stat(filename); err != nil {
if os.IsNotExist(err) {
Expand All @@ -277,9 +271,7 @@ func fileExists(filename string) bool {
return true
}

//
// hasRoles checks the scopes are the same
//
func hasRoles(required, issued []string) bool {
for _, role := range required {
if !containedIn(role, issued) {
Expand All @@ -290,9 +282,7 @@ func hasRoles(required, issued []string) bool {
return true
}

//
// containedIn checks if a value in a list of a strings
//
func containedIn(value string, list []string) bool {
for _, x := range list {
if x == value {
Expand All @@ -303,9 +293,7 @@ func containedIn(value string, list []string) bool {
return false
}

//
// containsSubString checks if substring exists
//
func containsSubString(value string, list []string) bool {
for _, x := range list {
if strings.Contains(value, x) {
Expand All @@ -316,9 +304,7 @@ func containsSubString(value string, list []string) bool {
return false
}

//
// tryDialEndpoint dials the upstream endpoint via plain
//
func tryDialEndpoint(location *url.URL) (net.Conn, error) {
switch dialAddress := dialAddress(location); location.Scheme {
case httpSchema:
Expand All @@ -331,9 +317,7 @@ func tryDialEndpoint(location *url.URL) (net.Conn, error) {
}
}

//
// isUpgradedConnection checks to see if the request is requesting
//
func isUpgradedConnection(req *http.Request) bool {
if req.Header.Get(headerUpgrade) != "" {
return true
Expand All @@ -342,9 +326,7 @@ func isUpgradedConnection(req *http.Request) bool {
return false
}

//
// transferBytes transfers bytes between the sink and source
//
func transferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, error) {
defer wg.Done()
copied, err := io.Copy(dest, src)
Expand All @@ -355,9 +337,7 @@ func transferBytes(src io.Reader, dest io.Writer, wg *sync.WaitGroup) (int64, er
return copied, nil
}

//
// tryUpdateConnection attempt to upgrade the connection to a http pdy stream
//
func tryUpdateConnection(cx *gin.Context, endpoint *url.URL) error {
// step: dial the endpoint
tlsConn, err := tryDialEndpoint(endpoint)
Expand Down Expand Up @@ -388,9 +368,7 @@ func tryUpdateConnection(cx *gin.Context, endpoint *url.URL) error {
return nil
}

//
// dialAddress extracts the dial address from the url
//
func dialAddress(location *url.URL) string {
items := strings.Split(location.Host, ":")
if len(items) != 2 {
Expand All @@ -405,9 +383,7 @@ func dialAddress(location *url.URL) string {
return location.Host
}

//
// findCookie looks for a cookie in a list of cookies
//
func findCookie(name string, cookies []*http.Cookie) *http.Cookie {
for _, cookie := range cookies {
if cookie.Name == name {
Expand All @@ -418,9 +394,7 @@ func findCookie(name string, cookies []*http.Cookie) *http.Cookie {
return nil
}

//
// toHeader is a helper method to play nice in the headers
//
func toHeader(v string) string {
var list []string

Expand All @@ -432,9 +406,7 @@ func toHeader(v string) string {
return strings.Join(list, "-")
}

//
// capitalize capitalizes the first letter of a word
//
func capitalize(s string) string {
if s == "" {
return ""
Expand All @@ -444,9 +416,7 @@ func capitalize(s string) string {
return string(unicode.ToUpper(r)) + s[n:]
}

//
// mergeMaps simples copies the keys from source to destination
//
func mergeMaps(dest, source map[string]string) map[string]string {
for k, v := range source {
dest[k] = v
Expand All @@ -455,9 +425,7 @@ func mergeMaps(dest, source map[string]string) map[string]string {
return dest
}

//
// loadCA loads the certificate authority
//
func loadCA(cert, key string) (*tls.Certificate, error) {
caCert, err := ioutil.ReadFile(cert)
if err != nil {
Expand All @@ -479,26 +447,20 @@ func loadCA(cert, key string) (*tls.Certificate, error) {
return &ca, err
}

//
// getWithin calculates a duration of x percent of the time period, i.e. something
// expires in 1 hours, get me a duration within 80%
//
func getWithin(expires time.Time, in float64) time.Duration {
seconds := int(float64(expires.Sub(time.Now()).Seconds()) * in)
return time.Duration(seconds) * time.Second
}

//
// getHashKey returns a hash of the encodes jwt token
//
func getHashKey(token *jose.JWT) string {
hash := md5.Sum([]byte(token.Encode()))
return hex.EncodeToString(hash[:])
}

//
// printError display the command line usage and error
//
func printError(message string, args ...interface{}) *cli.ExitError {
return cli.NewExitError(fmt.Sprintf("[error] "+message, args...), 1)
}

0 comments on commit 8559052

Please sign in to comment.