Skip to content

Commit

Permalink
feat: authentication using CSP
Browse files Browse the repository at this point in the history
  • Loading branch information
ernst-riemer authored and LukeWinikates committed Aug 18, 2023
1 parent fb1a1cf commit c13bb15
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 53 deletions.
29 changes: 18 additions & 11 deletions internal/reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"compress/gzip"
"crypto/tls"
"github.com/wavefronthq/wavefront-sdk-go/internal/token"
"io"
"io/ioutil"
"net/http"
Expand All @@ -13,17 +14,17 @@ import (

// The implementation of a Reporter that reports points directly to a Wavefront server.
type reporter struct {
serverURL string
token string
client *http.Client
serverURL string
tokenService token.TokenService
client *http.Client
}

// NewReporter creates a metrics Reporter
func NewReporter(server string, token string, client *http.Client) Reporter {
func NewReporter(server string, tokenService token.TokenService, client *http.Client) Reporter {
return &reporter{
serverURL: server,
token: token,
client: client,
serverURL: server,
tokenService: tokenService,
client: client,
}
}

Expand Down Expand Up @@ -77,8 +78,9 @@ func (reporter reporter) buildRequest(format string, body []byte) (*http.Request

req.Header.Set(contentType, octetStream)
req.Header.Set(contentEncoding, gzipFormat)
if len(reporter.token) > 0 {
req.Header.Set(authzHeader, bearer+reporter.token)

if len(reporter.tokenService.GetToken()) > 0 {
req.Header.Set(authzHeader, bearer+reporter.tokenService.GetToken())
}

q := req.URL.Query()
Expand All @@ -99,9 +101,10 @@ func (reporter reporter) ReportEvent(event string) (*http.Response, error) {
}

req.Header.Set(contentType, applicationJSON)
if len(reporter.token) > 0 {

if len(reporter.tokenService.GetToken()) > 0 {
req.Header.Set(contentEncoding, gzipFormat)
req.Header.Set(authzHeader, bearer+reporter.token)
req.Header.Set(authzHeader, bearer+reporter.tokenService.GetToken())
}

return reporter.execute(req)
Expand All @@ -116,3 +119,7 @@ func (reporter reporter) execute(req *http.Request) (*http.Response, error) {
defer resp.Body.Close()
return resp, nil
}

func (reporter reporter) Close() {
reporter.tokenService.Close()
}
3 changes: 2 additions & 1 deletion internal/reporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (
"crypto/x509"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wavefronthq/wavefront-sdk-go/internal/token"
"net/http"
"testing"
"time"
)

func TestBuildRequest(t *testing.T) {
var r *reporter
r = NewReporter("http://localhost:8010/wavefront", "", &http.Client{}).(*reporter)
r = NewReporter("http://localhost:8010/wavefront", token.NewNoopTokenService(), &http.Client{}).(*reporter)
request, err := r.buildRequest("wavefront", nil)
require.NoError(t, err)
assert.Equal(t, "http://localhost:8010/wavefront/report?f=wavefront", request.URL.String())
Expand Down
160 changes: 160 additions & 0 deletions internal/token/token_csp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package token

import (
"encoding/base64"
"encoding/json"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
)

type CSPAuthorizeResponse struct {
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}

func (t *CspServerToServerTokenService) GetToken() string {

t.mutex.Lock()
defer t.mutex.Unlock()

if !t.tokenReady {
cspAccessToken, err := t.CallCSP()

if err != nil {
return "INVALID_TOKEN"
}

t.AccessToken = cspAccessToken
t.tokenReady = true
}

return t.AccessToken
}

func (t *CspServerToServerTokenService) CallCSP() (string, error) {
var oauthPath = "/csp/gateway/am/api/auth/authorize"
client := &http.Client{}

req, err := http.NewRequest("POST", t.CSPBaseUrl+oauthPath, strings.NewReader(url.Values{"grant_type": {"client_credentials"}}.Encode()))

if err != nil {
return "", err
}

req.Header.Add("Authorization", encodeCSPCredentials(t.CSPClientId, t.CSPClientSecret))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

resp, err := client.Do(req)

if err != nil {
return "", err
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)

var cspResponse CSPAuthorizeResponse

err = json.Unmarshal(body, &cspResponse)

if !hasDirectIngestScope(cspResponse.Scope) {
log.Println("The CSP response did not find any scope matching 'aoa:directDataIngestion' which is required for Wavefront direct ingestion.")
}

if err != nil {
return "", err
}

t.startOrResetTicker(cspResponse)

log.Println("ACCESS TOKEN: " + cspResponse.AccessToken)
return cspResponse.AccessToken, nil
}

func hasDirectIngestScope(scope string) bool {
if len(scope) == 0 {
return false
}

for _, s := range strings.Split(scope, " ") {
if strings.Contains(s, "aoa:directDataIngestion") || strings.Contains(s, "aoa/*") || strings.Contains(s, "aoa:*") {
return true
}
}

return false
}

func (t *CspServerToServerTokenService) startOrResetTicker(cspResponse CSPAuthorizeResponse) {
tickerDelay := t.getTickerDelay(cspResponse.ExpiresIn)

if t.ticker == nil {
t.ticker = time.NewTicker(time.Duration(tickerDelay) * time.Second)
t.done = make(chan bool)

// Goroutine
go func() {
log.Printf("Scheduling a goroutine to fetch fresh CSP credentials in %d seconds\n", tickerDelay)
for {
select {
case <-t.done:
return
case tick := <-t.ticker.C:
t.mutex.Lock()

log.Println("Re-fetching CSP credentials at", tick)
cspAccessToken, e := t.CallCSP()

if e != nil {
t.AccessToken = "INVALID_TOKEN"
} else {
t.AccessToken = cspAccessToken
}

t.mutex.Unlock()
}
}
}()
} else {
t.ticker.Reset(time.Duration(tickerDelay) * time.Second)
}
}

func (t *CspServerToServerTokenService) getTickerDelay(expiresIn int) int {
retVal := 0

if expiresIn < 600 {
retVal = expiresIn - 30
} else {
retVal = expiresIn - 180
}

if retVal <= 0 {
return t.tickerDelay
}

return retVal
}

func encodeCSPCredentials(CSPClientId string, CSPClientSecret string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(CSPClientId+":"+CSPClientSecret))
}

// NewCSPServerToServerTokenService returns a TokenService instance where it will call CSP with client credentials to return an access token
func NewCSPServerToServerTokenService(CSPBaseUrl string, CSPClientId string, CSPClientSecret string) TokenService {
return &CspServerToServerTokenService{CSPBaseUrl: CSPBaseUrl, CSPClientId: CSPClientId, CSPClientSecret: CSPClientSecret, tickerDelay: 60}
}

func (t *CspServerToServerTokenService) Close() {
log.Println("Shutting down the CspServerToServerTokenService")
t.done <- true
}
110 changes: 110 additions & 0 deletions internal/token/token_csp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package token

import (
"encoding/base64"
"encoding/json"
"github.com/stretchr/testify/assert"
"github.com/wavefronthq/wavefront-sdk-go/test"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)

func TestMultipleCSPRequests(t *testing.T) {
directServer := httptest.NewServer(directHandler())
defer directServer.Close()
port := test.ExtractPort(directServer.URL)

tokenService := NewCSPServerToServerTokenService("http://localhost:"+port, "a", "b")

cspTokenService := tokenService.(*CspServerToServerTokenService)
cspTokenService.tickerDelay = 1

assert.NotNil(t, tokenService)
token := tokenService.GetToken()
assert.NotNil(t, token)
assert.NotEmpty(t, token)
assert.NotEqual(t, "INVALID_TOKEN", token)
assert.Equal(t, "abc", token)

time.Sleep(2 * time.Second)
token = tokenService.GetToken()

assert.NotNil(t, token)
assert.NotEmpty(t, token)
assert.NotEqual(t, "INVALID_TOKEN", token)
assert.Equal(t, "def", token)

time.Sleep(10 * time.Millisecond)
tokenService.Close()
}

func TestGetDelay(t *testing.T) {
cspStruct := CspServerToServerTokenService{
tickerDelay: 1,
}

assert.Equal(t, 999999820, cspStruct.getTickerDelay(1000000000))
assert.Equal(t, 569, cspStruct.getTickerDelay(599))
assert.Equal(t, 1, cspStruct.getTickerDelay(3))
assert.Equal(t, 1, cspStruct.getTickerDelay(1))
assert.Equal(t, 1, cspStruct.getTickerDelay(0))
assert.Equal(t, 1, cspStruct.getTickerDelay(-180))

cspStruct = CspServerToServerTokenService{
tickerDelay: 60,
}

assert.Equal(t, 999999820, cspStruct.getTickerDelay(1000000000))
assert.Equal(t, 569, cspStruct.getTickerDelay(599))
assert.Equal(t, 60, cspStruct.getTickerDelay(3))
assert.Equal(t, 60, cspStruct.getTickerDelay(1))
assert.Equal(t, 60, cspStruct.getTickerDelay(0))
assert.Equal(t, 60, cspStruct.getTickerDelay(-180))
}

func TestDirectIngestScopes(t *testing.T) {
assert.False(t, hasDirectIngestScope(""))
assert.False(t, hasDirectIngestScope("no direct ingest scopes"))

var scopeString = "external/51d98d2c-3ae1-11ee-be56-0242ac120002/*/aoa:directDataIngestion external/51d98d2c-3ae1-11ee-be56-0242ac120002/aoa:directDataIngestion csp:org_member"

assert.True(t, hasDirectIngestScope(scopeString))
assert.True(t, hasDirectIngestScope("some aoa:*"))
assert.True(t, hasDirectIngestScope("some aoa/*"))
}

func directHandler() http.Handler {
basicAuthCredentials := "Basic " + base64.StdEncoding.EncodeToString([]byte("a:b"))
firstRun := false

mux := http.NewServeMux()
mux.HandleFunc("/csp/gateway/am/api/auth/authorize", func(w http.ResponseWriter, r *http.Request) {
test.ReadBodyIntoString(r)
if strings.HasSuffix(r.Header.Get("Authorization"), basicAuthCredentials) {
var sup CSPAuthorizeResponse

if !firstRun {
sup = CSPAuthorizeResponse{
ExpiresIn: 1,
AccessToken: "abc",
}
firstRun = true
} else {
sup = CSPAuthorizeResponse{
ExpiresIn: 1,
AccessToken: "def",
}
}

w.WriteHeader(http.StatusOK)
marshal, _ := json.Marshal(sup)
w.Write(marshal)
return
}
w.WriteHeader(http.StatusUnauthorized)
})
return mux
}
17 changes: 17 additions & 0 deletions internal/token/token_noop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package token

var (
defaultNoopService TokenService = &TokenNoOpService{}
)

func (t TokenNoOpService) GetToken() string {
return ""
}

func (t TokenNoOpService) Close() {
}

// NewNoopTokenService returns a TokenService instance where it always returns an empty string for the token (for proxy usage).
func NewNoopTokenService() TokenService {
return defaultNoopService
}
13 changes: 13 additions & 0 deletions internal/token/token_wavefront.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package token

func (t WavefrontTokenService) GetToken() string {
return t.Token
}

func (t WavefrontTokenService) Close() {
}

// NewWavefrontTokenService returns a TokenService instance where it always returns a Wavefront API Token
func NewWavefrontTokenService(Token string) TokenService {
return &WavefrontTokenService{Token: Token}
}
Loading

0 comments on commit c13bb15

Please sign in to comment.