-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb1a1cf
commit c13bb15
Showing
11 changed files
with
562 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package token | ||
|
||
import ( | ||
"encoding/base64" | ||
"encoding/json" | ||
"io" | ||
"log" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"time" | ||
) | ||
|
||
type CSPAuthorizeResponse struct { | ||
IdToken string `json:"id_token"` | ||
TokenType string `json:"token_type"` | ||
ExpiresIn int `json:"expires_in"` | ||
Scope string `json:"scope"` | ||
AccessToken string `json:"access_token"` | ||
RefreshToken string `json:"refresh_token"` | ||
} | ||
|
||
func (t *CspServerToServerTokenService) GetToken() string { | ||
|
||
t.mutex.Lock() | ||
defer t.mutex.Unlock() | ||
|
||
if !t.tokenReady { | ||
cspAccessToken, err := t.CallCSP() | ||
|
||
if err != nil { | ||
return "INVALID_TOKEN" | ||
} | ||
|
||
t.AccessToken = cspAccessToken | ||
t.tokenReady = true | ||
} | ||
|
||
return t.AccessToken | ||
} | ||
|
||
func (t *CspServerToServerTokenService) CallCSP() (string, error) { | ||
var oauthPath = "/csp/gateway/am/api/auth/authorize" | ||
client := &http.Client{} | ||
|
||
req, err := http.NewRequest("POST", t.CSPBaseUrl+oauthPath, strings.NewReader(url.Values{"grant_type": {"client_credentials"}}.Encode())) | ||
|
||
if err != nil { | ||
return "", err | ||
} | ||
|
||
req.Header.Add("Authorization", encodeCSPCredentials(t.CSPClientId, t.CSPClientSecret)) | ||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
|
||
resp, err := client.Do(req) | ||
|
||
if err != nil { | ||
return "", err | ||
} | ||
|
||
defer resp.Body.Close() | ||
|
||
body, err := io.ReadAll(resp.Body) | ||
|
||
var cspResponse CSPAuthorizeResponse | ||
|
||
err = json.Unmarshal(body, &cspResponse) | ||
|
||
if !hasDirectIngestScope(cspResponse.Scope) { | ||
log.Println("The CSP response did not find any scope matching 'aoa:directDataIngestion' which is required for Wavefront direct ingestion.") | ||
} | ||
|
||
if err != nil { | ||
return "", err | ||
} | ||
|
||
t.startOrResetTicker(cspResponse) | ||
|
||
log.Println("ACCESS TOKEN: " + cspResponse.AccessToken) | ||
return cspResponse.AccessToken, nil | ||
} | ||
|
||
func hasDirectIngestScope(scope string) bool { | ||
if len(scope) == 0 { | ||
return false | ||
} | ||
|
||
for _, s := range strings.Split(scope, " ") { | ||
if strings.Contains(s, "aoa:directDataIngestion") || strings.Contains(s, "aoa/*") || strings.Contains(s, "aoa:*") { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (t *CspServerToServerTokenService) startOrResetTicker(cspResponse CSPAuthorizeResponse) { | ||
tickerDelay := t.getTickerDelay(cspResponse.ExpiresIn) | ||
|
||
if t.ticker == nil { | ||
t.ticker = time.NewTicker(time.Duration(tickerDelay) * time.Second) | ||
t.done = make(chan bool) | ||
|
||
// Goroutine | ||
go func() { | ||
log.Printf("Scheduling a goroutine to fetch fresh CSP credentials in %d seconds\n", tickerDelay) | ||
for { | ||
select { | ||
case <-t.done: | ||
return | ||
case tick := <-t.ticker.C: | ||
t.mutex.Lock() | ||
|
||
log.Println("Re-fetching CSP credentials at", tick) | ||
cspAccessToken, e := t.CallCSP() | ||
|
||
if e != nil { | ||
t.AccessToken = "INVALID_TOKEN" | ||
} else { | ||
t.AccessToken = cspAccessToken | ||
} | ||
|
||
t.mutex.Unlock() | ||
} | ||
} | ||
}() | ||
} else { | ||
t.ticker.Reset(time.Duration(tickerDelay) * time.Second) | ||
} | ||
} | ||
|
||
func (t *CspServerToServerTokenService) getTickerDelay(expiresIn int) int { | ||
retVal := 0 | ||
|
||
if expiresIn < 600 { | ||
retVal = expiresIn - 30 | ||
} else { | ||
retVal = expiresIn - 180 | ||
} | ||
|
||
if retVal <= 0 { | ||
return t.tickerDelay | ||
} | ||
|
||
return retVal | ||
} | ||
|
||
func encodeCSPCredentials(CSPClientId string, CSPClientSecret string) string { | ||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(CSPClientId+":"+CSPClientSecret)) | ||
} | ||
|
||
// NewCSPServerToServerTokenService returns a TokenService instance where it will call CSP with client credentials to return an access token | ||
func NewCSPServerToServerTokenService(CSPBaseUrl string, CSPClientId string, CSPClientSecret string) TokenService { | ||
return &CspServerToServerTokenService{CSPBaseUrl: CSPBaseUrl, CSPClientId: CSPClientId, CSPClientSecret: CSPClientSecret, tickerDelay: 60} | ||
} | ||
|
||
func (t *CspServerToServerTokenService) Close() { | ||
log.Println("Shutting down the CspServerToServerTokenService") | ||
t.done <- true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package token | ||
|
||
import ( | ||
"encoding/base64" | ||
"encoding/json" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/wavefronthq/wavefront-sdk-go/test" | ||
"net/http" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestMultipleCSPRequests(t *testing.T) { | ||
directServer := httptest.NewServer(directHandler()) | ||
defer directServer.Close() | ||
port := test.ExtractPort(directServer.URL) | ||
|
||
tokenService := NewCSPServerToServerTokenService("http://localhost:"+port, "a", "b") | ||
|
||
cspTokenService := tokenService.(*CspServerToServerTokenService) | ||
cspTokenService.tickerDelay = 1 | ||
|
||
assert.NotNil(t, tokenService) | ||
token := tokenService.GetToken() | ||
assert.NotNil(t, token) | ||
assert.NotEmpty(t, token) | ||
assert.NotEqual(t, "INVALID_TOKEN", token) | ||
assert.Equal(t, "abc", token) | ||
|
||
time.Sleep(2 * time.Second) | ||
token = tokenService.GetToken() | ||
|
||
assert.NotNil(t, token) | ||
assert.NotEmpty(t, token) | ||
assert.NotEqual(t, "INVALID_TOKEN", token) | ||
assert.Equal(t, "def", token) | ||
|
||
time.Sleep(10 * time.Millisecond) | ||
tokenService.Close() | ||
} | ||
|
||
func TestGetDelay(t *testing.T) { | ||
cspStruct := CspServerToServerTokenService{ | ||
tickerDelay: 1, | ||
} | ||
|
||
assert.Equal(t, 999999820, cspStruct.getTickerDelay(1000000000)) | ||
assert.Equal(t, 569, cspStruct.getTickerDelay(599)) | ||
assert.Equal(t, 1, cspStruct.getTickerDelay(3)) | ||
assert.Equal(t, 1, cspStruct.getTickerDelay(1)) | ||
assert.Equal(t, 1, cspStruct.getTickerDelay(0)) | ||
assert.Equal(t, 1, cspStruct.getTickerDelay(-180)) | ||
|
||
cspStruct = CspServerToServerTokenService{ | ||
tickerDelay: 60, | ||
} | ||
|
||
assert.Equal(t, 999999820, cspStruct.getTickerDelay(1000000000)) | ||
assert.Equal(t, 569, cspStruct.getTickerDelay(599)) | ||
assert.Equal(t, 60, cspStruct.getTickerDelay(3)) | ||
assert.Equal(t, 60, cspStruct.getTickerDelay(1)) | ||
assert.Equal(t, 60, cspStruct.getTickerDelay(0)) | ||
assert.Equal(t, 60, cspStruct.getTickerDelay(-180)) | ||
} | ||
|
||
func TestDirectIngestScopes(t *testing.T) { | ||
assert.False(t, hasDirectIngestScope("")) | ||
assert.False(t, hasDirectIngestScope("no direct ingest scopes")) | ||
|
||
var scopeString = "external/51d98d2c-3ae1-11ee-be56-0242ac120002/*/aoa:directDataIngestion external/51d98d2c-3ae1-11ee-be56-0242ac120002/aoa:directDataIngestion csp:org_member" | ||
|
||
assert.True(t, hasDirectIngestScope(scopeString)) | ||
assert.True(t, hasDirectIngestScope("some aoa:*")) | ||
assert.True(t, hasDirectIngestScope("some aoa/*")) | ||
} | ||
|
||
func directHandler() http.Handler { | ||
basicAuthCredentials := "Basic " + base64.StdEncoding.EncodeToString([]byte("a:b")) | ||
firstRun := false | ||
|
||
mux := http.NewServeMux() | ||
mux.HandleFunc("/csp/gateway/am/api/auth/authorize", func(w http.ResponseWriter, r *http.Request) { | ||
test.ReadBodyIntoString(r) | ||
if strings.HasSuffix(r.Header.Get("Authorization"), basicAuthCredentials) { | ||
var sup CSPAuthorizeResponse | ||
|
||
if !firstRun { | ||
sup = CSPAuthorizeResponse{ | ||
ExpiresIn: 1, | ||
AccessToken: "abc", | ||
} | ||
firstRun = true | ||
} else { | ||
sup = CSPAuthorizeResponse{ | ||
ExpiresIn: 1, | ||
AccessToken: "def", | ||
} | ||
} | ||
|
||
w.WriteHeader(http.StatusOK) | ||
marshal, _ := json.Marshal(sup) | ||
w.Write(marshal) | ||
return | ||
} | ||
w.WriteHeader(http.StatusUnauthorized) | ||
}) | ||
return mux | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package token | ||
|
||
var ( | ||
defaultNoopService TokenService = &TokenNoOpService{} | ||
) | ||
|
||
func (t TokenNoOpService) GetToken() string { | ||
return "" | ||
} | ||
|
||
func (t TokenNoOpService) Close() { | ||
} | ||
|
||
// NewNoopTokenService returns a TokenService instance where it always returns an empty string for the token (for proxy usage). | ||
func NewNoopTokenService() TokenService { | ||
return defaultNoopService | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package token | ||
|
||
func (t WavefrontTokenService) GetToken() string { | ||
return t.Token | ||
} | ||
|
||
func (t WavefrontTokenService) Close() { | ||
} | ||
|
||
// NewWavefrontTokenService returns a TokenService instance where it always returns a Wavefront API Token | ||
func NewWavefrontTokenService(Token string) TokenService { | ||
return &WavefrontTokenService{Token: Token} | ||
} |
Oops, something went wrong.