Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeWinikates committed Aug 23, 2023
1 parent 2bb3cf8 commit 66bfa91
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 44 deletions.
7 changes: 3 additions & 4 deletions internal/reporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ import (
"time"
)

func TestBuildRequest(t *testing.T) {
func TestReporter_BuildRequest(t *testing.T) {
var r *reporter
r = NewReporter("http://localhost:8010/wavefront", token.NewNoopTokenService(), &http.Client{}).(*reporter)
request, err := r.buildRequest("wavefront", nil)
require.NoError(t, err)
assert.Equal(t, "http://localhost:8010/wavefront/report?f=wavefront", request.URL.String())
}

func TestNewClientWithNilTLSConfig(t *testing.T) {
func TestNewClient_WithNilTLSConfig(t *testing.T) {
client := NewClient(10*time.Second, nil)
assert.Equal(t, nil, client.Transport)
}

func TestNewClientWithCustomTLSConfig(t *testing.T) {
func TestNewClient_WithCustomTLSConfig(t *testing.T) {
caCertPool := x509.NewCertPool()
fakeCert := []byte("Not a real cert")
caCertPool.AppendCertsFromPEM(fakeCert)
Expand All @@ -41,5 +41,4 @@ func TestNewClientWithCustomTLSConfig(t *testing.T) {
client := NewClient(10*time.Second, tlsConfig)
assert.Equal(t, transport, client.Transport)
assert.NotEqual(t, transportWithEmptyTLSConfig, client.Transport)

}
85 changes: 48 additions & 37 deletions internal/token/csp_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,43 @@ import (
"time"
)

type CSPService struct {
client csp.Client
mutex sync.Mutex
AccessToken string
tokenReady bool
type tokenResult struct {
accessToken string
err error
}

type CSPService struct {
client csp.Client
mutex sync.Mutex
tokenResult *tokenResult
tokenReady bool
ticker *time.Ticker
done chan bool
tickerInterval time.Duration
lastStatus error
}

// NewCSPServerToServerService returns a Service instance that gets access tokens via CSP
// NewCSPServerToServerService returns a Service instance that gets access tokens via CSP client credentials
func NewCSPServerToServerService(CSPBaseUrl string, ClientId string, ClientSecret string) Service {
return &CSPService{
client: &csp.ClientCredentialsClient{
BaseURL: CSPBaseUrl,
ClientID: ClientId,
ClientSecret: ClientSecret,
},
tickerInterval: 60 * time.Second,
}
return newService(&csp.ClientCredentialsClient{
BaseURL: CSPBaseUrl,
ClientID: ClientId,
ClientSecret: ClientSecret,
})
}

func NewCSPTokenService(CSPBaseUrl, apiToken string) Service {
return &CSPService{
client: &csp.APITokenClient{
BaseURL: CSPBaseUrl,
APIToken: apiToken,
},
return newService(&csp.APITokenClient{
BaseURL: CSPBaseUrl,
APIToken: apiToken,
})
}

func newService(client csp.Client) Service {
s := &CSPService{
client: client,
tickerInterval: 60 * time.Second,
}
return s
}

func (s *CSPService) IsDirect() bool {
Expand All @@ -51,55 +56,58 @@ func (s *CSPService) Authorize(r *http.Request) error {
s.mutex.Lock()
defer s.mutex.Unlock()

if s.lastStatus != nil {
return s.lastStatus
if s.tokenResult == nil {
s.RefreshAccessToken()
}

if !s.tokenReady {
s.RefreshAccessToken()
if s.tokenResult.err != nil {
return s.tokenResult.err
}

r.Header.Set("Authorization", "Bearer "+s.AccessToken)
r.Header.Set("Authorization", "Bearer "+s.tokenResult.accessToken)
return nil
}

func (s *CSPService) RefreshAccessToken() {
s.AccessToken = ""
s.tokenReady = false
cspResponse, err := s.client.GetAccessToken()

if err != nil {
s.lastStatus = err
s.tokenResult = &tokenResult{
accessToken: "",
err: err,
}
return
}

if !csp.HasDirectIngestScope(cspResponse.Scope) {
s.lastStatus = fmt.Errorf("response did not include required scope: 'aoa:directDataIngestion'")
s.tokenResult = &tokenResult{
accessToken: "",
err: fmt.Errorf("response did not include required scope: 'aoa:directDataIngestion'"),
}
return
}

s.startOrResetTicker(time.Duration(cspResponse.ExpiresIn) * time.Second)

s.AccessToken = cspResponse.AccessToken
s.lastStatus = nil
s.tokenReady = true
s.scheduleNextTokenRefresh(time.Duration(cspResponse.ExpiresIn) * time.Second)
s.tokenResult = &tokenResult{
accessToken: cspResponse.AccessToken,
err: nil,
}
}

func (s *CSPService) startOrResetTicker(expiresIn time.Duration) {
func (s *CSPService) scheduleNextTokenRefresh(expiresIn time.Duration) {
tickerInterval := calculateNewTickerInterval(expiresIn, s.tickerInterval)

if s.ticker == nil {
s.ticker = time.NewTicker(tickerInterval)
s.done = make(chan bool)

go func() {
for {
select {
case <-s.done:
return
case tick := <-s.ticker.C:
s.mutex.Lock()
log.Println("Re-fetching CSP credentials at", tick)
log.Printf("Re-fetching CSP credentials at: %v \n", tick)
s.RefreshAccessToken()
s.mutex.Unlock()
}
Expand All @@ -112,5 +120,8 @@ func (s *CSPService) startOrResetTicker(expiresIn time.Duration) {

func (s *CSPService) Close() {
log.Println("Shutting down the CSPService")
if s.ticker == nil {
return
}
s.done <- true
}
20 changes: 17 additions & 3 deletions internal/token/csp_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"time"
)

func TestMultipleCSPRequests(t *testing.T) {
func TestCSPService_MultipleCSPRequests(t *testing.T) {
cspServer := httptest.NewServer(csp.FakeCSPHandler(nil))
defer cspServer.Close()
tokenService := NewCSPServerToServerService(cspServer.URL, "a", "b")
Expand All @@ -35,7 +35,21 @@ func TestMultipleCSPRequests(t *testing.T) {
assert.NotEmpty(t, token)
assert.NotEqual(t, "INVALID_TOKEN", token)
assert.Equal(t, "Bearer def", token)

time.Sleep(10 * time.Millisecond)
tokenService.Close()
}

func TestCSPService_WhenAuthenticationFailsAuthorizeReturnsError(t *testing.T) {
cspServer := httptest.NewServer(csp.FakeCSPHandler(nil))
defer cspServer.Close()
tokenService := NewCSPServerToServerService(cspServer.URL, "nope", "wrong")
defer tokenService.Close()

cspTokenService := tokenService.(*CSPService)
cspTokenService.tickerInterval = 1 * time.Second

assert.NotNil(t, tokenService)
req, _ := http.NewRequest("GET", "https://example.com", nil)
assert.Error(t, tokenService.Authorize(req))
token := req.Header.Get("Authorization")
assert.Equal(t, "", token)
}

0 comments on commit 66bfa91

Please sign in to comment.